tensor_node_executor.hpp 5.65 KB
Newer Older
1
/** ExaTN:: Tensor Runtime: Tensor graph node executor
2
REVISION: 2022/01/17
3

4
5
Copyright (C) 2018-2022 Dmitry Lyakh, Tiffany Mintz, Alex McCaskey
Copyright (C) 2018-2022 Oak Ridge National Laboratory (UT-Battelle)
6
7

Rationale:
8
9
10
11
12
13
14
15
 (a) Tensor node executor provides actual implementation of registered
     tensor operations. Actual tensor operations are submitted for
     generally asynchronous execution via the .execute method overloads,
     which return an asynchronous execution handle assoicated with the
     submitted tensor operation. After submission, the completion status
     of the the tensor operation can be checked or enforced via the .sync
     method by providing the asynchronous execution handle previously
     returned by the .submit method.
16
17
18
19
20
**/

#ifndef EXATN_RUNTIME_TENSOR_NODE_EXECUTOR_HPP_
#define EXATN_RUNTIME_TENSOR_NODE_EXECUTOR_HPP_

21
22
#include "Identifiable.hpp"

23
#include "tensor_op_factory.hpp"
24
#include "tensor.hpp"
25
#include "space_register.hpp"
26

27
28
#include "param_conf.hpp"

29
30
#include "timers.hpp"

31
#include <vector>
32
33
#include <memory>

34
35
#include "errors.hpp"

36
37
38
39
namespace talsh{
class Tensor;
}

40
41
42
namespace exatn {
namespace runtime {

43
class TensorNodeExecutor : public Identifiable, public Cloneable<TensorNodeExecutor> {
44
45
46

public:

47
48
  virtual ~TensorNodeExecutor() = default;

49
  /** Explicitly initializes the underlying numerical service, if needed. **/
50
  virtual void initialize(const ParamConf & parameters) = 0;
51

52
53
54
  /** Activates dry run (no actual computations). **/
  virtual void activateDryRun(bool dry_run) = 0;

55
56
57
  /** Activates mixed-precision fast math on all devices (if available). **/
  virtual void activateFastMath() = 0;

58
  /** Returns the Host memory buffer size in bytes provided by the node executor. **/
59
  virtual std::size_t getMemoryBufferSize() const = 0;
60

61
62
63
64
  /** Returns the current memory usage by all allocated tensors.
      Note that the returned value includes buffer fragmentation overhead. **/
  virtual std::size_t getMemoryUsage(std::size_t * free_mem) const = 0;

65
66
67
  /** Returns the current value of the total Flop count executed by the node executor. **/
  virtual double getTotalFlopCount() const = 0;

68
69
70
71
72
73
74
75
76
77
  /** Executes the tensor operation found in a DAG node asynchronously,
      returning the execution handle in exec_handle that can later be
      used for testing for completion of the operation execution.
      Returns an integer error code (0:Success). **/
  virtual int execute(numerics::TensorOpCreate & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpDestroy & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpTransform & op,
                      TensorOpExecHandle * exec_handle) = 0;
78
79
80
81
  virtual int execute(numerics::TensorOpSlice & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpInsert & op,
                      TensorOpExecHandle * exec_handle) = 0;
82
83
84
85
  virtual int execute(numerics::TensorOpAdd & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpContract & op,
                      TensorOpExecHandle * exec_handle) = 0;
Dmitry I. Lyakh's avatar
Dmitry I. Lyakh committed
86
87
88
89
90
91
92
93
  virtual int execute(numerics::TensorOpDecomposeSVD3 & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpDecomposeSVD2 & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpOrthogonalizeSVD & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpOrthogonalizeMGS & op,
                      TensorOpExecHandle * exec_handle) = 0;
94
95
96
97
  virtual int execute(numerics::TensorOpFetch & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpUpload & op,
                      TensorOpExecHandle * exec_handle) = 0;
98
99
100
101
  virtual int execute(numerics::TensorOpBroadcast & op,
                      TensorOpExecHandle * exec_handle) = 0;
  virtual int execute(numerics::TensorOpAllreduce & op,
                      TensorOpExecHandle * exec_handle) = 0;
102
103
104
105

  /** Synchronizes the execution of a previously submitted tensor operation. **/
  virtual bool sync(TensorOpExecHandle op_handle,
                    int * error_code,
106
107
108
                    bool wait = true) = 0;

  /** Synchronizes the execution of all currently progressing tensor operations. **/
109
  virtual bool sync() = 0;
110

111
112
113
  /** Discards a previously submitted tensor operation. **/
  virtual bool discard(TensorOpExecHandle op_handle) = 0;

114
115
116
  /** Initiates tensor operand prefetch for a given tensor operation. **/
  virtual bool prefetch(const numerics::TensorOperation & op) = 0;

117
118
119
  /** Clears all internal node executor caches **/
  virtual void clearCache() = 0;

120
121
122
123
  /** Returns a local copy of a given tensor slice. **/
  virtual std::shared_ptr<talsh::Tensor> getLocalTensor(const numerics::Tensor & tensor,
                         const std::vector<std::pair<DimOffset,DimExtent>> & slice_spec) = 0;

124
125
  /** Returns a non-owning pointer to a local tensor data image on a given device.
      If unsuccessful, returns nullptr. **/
126
127
128
129
  virtual void * getTensorImage(const numerics::Tensor & tensor,   //in: tensor
                                int device_kind,                   //in: device kind (implementation specific)
                                int device_id,                     //in: device id: [0,1,2,..]
                                std::size_t * size = nullptr) const = 0; //out: tensor data image size in bytes
130
131

  /** Clones. **/
132
  virtual std::shared_ptr<TensorNodeExecutor> clone() = 0;
133
134
135
136
137
138
};

} //namespace runtime
} //namespace exatn

#endif //EXATN_RUNTIME_TENSOR_NODE_EXECUTOR_HPP_