ExaTN
node_executor_talsh.hpp
1 
11 #ifndef EXATN_RUNTIME_TALSH_NODE_EXECUTOR_HPP_
12 #define EXATN_RUNTIME_TALSH_NODE_EXECUTOR_HPP_
13 
14 #include "tensor_node_executor.hpp"
15 
16 #include "talshxx.hpp"
17 
18 #include <unordered_map>
19 #include <vector>
20 #include <memory>
21 
22 namespace exatn {
23 namespace runtime {
24 
26 
27 public:
28 
29  TalshNodeExecutor() = default;
30  TalshNodeExecutor(const TalshNodeExecutor &) = delete;
31  TalshNodeExecutor & operator=(const TalshNodeExecutor &) = delete;
32  TalshNodeExecutor(TalshNodeExecutor &&) noexcept = delete;
33  TalshNodeExecutor & operator=(TalshNodeExecutor &&) noexcept = delete;
34  virtual ~TalshNodeExecutor();
35 
36  void initialize() override;
37 
39  TensorOpExecHandle * exec_handle) override;
41  TensorOpExecHandle * exec_handle) override;
43  TensorOpExecHandle * exec_handle) override;
45  TensorOpExecHandle * exec_handle) override;
47  TensorOpExecHandle * exec_handle) override;
48 
49  bool sync(TensorOpExecHandle op_handle,
50  int * error_code,
51  bool wait = false) override;
52 
53  std::shared_ptr<talsh::Tensor> getLocalTensor(const numerics::Tensor & tensor,
54  const std::vector<std::pair<DimOffset,DimExtent>> & slice_spec) override;
55 
56  const std::string name() const override {return "talsh-node-executor";}
57  const std::string description() const override {return "TALSH tensor graph node executor";}
58  std::shared_ptr<TensorNodeExecutor> clone() override {return std::make_shared<TalshNodeExecutor>();}
59 
60 protected:
62  std::unordered_map<numerics::TensorHashType,std::shared_ptr<talsh::Tensor>> tensors_;
64  std::unordered_map<TensorOpExecHandle,std::shared_ptr<talsh::TensorTask>> tasks_;
66  static bool talsh_initialized_;
69 };
70 
71 
73 inline int get_talsh_tensor_element_kind(TensorElementType element_type)
74 {
75  int talsh_data_kind = NO_TYPE;
76  switch(element_type){
77  case TensorElementType::REAL32: talsh_data_kind = R4; break;
78  case TensorElementType::REAL64: talsh_data_kind = R8; break;
79  case TensorElementType::COMPLEX32: talsh_data_kind = C4; break;
80  case TensorElementType::COMPLEX64: talsh_data_kind = C8; break;
81  }
82  return talsh_data_kind;
83 }
84 
86 inline TensorElementType get_exatn_tensor_element_kind(int element_type)
87 {
88  switch(element_type){
89  case R4: return TensorElementType::REAL32;
90  case R8: return TensorElementType::REAL64;
91  case C4: return TensorElementType::COMPLEX32;
92  case C8: return TensorElementType::COMPLEX64;
93  }
94  return TensorElementType::VOID;
95 }
96 
97 } //namespace runtime
98 } //namespace exatn
99 
100 #endif //EXATN_RUNTIME_TALSH_NODE_EXECUTOR_HPP_
exatn::numerics::Tensor
Definition: tensor.hpp:63
exatn::runtime::TalshNodeExecutor::getLocalTensor
std::shared_ptr< talsh::Tensor > getLocalTensor(const numerics::Tensor &tensor, const std::vector< std::pair< DimOffset, DimExtent >> &slice_spec) override
Definition: node_executor_talsh.cpp:243
exatn::runtime::TalshNodeExecutor::execute
int execute(numerics::TensorOpCreate &op, TensorOpExecHandle *exec_handle) override
Definition: node_executor_talsh.cpp:66
exatn::runtime::TalshNodeExecutor::tasks_
std::unordered_map< TensorOpExecHandle, std::shared_ptr< talsh::TensorTask > > tasks_
Definition: node_executor_talsh.hpp:64
exatn::runtime::TalshNodeExecutor::talsh_node_exec_count_
static int talsh_node_exec_count_
Definition: node_executor_talsh.hpp:68
exatn
Definition: DriverClient.hpp:10
exatn::numerics::TensorOpAdd
Definition: tensor_op_add.hpp:21
exatn::runtime::TalshNodeExecutor
Definition: node_executor_talsh.hpp:25
exatn::runtime::TalshNodeExecutor::talsh_initialized_
static bool talsh_initialized_
Definition: node_executor_talsh.hpp:66
exatn::runtime::TalshNodeExecutor::tensors_
std::unordered_map< numerics::TensorHashType, std::shared_ptr< talsh::Tensor > > tensors_
Definition: node_executor_talsh.hpp:62
exatn::numerics::TensorOpDestroy
Definition: tensor_op_destroy.hpp:21
exatn::numerics::TensorOpTransform
Definition: tensor_op_transform.hpp:27
exatn::runtime::TalshNodeExecutor::initialize
void initialize() override
Definition: node_executor_talsh.cpp:25
exatn::runtime::TalshNodeExecutor::sync
bool sync(TensorOpExecHandle op_handle, int *error_code, bool wait=false) override
Definition: node_executor_talsh.cpp:219
exatn::numerics::TensorOpContract
Definition: tensor_op_contract.hpp:22
exatn::numerics::TensorOpCreate
Definition: tensor_op_create.hpp:21
exatn::runtime::TensorNodeExecutor
Definition: tensor_node_executor.hpp:36