tensor_runtime.cpp 11.9 KB
Newer Older
1
/** ExaTN:: Tensor Runtime: Task-based execution layer for tensor operations
2
REVISION: 2022/01/17
3

4
5
Copyright (C) 2018-2022 Dmitry Lyakh, Tiffany Mintz, Alex McCaskey
Copyright (C) 2018-2022 Oak Ridge National Laboratory (UT-Battelle)
6
7
**/

8
#include "tensor_runtime.hpp"
9
#include "exatn_service.hpp"
10

11
12
#include "talshxx.hpp"

13
14
15
16
#ifdef MPI_ENABLED
#include "mpi.h"
#endif

17
#include <vector>
18
19
#include <iostream>

20
21
22
23
#include "errors.hpp"

//#define DEBUG

24
25
namespace exatn {
namespace runtime {
Dmitry I. Lyakh's avatar
Dmitry I. Lyakh committed
26

27
#ifdef MPI_ENABLED
28
29
30
static MPI_Comm global_mpi_comm; //MPI communicator used to initialize the tensor runtime

TensorRuntime::TensorRuntime(const MPICommProxy & communicator,
31
                             const ParamConf & parameters,
32
33
                             const std::string & graph_executor_name,
                             const std::string & node_executor_name):
34
 parameters_(parameters),
35
 graph_executor_name_(graph_executor_name), node_executor_name_(node_executor_name),
36
 current_dag_(nullptr), logging_(0), executing_(false), scope_set_(false), alive_(false)
37
{
38
#ifdef DEBUG
39
  const bool debugging = true;
40
#else
41
42
  const bool debugging = false;
#endif
43
44
45
  global_mpi_comm = *(communicator.get<MPI_Comm>());
  int mpi_error = MPI_Comm_size(global_mpi_comm,&num_processes_); assert(mpi_error == MPI_SUCCESS);
  mpi_error = MPI_Comm_rank(global_mpi_comm,&process_rank_); assert(mpi_error == MPI_SUCCESS);
46
  mpi_error = MPI_Comm_rank(MPI_COMM_WORLD,&global_process_rank_); assert(mpi_error == MPI_SUCCESS);
47
  graph_executor_ = exatn::getService<TensorGraphExecutor>(graph_executor_name_);
48
49
50
  if(debugging) std::cout << "#DEBUG(exatn::runtime::TensorRuntime)[MAIN_THREAD:Process " << process_rank_
                          << "]: DAG executor set to " << graph_executor_name_ << " + "
                          << node_executor_name_ << std::endl << std::flush;
51
52
53
  launchExecutionThread();
}
#else
54
55
TensorRuntime::TensorRuntime(const ParamConf & parameters,
                             const std::string & graph_executor_name,
56
                             const std::string & node_executor_name):
57
 parameters_(parameters),
58
 graph_executor_name_(graph_executor_name), node_executor_name_(node_executor_name),
59
 current_dag_(nullptr), logging_(0), executing_(false), scope_set_(false), alive_(false)
60
{
61
#ifdef DEBUG
62
  const bool debugging = true;
63
#else
64
65
  const bool debugging = false;
#endif
66
  num_processes_ = 1; process_rank_ = 0; global_process_rank_ = 0;
67
  graph_executor_ = exatn::getService<TensorGraphExecutor>(graph_executor_name_);
68
69
  if(debugging) std::cout << "#DEBUG(exatn::runtime::TensorRuntime)[MAIN_THREAD]: DAG executor set to "
                          << graph_executor_name_ << " + " << node_executor_name_ << std::endl << std::flush;
70
  launchExecutionThread();
71
}
72
#endif
73
74
75
76
77


TensorRuntime::~TensorRuntime()
{
  if(alive_.load()){
78
    alive_.store(false); //signal for the execution thread to finish
79
    //std::cout << "#DEBUG(exatn::runtime::TensorRuntime)[MAIN_THREAD]: Waiting Execution Thread ... " << std::flush;
80
    exec_thread_.join(); //wait until the execution thread has finished
81
    //std::cout << "Joined" << std::endl << std::flush;
82
83
84
85
86
87
  }
}


void TensorRuntime::launchExecutionThread()
{
88
89
  if(!(alive_.load())){
    alive_.store(true);
90
    //std::cout << "#DEBUG(exatn::runtime::TensorRuntime)[MAIN_THREAD]: Launching Execution Thread ... " << std::flush;
91
    exec_thread_ = std::thread(&TensorRuntime::executionThreadWorkflow,this);
92
    //std::cout << "Done" << std::endl << std::flush;
93
94
  }
  return; //only the main thread returns to the client
95
96
97
98
99
}


void TensorRuntime::executionThreadWorkflow()
{
100
  graph_executor_->resetNodeExecutor(exatn::getService<TensorNodeExecutor>(node_executor_name_),
101
                                     parameters_,num_processes_,process_rank_,global_process_rank_);
102
103
  //std::cout << "#DEBUG(exatn::runtime::TensorRuntime)[EXEC_THREAD]: DAG node executor set to "
            //<< node_executor_name_ << std::endl << std::flush;
104
  while(alive_.load()){ //alive_ is set by the main thread
105
    while(executing_.load()){ //executing_ is set to TRUE by the main thread when new operations and syncs are submitted
106
      graph_executor_->execute(*current_dag_);
107
      processTensorDataRequests(); //process all outstanding client requests for tensor data (synchronous)
108
      if(current_dag_->hasUnexecutedNodes()){
109
        executing_.store(true); //reaffirm that DAG is still executing
110
      }else{
111
        graph_executor_->execute(tensor_network_queue_);
112
        if(!(current_dag_->hasUnexecutedNodes())) executing_.store(false); //executing_ is set to FALSE by the execution thread
113
      }
114
    }
115
    processTensorDataRequests(); //process all outstanding client requests for tensor data (synchronous)
116
  }
117
118
  graph_executor_->resetNodeExecutor(std::shared_ptr<TensorNodeExecutor>(nullptr),
                                     parameters_,num_processes_,process_rank_,global_process_rank_);
119
120
  //std::cout << "#DEBUG(exatn::runtime::TensorRuntime)[EXEC_THREAD]: DAG node executor reset. End of life."
            //<< std::endl << std::flush;
121
  return; //end of execution thread life
122
123
124
}


125
126
127
128
129
130
void TensorRuntime::processTensorDataRequests()
{
  lockDataReqQ();
  for(auto & req: data_req_queue_){
    req.slice_promise_.set_value(graph_executor_->getLocalTensor(*(req.tensor_),req.slice_specs_));
  }
131
  data_req_queue_.clear();
132
133
134
135
136
  unlockDataReqQ();
  return;
}


137
138
void TensorRuntime::resetLoggingLevel(int level)
{
139
140
141
142
  while(!graph_executor_);
  graph_executor_->resetLoggingLevel(level);
  logging_ = level;
  return;
143
144
145
}


146
147
void TensorRuntime::resetSerialization(bool serialize, bool validation_trace)
{
148
149
  while(!graph_executor_);
  return graph_executor_->resetSerialization(serialize,validation_trace);
150
151
152
}


153
154
void TensorRuntime::activateDryRun(bool dry_run)
{
155
156
  while(!graph_executor_);
  return graph_executor_->activateDryRun(dry_run);
157
158
159
}


160
161
void TensorRuntime::activateFastMath()
{
162
163
  while(!graph_executor_);
  return graph_executor_->activateFastMath();
164
165
166
}


167
std::size_t TensorRuntime::getMemoryBufferSize() const
168
{
169
170
  while(!graph_executor_);
  return graph_executor_->getMemoryBufferSize();
171
172
173
}


174
175
176
177
178
179
180
std::size_t TensorRuntime::getMemoryUsage(std::size_t * free_mem) const
{
  while(!graph_executor_);
  return graph_executor_->getMemoryUsage(free_mem);
}


181
182
double TensorRuntime::getTotalFlopCount() const
{
183
184
  while(!graph_executor_);
  return graph_executor_->getTotalFlopCount();
185
186
187
}


188
void TensorRuntime::openScope(const std::string & scope_name) {
189
  assert(!scope_name.empty());
190
  // Complete the current scope first:
191
192
193
194
195
  if(currentScopeIsSet()){
    assert(scope_name != current_scope_);
    closeScope();
  }
  // Create new DAG with name given by scope name and store it in the dags map:
196
  auto new_dag = dags_.emplace(std::make_pair(
197
                                scope_name,
198
199
200
                                exatn::getService<TensorGraph>("boost-digraph")
                               )
                              );
201
  assert(new_dag.second); // make sure there was no other scope with the same name
202
  current_dag_ = (new_dag.first)->second; //storing a shared pointer to the DAG
203
  current_scope_ = scope_name; // change the name of the current scope
204
  scope_set_.store(true);
Dmitry I. Lyakh's avatar
Dmitry I. Lyakh committed
205
  return;
206
207
}

208
209

void TensorRuntime::pauseScope() {
210
  graph_executor_->stopExecution(); //execution thread will pause and reset executing_ to FALSE
211
212
213
214
  return;
}


215
void TensorRuntime::resumeScope(const std::string & scope_name) {
216
  assert(!scope_name.empty());
217
  // Pause the current scope first:
218
  if(currentScopeIsSet()) pauseScope();
219
  while(executing_.load()){}; //wait until the execution thread stops executing previous DAG
220
  current_dag_ = dags_[scope_name]; //storing a shared pointer to the DAG
221
  current_scope_ = scope_name; // change the name of the current scope
222
  scope_set_.store(true);
223
  executing_.store(true); //will trigger DAG execution by the execution thread
224
225
226
  return;
}

227

Dmitry I. Lyakh's avatar
Dmitry I. Lyakh committed
228
void TensorRuntime::closeScope() {
229
  if(currentScopeIsSet()){
230
    sync();
231
    while(executing_.load()){}; //wait until the execution thread has completed execution of the current DAG
232
233
    const std::string scope_name = current_scope_;
    scope_set_.store(false);
234
    current_scope_ = "";
235
    current_dag_.reset();
236
237
    auto num_deleted = dags_.erase(scope_name);
    assert(num_deleted == 1);
238
  }
239
  return;
Dmitry I. Lyakh's avatar
Dmitry I. Lyakh committed
240
}
241

242

243
VertexIdType TensorRuntime::submit(std::shared_ptr<TensorOperation> op) {
244
245
  assert(currentScopeIsSet());
  auto node_id = current_dag_->addOperation(op);
246
  op->setId(node_id);
247
  //current_dag_->printIt(); //debug
248
249
  executing_.store(true); //signal to the execution thread to execute the DAG
  return node_id;
250
251
}

252

253
bool TensorRuntime::sync(TensorOperation & op, bool wait) {
254
255
256
257
  assert(currentScopeIsSet());
  executing_.store(true); //reactivate the execution thread to execute the DAG in case it was not active
  auto opid = op.getId();
  bool completed = current_dag_->nodeExecuted(opid);
258
  while(wait && (!completed)){
259
260
    executing_.store(true); //reactivate the execution thread to execute the DAG in case it was not active
    completed = current_dag_->nodeExecuted(opid);
261
  }
262
  return completed;
263
264
}

265

266
bool TensorRuntime::sync(const Tensor & tensor, bool wait) {
267
  //if(wait) std::cout << "#DEBUG(TensorRuntime::sync)[MAIN_THREAD]: Syncing on tensor " << tensor.getName() << " ... "; //debug
268
  assert(currentScopeIsSet());
269
270
  executing_.store(true); //reactivate the execution thread to execute the DAG in case it was not active
  bool completed = (current_dag_->getTensorUpdateCount(tensor) == 0);
271
  while(wait && (!completed)){
272
273
    executing_.store(true); //reactivate the execution thread to execute the DAG in case it was not active
    completed = (current_dag_->getTensorUpdateCount(tensor) == 0);
274
  }
275
  //if(wait) std::cout << "Synced" << std::endl; //debug
Dmitry I. Lyakh's avatar
Dmitry I. Lyakh committed
276
  return completed;
277
278
}

279

280
bool TensorRuntime::sync(bool wait) {
281
  //if(wait) std::cout << "#DEBUG(TensorRuntime::sync)[MAIN_THREAD]: Syncing ... "; //debug
282
  assert(currentScopeIsSet());
283
  if(current_dag_->hasUnexecutedNodes()) executing_.store(true);
284
  bool still_working = executing_.load();
285
  while(wait && still_working){
286
287
    if(current_dag_->hasUnexecutedNodes()) executing_.store(true);
    still_working = executing_.load();
288
  }
289
  if(wait && (!still_working)){
290
291
292
293
294
    if(current_dag_->getNumNodes() > MAX_RUNTIME_DAG_SIZE){
      //std::cout << "Clearing DAG ... "; //debug
      current_dag_->clear();
      //std::cout << "Done; "; //debug
    }
295
  }
296
  //if(wait) std::cout << "Synced\n" << std::flush; //debug
297
298
299
300
  return !still_working;
}


301
#ifdef CUQUANTUM
302
303
304
TensorOpExecHandle TensorRuntime::submit(std::shared_ptr<numerics::TensorNetwork> network,
                                         const MPICommProxy & communicator,
                                         unsigned int num_processes, unsigned int process_rank)
305
{
306
307
308
  const auto exec_handle = tensor_network_queue_.append(network,communicator,num_processes,process_rank);
  executing_.store(true); //signal to the execution thread to execute the queue
  return exec_handle;
309
310
311
312
313
314
}


bool TensorRuntime::syncNetwork(const TensorOpExecHandle exec_handle, bool wait)
{
  assert(exec_handle != 0);
315
  executing_.store(true); //reactivate the execution thread in case it was not active
316
317
318
319
320
321
322
323
324
325
326
327
  bool synced = false;
  while(!synced){
    const auto exec_stat = tensor_network_queue_.checkExecStatus(exec_handle);
    synced = (exec_stat == TensorNetworkQueue::ExecStat::None ||
              exec_stat == TensorNetworkQueue::ExecStat::Completed);
    if(!wait) break;
  };
  return synced;
}
#endif


328
329
330
331
332
333
334
335
336
337
338
339
340
std::future<std::shared_ptr<talsh::Tensor>> TensorRuntime::getLocalTensor(std::shared_ptr<Tensor> tensor,
                                          const std::vector<std::pair<DimOffset,DimExtent>> & slice_spec)
{
  // Complete all submitted update operations on the tensor:
  auto synced = sync(*tensor,true); assert(synced);
  // Create promise-future pair:
  std::promise<std::shared_ptr<talsh::Tensor>> promised_slice;
  auto future_slice = promised_slice.get_future();
  // Schedule data request:
  lockDataReqQ();
  data_req_queue_.emplace_back(std::move(promised_slice),slice_spec,tensor);
  unlockDataReqQ();
  return future_slice;
341
}
Dmitry I. Lyakh's avatar
Dmitry I. Lyakh committed
342

343
344
} // namespace runtime
} // namespace exatn