TensorGraph.hpp 2.24 KB
Newer Older
1
2
3
4
#ifndef XACC_QUANTUM_GRAPH_HPP_
#define XACC_QUANTUM_GRAPH_HPP_

#include "Identifiable.hpp"
5
#include "tensor_operation.hpp"
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

#include <fstream>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#include <vector>

namespace exatn {

// For now
using TensorOp = int;

class TensorOpNode {
public:
  TensorOpNode() : op(nullptr) {}
22
23
  TensorOpNode(std::shared_ptr<TensorOperation> o) : op(o) {}
  std::shared_ptr<TensorOperation> op;
24
  bool executed = false;
25
  bool is_noop = false;
26
  int id;
27
28
29
30
31
32
33
34
  // Add any other info you need
};

// Public Graph API
class TensorGraph : public Identifiable, public Cloneable<TensorGraph> {
public:
  // Add an edge between src and tgt, this is
  // a directed edge
35
  virtual void addEdge(const TensorOpNode &srcNode, const TensorOpNode &tgtNode) = 0;
36
37
38
39
40
41

  virtual void addVertex(TensorOpNode &opNode) = 0;
  virtual void addVertex(TensorOpNode &&opNode) = 0;

  // For now lets assume as you build it,
  // you can't change the structure or the node values
42
  // virtual void removeEdge(const TensorOpNode &srcNode, const TensorOpNode &tgtNode) = 0;
43
44
45
46
47
  // virtual void setVertexProperties(const int index, TensorOpNode& opNode) =
  // 0; virtual void setVertexProperties(const int index, TensorOpNode&& opNode)
  // = 0;

  // Get the TensorOpNode at the given index
48
  virtual const std::shared_ptr<TensorOpNode> &getVertexProperties(const int index) = 0;
49
50
51

  // Flip the bool on the TensorOpNode to indicate this
  // node has been executed
52
  virtual void setNodeExecuted(const int index) = 0;
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

  // Return true if edge exists
  virtual bool edgeExists(const int srcIndex, const int tgtIndex) = 0;

  // Get how many vertices this vertex is connected to
  virtual int degree(const int index) = 0;

  // Get all vertex indices this vertex is connected to
  virtual std::vector<int> getNeighborList(const int index) = 0;

  // n edges
  virtual int size() = 0;

  // n vertices
  virtual int order() = 0;

  // Compute shortest path from start index
  virtual void computeShortestPath(int startIndex,
                                   std::vector<double> &distances,
                                   std::vector<int> &paths) = 0;

  // needed for plugin registry
  virtual std::shared_ptr<TensorGraph> clone() = 0;
};

} // namespace exatn
79
#endif