ExaTN
tensor_exec_state.hpp
1 
39 #ifndef EXATN_RUNTIME_TENSOR_EXEC_STATE_HPP_
40 #define EXATN_RUNTIME_TENSOR_EXEC_STATE_HPP_
41 
42 #include "tensor_operation.hpp"
43 #include "tensor.hpp"
44 
45 #include <unordered_map>
46 #include <list>
47 #include <memory>
48 #include <atomic>
49 
50 namespace exatn {
51 namespace runtime {
52 
53 // Tensor Graph node id (DirectedBoostGraph vertex descriptor):
54 using VertexIdType = std::size_t; //must match with boost::graph vertex descriptor type
55 
56 // Tensor implementation:
57 using numerics::TensorHashType; //each numerics::Tensor has its unique integer hash
58 using numerics::Tensor;
59 using numerics::TensorOperation;
60 
61 
63 
64 protected:
65 
66  struct TensorExecInfo {
67  std::atomic<std::size_t> update_count; //total number of outstanding updates on a given Tensor in the current DAG
68  std::atomic<int> rw_epoch; //>0: number of current epoch reads; -1: current epoch write (single)
69  std::vector<VertexIdType> rw_epoch_nodes; //nodes participating in the current R/W epoch (either read or write)
70 
71  TensorExecInfo(): update_count(0), rw_epoch(0) {}
72  TensorExecInfo(const TensorExecInfo &) = delete;
73  TensorExecInfo & operator=(const TensorExecInfo &) = delete;
74  TensorExecInfo(TensorExecInfo &&) noexcept = delete;
75  TensorExecInfo & operator=(TensorExecInfo &&) noexcept = delete;
76  ~TensorExecInfo() = default;
77  };
78 
79 public:
80 
81  TensorExecState(): front_node_(0) {}
82 
83  TensorExecState(const TensorExecState &) = delete;
84  TensorExecState & operator=(const TensorExecState &) = delete;
85  TensorExecState(TensorExecState &&) noexcept = default;
86  TensorExecState & operator=(TensorExecState &&) noexcept = default;
87  ~TensorExecState() = default;
88 
92  const std::vector<VertexIdType> * getTensorEpochNodes(const Tensor & tensor,
93  int * epoch);
95  int registerTensorRead(const Tensor & tensor,
96  VertexIdType node_id);
98  int registerTensorWrite(const Tensor & tensor,
99  VertexIdType node_id);
100 
103  std::size_t registerWriteCompletion(const Tensor & tensor);
105  std::size_t getTensorUpdateCount(const Tensor & tensor);
106 
108  void registerDependencyFreeNode(VertexIdType node_id);
111  bool extractDependencyFreeNode(VertexIdType * node_id);
112 
114  void registerExecutingNode(VertexIdType node_id);
116  bool extractExecutingNode(VertexIdType * node_id);
117 
120  bool progressFrontNode(VertexIdType node_executed);
121 
123  VertexIdType getFrontNode() const;
124 
125 private:
128  std::unordered_map<TensorHashType,std::shared_ptr<TensorExecInfo>> tensor_info_;
130  std::list<VertexIdType> nodes_ready_;
132  std::list<VertexIdType> nodes_executing_;
134  VertexIdType front_node_;
135 };
136 
137 } // namespace runtime
138 } // namespace exatn
139 
140 #endif //EXATN_RUNTIME_TENSOR_EXEC_STATE_HPP_
exatn::numerics::Tensor
Definition: tensor.hpp:63
exatn::runtime::TensorExecState::registerWriteCompletion
std::size_t registerWriteCompletion(const Tensor &tensor)
Definition: tensor_exec_state.cpp:62
exatn::runtime::TensorExecState::registerTensorWrite
int registerTensorWrite(const Tensor &tensor, VertexIdType node_id)
Definition: tensor_exec_state.cpp:44
exatn::runtime::TensorExecState::registerExecutingNode
void registerExecutingNode(VertexIdType node_id)
Definition: tensor_exec_state.cpp:94
exatn::runtime::TensorExecState::extractDependencyFreeNode
bool extractDependencyFreeNode(VertexIdType *node_id)
Definition: tensor_exec_state.cpp:84
exatn::runtime::TensorExecState::getFrontNode
VertexIdType getFrontNode() const
Definition: tensor_exec_state.cpp:117
exatn
Definition: DriverClient.hpp:10
exatn::runtime::TensorExecState::registerDependencyFreeNode
void registerDependencyFreeNode(VertexIdType node_id)
Definition: tensor_exec_state.cpp:78
exatn::runtime::TensorExecState::getTensorEpochNodes
const std::vector< VertexIdType > * getTensorEpochNodes(const Tensor &tensor, int *epoch)
Definition: tensor_exec_state.cpp:16
exatn::runtime::TensorExecState::extractExecutingNode
bool extractExecutingNode(VertexIdType *node_id)
Definition: tensor_exec_state.cpp:100
exatn::runtime::TensorExecState::progressFrontNode
bool progressFrontNode(VertexIdType node_executed)
Definition: tensor_exec_state.cpp:110
exatn::runtime::TensorExecState::getTensorUpdateCount
std::size_t getTensorUpdateCount(const Tensor &tensor)
Definition: tensor_exec_state.cpp:70
exatn::runtime::TensorExecState::TensorExecInfo
Definition: tensor_exec_state.hpp:66
exatn::runtime::TensorExecState::registerTensorRead
int registerTensorRead(const Tensor &tensor, VertexIdType node_id)
Definition: tensor_exec_state.cpp:27
exatn::runtime::TensorExecState
Definition: tensor_exec_state.hpp:62