ExaTN
tensor_op_transform.hpp
1 
13 #ifndef EXATN_NUMERICS_TENSOR_OP_TRANSFORM_HPP_
14 #define EXATN_NUMERICS_TENSOR_OP_TRANSFORM_HPP_
15 
16 #include "Identifiable.hpp"
17 
18 #include "tensor_basic.hpp"
19 #include "tensor_operation.hpp"
20 
21 #include "tensor_method.hpp"
22 
23 namespace exatn{
24 
25 namespace numerics{
26 
28 public:
29 
31 
32  TensorOpTransform(const TensorOpTransform &) = default;
33  TensorOpTransform & operator=(const TensorOpTransform &) = default;
34  TensorOpTransform(TensorOpTransform &&) noexcept = default;
35  TensorOpTransform & operator=(TensorOpTransform &&) noexcept = default;
36  virtual ~TensorOpTransform() = default;
37 
39  virtual bool isSet() const override;
40 
42  virtual int accept(runtime::TensorNodeExecutor & node_executor,
43  runtime::TensorOpExecHandle * exec_handle) override;
44 
46  static std::unique_ptr<TensorOperation> createNew();
47 
48  void resetFunctor(std::shared_ptr<talsh::TensorFunctor<Identifiable>> functor){
49  functor_ = functor;
50  return;
51  }
52 
53  int apply(talsh::Tensor & local_tensor){
54  if(functor_) return functor_->apply(local_tensor);
55  return 0;
56  }
57 
58 private:
59 
60  std::shared_ptr<talsh::TensorFunctor<Identifiable>> functor_; //tensor functor (method)
61 
62 };
63 
64 } //namespace numerics
65 
66 } //namespace exatn
67 
68 #endif //EXATN_NUMERICS_TENSOR_OP_TRANSFORM_HPP_
exatn
Definition: DriverClient.hpp:10
exatn::numerics::TensorOperation
Definition: tensor_operation.hpp:36
exatn::numerics::TensorOpTransform
Definition: tensor_op_transform.hpp:27
exatn::numerics::TensorOpTransform::createNew
static std::unique_ptr< TensorOperation > createNew()
Definition: tensor_op_transform.cpp:34
exatn::numerics::TensorOpTransform::isSet
virtual bool isSet() const override
Definition: tensor_op_transform.cpp:23
exatn::runtime::TensorNodeExecutor
Definition: tensor_node_executor.hpp:36
exatn::numerics::TensorOpTransform::accept
virtual int accept(runtime::TensorNodeExecutor &node_executor, runtime::TensorOpExecHandle *exec_handle) override
Definition: tensor_op_transform.cpp:28