ExaTN
tensor_operation.hpp
1 
11 #ifndef EXATN_NUMERICS_TENSOR_OPERATION_HPP_
12 #define EXATN_NUMERICS_TENSOR_OPERATION_HPP_
13 
14 #include "tensor_basic.hpp"
15 #include "tensor.hpp"
16 
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include <complex>
21 
22 #include <iostream>
23 #include <fstream>
24 
25 namespace exatn{
26 
27 namespace runtime{
28  // Tensor operation execution handle:
29  using TensorOpExecHandle = std::size_t;
30  // Forward for TensorNodeExecutor:
31  class TensorNodeExecutor;
32 }
33 
34 namespace numerics{
35 
36 class TensorOperation{ //abstract
37 public:
38 
41  TensorOperation(TensorOpCode opcode, //tensor operation code
42  unsigned int num_operands, //required number of tensor operands
43  unsigned int num_scalars); //required number of scalar operands
44 
45  TensorOperation(const TensorOperation &) = default;
46  TensorOperation & operator=(const TensorOperation &) = default;
47  TensorOperation(TensorOperation &&) noexcept = default;
48  TensorOperation & operator=(TensorOperation &&) noexcept = default;
49  virtual ~TensorOperation() = default;
50 
52  virtual bool isSet() const = 0;
53 
58  virtual int accept(runtime::TensorNodeExecutor & node_executor,
59  runtime::TensorOpExecHandle * exec_handle) = 0;
60 
62  virtual void printIt() const;
63  virtual void printItFile(std::ofstream & output_file) const;
64 
66  TensorOpCode getOpcode() const;
67 
69  unsigned int getNumOperands() const;
70 
72  unsigned int getNumOperandsSet() const;
73 
75  TensorHashType getTensorOperandHash(unsigned int op_num) const;
76 
78  std::shared_ptr<Tensor> getTensorOperand(unsigned int op_num,
79  bool * conjugated = nullptr) const;
80 
82  void setTensorOperand(std::shared_ptr<Tensor> tensor,
83  bool conjugated = false);
84 
86  unsigned int getNumScalars() const;
87 
89  unsigned int getNumScalarsSet() const;
90 
92  std::complex<double> getScalar(unsigned int scalar_num) const;
93 
95  void setScalar(unsigned int scalar_num,
96  const std::complex<double> scalar);
97 
99  const std::string & getIndexPattern() const;
100 
103  void setIndexPattern(const std::string & pattern);
104 
106  void setId(std::size_t id);
107 
109  std::size_t getId() const;
110 
111 protected:
112 
113  std::string pattern_; //symbolic index pattern
114  std::vector<std::pair<std::shared_ptr<Tensor>,bool>> operands_; //tensor operands (non-owning pointers)
115  std::vector<std::complex<double>> scalars_; //additional scalars (prefactors)
116  unsigned int num_operands_; //number of required tensor operands
117  unsigned int num_scalars_; //number of required scalar arguments
118  TensorOpCode opcode_; //tensor operation code
119  std::size_t id_; //tensor operation id (unique integer identifier)
120 };
121 
122 using createTensorOpFn = std::unique_ptr<TensorOperation> (*)(void);
123 
124 } //namespace numerics
125 
126 } //namespace exatn
127 
128 #endif //EXATN_NUMERICS_TENSOR_OPERATION_HPP_
exatn::numerics::TensorOperation::getNumOperandsSet
unsigned int getNumOperandsSet() const
Definition: tensor_operation.cpp:69
exatn::numerics::TensorOperation::getNumOperands
unsigned int getNumOperands() const
Definition: tensor_operation.cpp:64
exatn::numerics::TensorOperation::getTensorOperandHash
TensorHashType getTensorOperandHash(unsigned int op_num) const
Definition: tensor_operation.cpp:74
exatn::numerics::TensorOperation::getNumScalars
unsigned int getNumScalars() const
Definition: tensor_operation.cpp:96
exatn::numerics::TensorOperation::getId
std::size_t getId() const
Definition: tensor_operation.cpp:137
exatn::numerics::TensorOperation::isSet
virtual bool isSet() const =0
exatn::numerics::TensorOperation::getOpcode
TensorOpCode getOpcode() const
Definition: tensor_operation.cpp:59
exatn::numerics::TensorOperation::getScalar
std::complex< double > getScalar(unsigned int scalar_num) const
Definition: tensor_operation.cpp:106
exatn::numerics::TensorOperation::getIndexPattern
const std::string & getIndexPattern() const
Definition: tensor_operation.cpp:119
exatn
Definition: DriverClient.hpp:10
exatn::numerics::TensorOperation::getNumScalarsSet
unsigned int getNumScalarsSet() const
Definition: tensor_operation.cpp:101
exatn::numerics::TensorOperation::setIndexPattern
void setIndexPattern(const std::string &pattern)
Definition: tensor_operation.cpp:124
exatn::numerics::TensorOperation
Definition: tensor_operation.hpp:36
exatn::numerics::TensorOperation::accept
virtual int accept(runtime::TensorNodeExecutor &node_executor, runtime::TensorOpExecHandle *exec_handle)=0
exatn::numerics::TensorOperation::printIt
virtual void printIt() const
Definition: tensor_operation.cpp:22
exatn::numerics::TensorOperation::TensorOperation
TensorOperation(TensorOpCode opcode, unsigned int num_operands, unsigned int num_scalars)
Definition: tensor_operation.cpp:15
exatn::numerics::TensorOperation::setTensorOperand
void setTensorOperand(std::shared_ptr< Tensor > tensor, bool conjugated=false)
Definition: tensor_operation.cpp:88
exatn::runtime::TensorNodeExecutor
Definition: tensor_node_executor.hpp:36
exatn::numerics::TensorOperation::getTensorOperand
std::shared_ptr< Tensor > getTensorOperand(unsigned int op_num, bool *conjugated=nullptr) const
Definition: tensor_operation.cpp:79
exatn::numerics::TensorOperation::setScalar
void setScalar(unsigned int scalar_num, const std::complex< double > scalar)
Definition: tensor_operation.cpp:112
exatn::numerics::TensorOperation::setId
void setId(std::size_t id)
Definition: tensor_operation.cpp:131