tensor_graph_executor.hpp 6.67 KB
Newer Older
1
/** ExaTN:: Tensor Runtime: Tensor graph executor
2
REVISION: 2021/12/22
3

4
5
Copyright (C) 2018-2021 Dmitry Lyakh, Tiffany Mintz, Alex McCaskey
Copyright (C) 2018-2021 Oak Ridge National Laboratory (UT-Battelle)
6
7

Rationale:
8
9
10
11
12
13
 (a) Tensor graph executor traverses the tensor graph (DAG) and
     executes all its nodes while respecting node dependencies.
     Each DAG node is executed by a concrete tensor node executor
     (tensor operation stored in the DAG node accepts a polymorphic
     tensor node executor which then executes that tensor operation).
     The execution of each DAG node is generally asynchronous.
14
15
16
17
18
**/

#ifndef EXATN_RUNTIME_TENSOR_GRAPH_EXECUTOR_HPP_
#define EXATN_RUNTIME_TENSOR_GRAPH_EXECUTOR_HPP_

19
20
#include "Identifiable.hpp"

21
#include "tensor_graph.hpp"
22
#include "tensor_network_queue.hpp"
23
24
25
#include "tensor_node_executor.hpp"
#include "tensor_operation.hpp"

26
27
#include "param_conf.hpp"

28
29
#include "timers.hpp"

30
#include <memory>
31
#include <atomic>
32

33
34
#include <iostream>
#include <fstream>
35
#include <iomanip>
36

37
38
#include "errors.hpp"

39
40
41
namespace exatn {
namespace runtime {

42
class TensorGraphExecutor : public Identifiable, public Cloneable<TensorGraphExecutor> {
43
44
45

public:

46
  TensorGraphExecutor():
47
   node_executor_(nullptr), num_ops_issued_(0), process_rank_(-1), global_process_rank_(-1),
48
49
   logging_(0), stopping_(false), active_(false), serialize_(false), validation_tracing_(false),
   time_start_(exatn::Timer::timeInSecHR())
50
51
  {
  }
52
53
54

  TensorGraphExecutor(const TensorGraphExecutor &) = delete;
  TensorGraphExecutor & operator=(const TensorGraphExecutor &) = delete;
55
56
  TensorGraphExecutor(TensorGraphExecutor &&) = delete;
  TensorGraphExecutor & operator=(TensorGraphExecutor &&) = delete;
57
58
59
60

  virtual ~TensorGraphExecutor(){
    resetLoggingLevel();
  }
61

62
  /** Sets/resets the DAG node executor (tensor operation executor). **/
63
64
65
66
  virtual void resetNodeExecutor(std::shared_ptr<TensorNodeExecutor> node_executor,
                                 const ParamConf & parameters,
                                 unsigned int process_rank,
                                 unsigned int global_process_rank) {
67
    process_rank_.store(process_rank);
68
    global_process_rank_.store(global_process_rank);
69
    node_executor_ = node_executor;
70
71
72
73
74
75
76
    if(node_executor_){
      if(logging_.load() != 0){
        logfile_ << "[" << std::fixed << std::setprecision(6) << exatn::Timer::timeInSecHR(getTimeStampStart())
                 << "](TensorGraphExecutor)[EXEC_THREAD]: Initializing the node executor ... "; //debug
      }
      node_executor_->initialize(parameters);
      if(logging_.load() != 0){
77
        logfile_ << "Successfully initialized [" << std::fixed << std::setprecision(6)
78
79
80
81
                 << exatn::Timer::timeInSecHR(getTimeStampStart()) << "]" << std::endl; //debug
        logfile_.flush();
      }
    }
82
    return;
83
84
  }

85
86
87
  /** Resets the logging level (0:none). **/
  void resetLoggingLevel(int level = 0) {
    if(logging_.load() == 0){
88
89
      while(level != 0 && global_process_rank_.load() < 0);
      if(level != 0) logfile_.open("exatn_exec_thread."+std::to_string(global_process_rank_.load())+".log", std::ios::out | std::ios::trunc);
90
91
92
93
94
95
96
    }else{
      if(level == 0) logfile_.close();
    }
    logging_.store(level);
    return;
  }

97
98
99
100
101
102
103
104
  /** Enforces serialized (synchronized) execution of the DAG. **/
  void resetSerialization(bool serialize,
                          bool validation_trace = false) {
    serialize_.store(serialize);
    validation_tracing_.store(serialize && validation_trace);
    return;
  }

105
106
107
108
109
110
  /** Activates/deactivates dry run (no actual computations). **/
  void activateDryRun(bool dry_run) {
   while(!node_executor_);
   return node_executor_->activateDryRun(dry_run);
  }

111
112
113
114
115
116
  /** Activates mixed-precision fast math on all devices (if available). **/
  void activateFastMath() {
    while(!node_executor_);
    return node_executor_->activateFastMath();
  }

117
  /** Returns the Host memory buffer size in bytes provided by the node executor. **/
118
  std::size_t getMemoryBufferSize() const {
119
    while(!node_executor_);
120
    return node_executor_->getMemoryBufferSize();
121
122
  }

123
124
125
126
127
128
  /** Returns the current value of the total Flop count executed by the node executor. **/
  double getTotalFlopCount() const {
    while(!node_executor_);
    return node_executor_->getTotalFlopCount();
  }

129
130
  /** Traverses the DAG and executes all its nodes (operations).
      [THREAD: This function is executed by the execution thread] **/
131
132
  virtual void execute(TensorGraph & dag) = 0;

133
134
135
136
  /** Traverses the list of tensor networks and executes them as a whole.
      [THREAD: This function is executed by the execution thread] **/
  virtual void execute(TensorNetworkQueue & tensor_network_queue) = 0;

137
138
139
  /** Regulates the tensor prefetch depth (0 turns prefetch off). **/
  virtual void setPrefetchDepth(unsigned int depth) = 0;

140
  /** Factory method **/
141
142
  virtual std::shared_ptr<TensorGraphExecutor> clone() = 0;

143
144
145
146
147
148
149
  /** Returns a local copy of a given tensor slice. **/
  std::shared_ptr<talsh::Tensor> getLocalTensor(const numerics::Tensor & tensor,
                 const std::vector<std::pair<DimOffset,DimExtent>> & slice_spec) {
    assert(node_executor_);
    return node_executor_->getLocalTensor(tensor,slice_spec);
  }

150
  /** Signals to stop execution of the DAG until later resume
151
      and waits until the execution has actually stopped.
152
153
154
155
156
157
158
      [THREAD: This function is executed by the main thread] **/
  void stopExecution() {
    stopping_.store(true);   //this signal will be picked by the execution thread
    while(active_.load()){}; //once the DAG execution is stopped the execution thread will set active_ to FALSE
    return;
  }

159
160
  inline double getTimeStampStart() const {return time_start_;}

161
162
163
164
  inline std::size_t incrementOpCounter() {return ++num_ops_issued_;}

  inline std::size_t getOpCounter() const {return num_ops_issued_.load();}

165
protected:
166

167
  std::shared_ptr<TensorNodeExecutor> node_executor_; //intr-node tensor operation executor
168
  std::atomic<std::size_t> num_ops_issued_; //total number of issued tensor operations
169
  std::atomic<int> process_rank_; //current process rank
170
  std::atomic<int> global_process_rank_; //current global process rank (in MPI_COMM_WORLD)
171
172
173
  std::atomic<int> logging_;      //logging level (0:none)
  std::atomic<bool> stopping_;    //signal to pause the execution thread
  std::atomic<bool> active_;      //TRUE while the execution thread is executing DAG operations
174
175
  std::atomic<bool> serialize_;   //serialization of the DAG execution
  std::atomic<bool> validation_tracing_; //validation tracing flag
176
  const double time_start_;       //start time stamp
177
  std::ofstream logfile_;         //logging file stream (output)
178
179
180
181
182
183
};

} //namespace runtime
} //namespace exatn

#endif //EXATN_RUNTIME_TENSOR_GRAPH_EXECUTOR_HPP_