graph_executor_lazy.cpp 16.7 KB
Newer Older
1
/** ExaTN:: Tensor Runtime: Tensor graph executor: Lazy
2
REVISION: 2022/01/17
3

4
5
Copyright (C) 2018-2022 Dmitry Lyakh
Copyright (C) 2018-2022 Oak Ridge National Laboratory (UT-Battelle)
6
7
**/

8
9
#include "graph_executor_lazy.hpp"

10
11
#include "talshxx.hpp"

12
13
14
15
#ifdef CUQUANTUM
#include "cuquantum_executor.hpp"
#endif

16
17
18
#include <iostream>
#include <iomanip>

19
20
#include "errors.hpp"

21
22
23
#ifndef NDEBUG
#define DEBUG
#endif
24

25
26
27
namespace exatn {
namespace runtime {

28
29
void LazyGraphExecutor::resetNodeExecutor(std::shared_ptr<TensorNodeExecutor> node_executor,
                                          const ParamConf & parameters,
30
                                          unsigned int num_processes,
31
32
33
                                          unsigned int process_rank,
                                          unsigned int global_process_rank)
{
34
  TensorGraphExecutor::resetNodeExecutor(node_executor,parameters,num_processes,process_rank,global_process_rank);
35
#ifdef CUQUANTUM
36
37
38
  if(node_executor){
    cuquantum_executor_ = std::make_shared<CuQuantumExecutor>(
      [this](const numerics::Tensor & tensor, int device_kind, int device_id, std::size_t * size){
39
        void * data_ptr = this->node_executor_->getTensorImage(tensor,device_kind,device_id,size);
40
        return data_ptr;
41
      },
42
      cuquantum_pipe_depth_,
43
44
      num_processes,
      process_rank
45
46
    );
  }
47
#endif
48
  return;
49
50
51
}


52
void LazyGraphExecutor::execute(TensorGraph & dag) {
53

54
  struct Progress {
55
56
57
    VertexIdType num_nodes; //total number of nodes in the DAG (may grow)
    VertexIdType front;     //the first unexecuted node in the DAG
    VertexIdType current;   //the current node in the DAG
58
59
60
61
  };

  Progress progress{dag.getNumNodes(),dag.getFrontNode(),0};
  progress.current = progress.front;
62

63
64
65
  auto find_next_idle_node = [this,&dag,&progress] () {
    const auto prev_node = progress.current;
    progress.front = dag.getFrontNode();
66
    progress.num_nodes = dag.getNumNodes();
67
68
69
    if(progress.front < progress.num_nodes){
      ++progress.current;
      if(progress.current >= progress.num_nodes){
70
        progress.current = progress.front;
71
72
73
74
75
76
77
78
79
80
        if(progress.current == prev_node) ++progress.current;
      }else{
        if(progress.current >= (progress.front + this->getPipelineDepth())){
          progress.current = progress.front;
          if(progress.current == prev_node) ++progress.current;
        }
      }
      while(progress.current < progress.num_nodes){
        if(dag.nodeIdle(progress.current)) break;
        ++progress.current;
81
      }
82
83
    }else{ //all DAG nodes have been executed
      progress.current = progress.front; //end-of-DAG
84
    }
85
86
    progress.num_nodes = dag.getNumNodes(); //update DAG size again
    return (progress.current < progress.num_nodes && progress.current != prev_node);
87
88
  };

89
  auto inspect_node_dependencies = [this,&dag,&progress] () {
90
91
92
93
94
95
96
    bool ready_for_execution = false;
    if(progress.current < progress.num_nodes){
      auto & dag_node = dag.getNodeProperties(progress.current);
      ready_for_execution = dag_node.isIdle();
      if(ready_for_execution){ //node is idle
        ready_for_execution = ready_for_execution && dag.nodeDependenciesResolved(progress.current);
        if(ready_for_execution){ //all node dependencies resolved (or none)
97
98
          auto registered = dag.registerDependencyFreeNode(progress.current);
          if(registered && logging_.load() > 1) logfile_ << "DAG node detected with all dependencies resolved: " << progress.current << std::endl;
99
100
101
102
103
104
105
        }else{ //node still has unresolved dependencies, try prefetching
          if(progress.current < (progress.front + this->getPrefetchDepth())){
            auto prefetching = this->node_executor_->prefetch(*(dag_node.getOperation()));
            if(logging_.load() != 0 && prefetching){
              logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
                       << "](LazyGraphExecutor)[EXEC_THREAD]: Initiated prefetch for tensor operation "
                       << progress.current << std::endl;
106
#ifdef DEBUG
107
108
              logfile_.flush();
#endif
109
110
            }
          }
111
112
113
        }
      }
    }
114
115
116
    return ready_for_execution;
  };

117
  auto issue_ready_node = [this,&dag,&progress] () {
118
    if(logging_.load() > 2){
119
120
121
122
123
      logfile_ << "DAG current list of dependency free nodes:";
      auto free_nodes = dag.getDependencyFreeNodes();
      for(const auto & node: free_nodes) logfile_ << " " << node;
      logfile_ << std::endl;
    }
124
125
126
127
128
    VertexIdType node;
    bool issued = dag.extractDependencyFreeNode(&node);
    if(issued){
      auto & dag_node = dag.getNodeProperties(node);
      auto op = dag_node.getOperation();
129
130
131
132
133
134
135
136
      if(logging_.load() != 0){
        logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
                 << "](LazyGraphExecutor)[EXEC_THREAD]: Submitting tensor operation "
                 << node << ": Opcode = " << static_cast<int>(op->getOpcode());
        if(logging_.load() > 1){
          logfile_ << ": Details:" << std::endl;
          op->printItFile(logfile_);
        }
137
#ifdef DEBUG
138
139
        logfile_.flush();
#endif
140
      }
141
      dag.setNodeExecuting(node);
142
143
144
      op->recordStartTime();
      TensorOpExecHandle exec_handle;
      auto error_code = op->accept(*(this->node_executor_),&exec_handle);
145
      if(logging_.load() != 0) logfile_ << ": Status = " << error_code;
146
      if(error_code == 0){ //tensor operation submitted for execution successfully
147
        if(logging_.load() != 0) logfile_ << ": Syncing ... ";
148
        auto synced = this->node_executor_->sync(exec_handle,&error_code,serialize_.load());
149
150
151
152
        if(synced){ //tensor operation has completed immediately
          op->recordFinishTime();
          dag.setNodeExecuted(node,error_code);
          if(error_code == 0){
153
154
155
            if(logging_.load() != 0){
              logfile_ << "Success [" << std::fixed << std::setprecision(6)
                       << exatn::Timer::timeInSecHR(getTimeStampStart()) << "]" << std::endl;
156
157
              std::size_t free_mem = 0;
              std::size_t used_mem = this->node_executor_->getMemoryUsage(&free_mem);
158
              logfile_ << "[" << exatn::Timer::timeInSecHR(getTimeStampStart()) << "]"
159
160
                       << " Total Flop count = " << getTotalFlopCount()
                       << "; Memory usage = " << used_mem << ", Free = " << free_mem << std::endl;
161
#ifdef DEBUG
162
              logfile_.flush();
163
#endif
164
            }
165
            op->dissociateTensorOperands();
166
            progress.num_nodes = dag.getNumNodes();
167
            auto progressed = dag.progressFrontNode(node);
168
169
170
171
172
173
174
175
176
177
            if(progressed){
              progress.front = dag.getFrontNode();
              while(progress.front < progress.num_nodes){
                if(!(dag.nodeExecuted(progress.front))) break;
                dag.progressFrontNode(progress.front);
                progress.front = dag.getFrontNode();
              }
            }
            if(progressed && logging_.load() > 1) logfile_ << "DAG front node progressed to "
              << progress.front << " out of total of " << progress.num_nodes << std::endl;
178
          }else{
179
180
181
182
183
            if(logging_.load() != 0){
              logfile_ << "Failed: Error " << error_code << " [" << std::fixed << std::setprecision(6)
                       << exatn::Timer::timeInSecHR(getTimeStampStart()) << "]" << std::endl;
              logfile_.flush();
            }
184
185
186
187
188
189
            std::cout << "#ERROR(exatn::TensorRuntime::GraphExecutorLazy): Immediate completion error for tensor operation "
             << node << " with execution handle " << exec_handle << ": Error " << error_code << std::endl << std::flush;
            assert(false); //`Do I need to handle this case gracefully?
          }
        }else{ //tensor operation is still executing asynchronously
          dag.registerExecutingNode(node,exec_handle);
190
          if(logging_.load() != 0) logfile_ << "Deferred" << std::endl;
191
192
        }
      }else{ //tensor operation not submitted due to either temporary resource shortage or fatal error
193
194
        auto discarded = this->node_executor_->discard(exec_handle);
        dag.setNodeIdle(node);
195
        auto registered = dag.registerDependencyFreeNode(node); assert(registered);
196
        issued = false;
197
198
199
        if(error_code == TRY_LATER){ //temporary shortage of resources
          if(logging_.load() != 0) logfile_ << ": Postponed" << std::endl;
        }else{ //fatal error
200
          if(logging_.load() != 0) logfile_.flush();
201
202
          std::cout << "#ERROR(exatn::TensorRuntime::GraphExecutorLazy): Failed to submit tensor operation "
           << node << " with execution handle " << exec_handle << ": Error " << error_code << std::endl << std::flush;
203
          assert(false); //`Do I need to handle this case gracefully?
204
205
206
207
208
209
        }
      }
    }
    return issued;
  };

210
  auto test_nodes_for_completion = [this,&dag,&progress] () {
211
212
213
214
    auto executing_nodes = dag.executingNodesBegin();
    while(executing_nodes != dag.executingNodesEnd()){
      int error_code;
      auto exec_handle = executing_nodes->second;
215
      auto synced = this->node_executor_->sync(exec_handle,&error_code,serialize_.load());
216
      if(synced){ //tensor operation has completed
217
218
219
220
221
        VertexIdType node;
        executing_nodes = dag.extractExecutingNode(executing_nodes,&node);
        auto & dag_node = dag.getNodeProperties(node);
        auto op = dag_node.getOperation();
        op->recordFinishTime();
222
        dag.setNodeExecuted(node,error_code);
223
        if(error_code == 0){
224
225
226
227
          if(logging_.load() != 0){
            logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
                     << "](LazyGraphExecutor)[EXEC_THREAD]: Synced tensor operation "
                     << node << ": Opcode = " << static_cast<int>(op->getOpcode()) << std::endl;
228
229
            std::size_t free_mem = 0;
            std::size_t used_mem = this->node_executor_->getMemoryUsage(&free_mem);
230
            logfile_ << "[" << exatn::Timer::timeInSecHR(getTimeStampStart()) << "]"
231
232
                     << " Total Flop count = " << getTotalFlopCount()
                     << "; Memory usage = " << used_mem << ", Free = " << free_mem << std::endl;
233
#ifdef DEBUG
234
            logfile_.flush();
235
#endif
236
          }
237
          op->dissociateTensorOperands();
238
          progress.num_nodes = dag.getNumNodes();
239
          auto progressed = dag.progressFrontNode(node);
240
241
242
243
244
245
246
247
248
249
          if(progressed){
            progress.front = dag.getFrontNode();
            while(progress.front < progress.num_nodes){
              if(!(dag.nodeExecuted(progress.front))) break;
              dag.progressFrontNode(progress.front);
              progress.front = dag.getFrontNode();
            }
          }
          if(progressed && logging_.load() > 1) logfile_ << "DAG front node progressed to "
            << progress.front << " out of total of " << progress.num_nodes << std::endl;
250
        }else{
251
252
253
254
255
256
          if(logging_.load() != 0){
            logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
                     << "](LazyGraphExecutor)[EXEC_THREAD]: Failed to sync tensor operation "
                     << node << ": Opcode = " << static_cast<int>(op->getOpcode()) << std::endl;
            logfile_.flush();
          }
257
          std::cout << "#ERROR(exatn::TensorRuntime::GraphExecutorLazy): Deferred completion error for tensor operation "
258
           << node << " with execution handle " << exec_handle << ": Error " << error_code << std::endl << std::flush;
259
          assert(false); //`Do I need to handle this case gracefully?
260
        }
261
      }else{ //tensor operation has not completed yet
262
263
264
265
266
267
        ++executing_nodes;
      }
    }
    return;
  };

268
269
270
271
272
273
  if(logging_.load() != 0){
    logfile_ << "DAG entry list of dependency free nodes:";
    auto free_nodes = dag.getDependencyFreeNodes();
    for(const auto & node: free_nodes) logfile_ << " " << node;
    logfile_ << std::endl << std::flush;
  }
274
275
276
277
278
279
280
281
282
283
284
  bool not_done = (progress.front < progress.num_nodes);
  while(not_done){
    //Try to issue all idle DAG nodes that are ready for execution:
    while(issue_ready_node());
    //Inspect whether the current node can be issued:
    auto node_ready = inspect_node_dependencies();
    //Test the currently executing DAG nodes for completion:
    test_nodes_for_completion();
    //Find the next idle DAG node:
    not_done = find_next_idle_node() || (progress.front < progress.num_nodes);
  }
285
  return;
286
287
}

288
289

void LazyGraphExecutor::execute(TensorNetworkQueue & tensor_network_queue) {
290
#ifdef CUQUANTUM
291
292
  //std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Started executing the tensor network queue via cuQuantum: "
  //          << tensor_network_queue.getSize() << " networks detected" << std::endl;
293
  assert(node_executor_);
294
  //Synchronize the node executor:
295
  bool synced = node_executor_->sync(); assert(synced);
296
297
298
299
300
301
  node_executor_->clearCache();
  //Process the tensor network queue:
  while(!tensor_network_queue.isEmpty()){
    tensor_network_queue.reset();
    bool not_over = !tensor_network_queue.isOver();
    while(not_over){
302
      const auto current_pos = tensor_network_queue.getCurrentPos();
303
      if(current_pos < cuquantum_pipe_depth_){
304
305
306
        const auto current = tensor_network_queue.getCurrent();
        const auto exec_handle = current->second;
        int error_code = 0;
307
308
        int64_t num_slices = 0;
        ExecutionTimings timings;
309
310
        auto exec_stat = tensor_network_queue.checkExecStatus(exec_handle);
        if(exec_stat == TensorNetworkQueue::ExecStat::Idle || current_pos == 0){
311
          exec_stat = cuquantum_executor_->sync(exec_handle,&error_code,&num_slices,&timings); //this call will progress tensor network execution
312
313
314
          assert(error_code == 0);
        }
        if(exec_stat == TensorNetworkQueue::ExecStat::None){
315
316
          if(logging_.load() != 0){
            logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
317
                     << "](LazyGraphExecutor)[EXEC_THREAD]: Submitting to cuQuantum tensor network " << exec_handle << std::endl;
318
319
320
321
#ifdef DEBUG
            logfile_.flush();
#endif
          }
322
323
          const auto exec_conf = tensor_network_queue.getExecConfiguration(exec_handle);
          exec_stat = cuquantum_executor_->execute(current->first,exec_conf.first,exec_conf.second,exec_handle);
324
          if(logging_.load() != 0){
325
326
327
            logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
                     << "](LazyGraphExecutor)[EXEC_THREAD]: Submitted to cuQuantum tensor network " << exec_handle
                     << ": Status = " << static_cast<int>(exec_stat) << std::endl;
328
329
330
331
#ifdef DEBUG
            logfile_.flush();
#endif
          }
332
333
          if(exec_stat != TensorNetworkQueue::ExecStat::None){
            auto prev_exec_stat = tensor_network_queue.updateExecStatus(exec_handle,exec_stat);
334
            //std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Submitted tensor network to cuQuantum\n";
335
336
337
          }
          not_over = tensor_network_queue.next();
        }else if(exec_stat == TensorNetworkQueue::ExecStat::Completed){
338
339
          if(logging_.load() != 0){
            logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
340
341
342
343
                     << "](LazyGraphExecutor)[EXEC_THREAD]: Completed via cuQuantum tensor network " << exec_handle
                     << ": NumSlices = " << num_slices << "; Time (ms): In{" << timings.data_in
                     << "}, Prep{" << timings.prepare << "}, Comp{" << timings.compute
                     << "}, Out{" << timings.data_out << "}" << std::endl;
344
345
346
347
#ifdef DEBUG
            logfile_.flush();
#endif
          }
348
349
350
          auto prev_exec_stat = tensor_network_queue.updateExecStatus(exec_handle,exec_stat);
          assert(current_pos == 0);
          tensor_network_queue.remove();
351
          //std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Completed tensor network execution via cuQuantum\n";
352
353
          not_over = !tensor_network_queue.isOver();
        }else{
354
          auto prev_exec_stat = tensor_network_queue.updateExecStatus(exec_handle,exec_stat);
355
          not_over = tensor_network_queue.next();
356
357
        }
      }else{
358
        not_over = false;
359
360
361
      }
    }
  }
362
  cuquantum_executor_->sync();
363
  //std::cout << "#DEBUG(exatn::runtime::LazyGraphExecutor::execute): Finished executing the tensor network queue via cuQuantum\n";
364
365
#else
  assert(tensor_network_queue.isEmpty());
366
#endif
367
368
369
  return;
}

370
371
372
373
374
375
376
377
378
379
380
double LazyGraphExecutor::getTotalFlopCount() const
{
  while(!node_executor_);
  double flops = node_executor_->getTotalFlopCount();
#ifdef CUQUANTUM
  while(!cuquantum_executor_);
  flops += cuquantum_executor_->getTotalFlopCount();
#endif
  return flops;
}

381
382
} //namespace runtime
} //namespace exatn