ExaTN
tensor.hpp
1 
38 #ifndef EXATN_NUMERICS_TENSOR_HPP_
39 #define EXATN_NUMERICS_TENSOR_HPP_
40 
41 #include "tensor_basic.hpp"
42 #include "tensor_shape.hpp"
43 #include "tensor_signature.hpp"
44 #include "tensor_leg.hpp"
45 
46 #include <assert.h>
47 
48 #include <iostream>
49 #include <fstream>
50 #include <type_traits>
51 #include <string>
52 #include <initializer_list>
53 #include <vector>
54 #include <list>
55 #include <memory>
56 
57 namespace exatn{
58 
59 namespace numerics{
60 
61 using TensorHashType = std::size_t;
62 
63 class Tensor{
64 public:
65 
67  Tensor(const std::string & name, //tensor name
68  const TensorShape & shape, //tensor shape
69  const TensorSignature & signature); //tensor signature
72  Tensor(const std::string & name, //tensor name
73  const TensorShape & shape); //tensor shape
75  template<typename T>
76  Tensor(const std::string & name, //tensor name
77  std::initializer_list<T> extents, //tensor dimension extents
78  std::initializer_list<std::pair<SpaceId,SubspaceId>> subspaces); //tensor dimension defining subspaces
79  template<typename T>
80  Tensor(const std::string & name, //tensor name
81  const std::vector<T> & extents, //tensor dimension extents
82  const std::vector<std::pair<SpaceId,SubspaceId>> & subspaces); //tensor dimension defining subspaces
85  template<typename T>
86  Tensor(const std::string & name, //tensor name
87  std::initializer_list<T> extents); //tensor dimension extents
88  template<typename T>
89  Tensor(const std::string & name, //tensor name
90  const std::vector<T> & extents); //tensor dimension extents
92  Tensor(const std::string & name); //tensor name
101  Tensor(const std::string & name, //tensor name
102  const Tensor & left_tensor, //left tensor
103  const Tensor & right_tensor, //right tensor
104  const std::vector<TensorLeg> & contraction); //tensor contraction pattern
105 
106  Tensor(const Tensor & tensor) = default;
107  Tensor & operator=(const Tensor & tensor) = default;
108  Tensor(Tensor && tensor) noexcept = default;
109  Tensor & operator=(Tensor && tensor) noexcept = default;
110  virtual ~Tensor() = default;
111 
113  void printIt() const;
114  void printItFile(std::ofstream & output_file) const;
115 
117  const std::string & getName() const;
119  unsigned int getRank() const;
121  const TensorShape & getShape() const;
123  const TensorSignature & getSignature() const;
124 
126  DimExtent getDimExtent(unsigned int dim_id) const;
128  const std::vector<DimExtent> & getDimExtents() const;
129 
131  SpaceId getDimSpaceId(unsigned int dim_id) const;
132  SubspaceId getDimSubspaceId(unsigned int dim_id) const;
133  std::pair<SpaceId,SubspaceId> getDimSpaceAttr(unsigned int dim_id) const;
134 
137  bool isCongruentTo(const Tensor & another) const;
138 
140  void deleteDimension(unsigned int dim_id);
141 
143  void appendDimension(std::pair<SpaceId,SubspaceId> subspace,
144  DimExtent dim_extent);
145  void appendDimension(DimExtent dim_extent);
146 
150  std::shared_ptr<Tensor> createSubtensor(const std::string & name, //in: subtensor name
151  const std::vector<int> & mode_mask, //in: mode masks
152  int mask_val); //in: chosen mask value
153 
155  void setElementType(TensorElementType element_type);
156 
158  TensorElementType getElementType() const;
159 
161  void registerIsometry(const std::vector<unsigned int> & isometry);
162 
164  const std::list<std::vector<unsigned int>> & retrieveIsometries() const;
165 
167  TensorHashType getTensorHash() const;
168 
169 private:
170 
171  std::string name_; //tensor name
172  TensorShape shape_; //tensor shape
173  TensorSignature signature_; //tensor signature
174  TensorElementType element_type_; //tensor element type (optional)
175  std::list<std::vector<unsigned int>> isometries_; //available isometries (optional)
176 };
177 
178 
179 //TEMPLATES:
180 template<typename T>
181 Tensor::Tensor(const std::string & name,
182  std::initializer_list<T> extents,
183  std::initializer_list<std::pair<SpaceId,SubspaceId>> subspaces):
184 name_(name), shape_(extents), signature_(subspaces), element_type_(TensorElementType::VOID)
185 {
186  //DEBUG:
187  if(signature_.getRank() != shape_.getRank()) std::cout << "ERROR(Tensor::Tensor): Signature/Shape size mismatch!" << std::endl;
188  assert(signature_.getRank() == shape_.getRank());
189 }
190 
191 template<typename T>
192 Tensor::Tensor(const std::string & name,
193  const std::vector<T> & extents,
194  const std::vector<std::pair<SpaceId,SubspaceId>> & subspaces):
195 name_(name), shape_(extents), signature_(subspaces), element_type_(TensorElementType::VOID)
196 {
197  //DEBUG:
198  if(signature_.getRank() != shape_.getRank()) std::cout << "ERROR(Tensor::Tensor): Signature/Shape size mismatch!" << std::endl;
199  assert(signature_.getRank() == shape_.getRank());
200 }
201 
202 template<typename T>
203 Tensor::Tensor(const std::string & name,
204  std::initializer_list<T> extents):
205 name_(name), shape_(extents), signature_(static_cast<unsigned int>(extents.size())), element_type_(TensorElementType::VOID)
206 {
207 }
208 
209 template<typename T>
210 Tensor::Tensor(const std::string & name,
211  const std::vector<T> & extents):
212 name_(name), shape_(extents), signature_(static_cast<unsigned int>(extents.size())), element_type_(TensorElementType::VOID)
213 {
214 }
215 
216 } //namespace numerics
217 
218 template<typename... Args>
219 inline std::shared_ptr<numerics::Tensor> makeSharedTensor(Args&&... args)
220 {
221  return std::make_shared<numerics::Tensor>(args...);
222 }
223 
224 } //namespace exatn
225 
226 #endif //EXATN_NUMERICS_TENSOR_HPP_
exatn::numerics::Tensor
Definition: tensor.hpp:63
exatn::numerics::Tensor::Tensor
Tensor(const std::string &name, const TensorShape &shape, const TensorSignature &signature)
Definition: tensor.cpp:16
exatn::numerics::Tensor::getElementType
TensorElementType getElementType() const
Definition: tensor.cpp:191
exatn::numerics::Tensor::isCongruentTo
bool isCongruentTo(const Tensor &another) const
Definition: tensor.cpp:146
exatn::numerics::Tensor::deleteDimension
void deleteDimension(unsigned int dim_id)
Definition: tensor.cpp:152
exatn::numerics::Tensor::registerIsometry
void registerIsometry(const std::vector< unsigned int > &isometry)
Definition: tensor.cpp:196
exatn::numerics::Tensor::getDimSpaceId
SpaceId getDimSpaceId(unsigned int dim_id) const
Definition: tensor.cpp:131
exatn::numerics::TensorShape
Definition: tensor_shape.hpp:29
exatn::numerics::Tensor::getName
const std::string & getName() const
Definition: tensor.cpp:101
exatn
Definition: DriverClient.hpp:10
exatn::numerics::TensorSignature::getRank
unsigned int getRank() const
Definition: tensor_signature.cpp:63
exatn::numerics::Tensor::getTensorHash
TensorHashType getTensorHash() const
Definition: tensor.cpp:210
exatn::numerics::Tensor::appendDimension
void appendDimension(std::pair< SpaceId, SubspaceId > subspace, DimExtent dim_extent)
Definition: tensor.cpp:159
exatn::numerics::Tensor::getDimExtent
DimExtent getDimExtent(unsigned int dim_id) const
Definition: tensor.cpp:121
exatn::numerics::Tensor::setElementType
void setElementType(TensorElementType element_type)
Definition: tensor.cpp:186
exatn::numerics::TensorSignature
Definition: tensor_signature.hpp:35
exatn::numerics::Tensor::retrieveIsometries
const std::list< std::vector< unsigned int > > & retrieveIsometries() const
Definition: tensor.cpp:205
exatn::numerics::Tensor::getSignature
const TensorSignature & getSignature() const
Definition: tensor.cpp:116
exatn::numerics::Tensor::getDimExtents
const std::vector< DimExtent > & getDimExtents() const
Definition: tensor.cpp:126
exatn::numerics::TensorShape::getRank
unsigned int getRank() const
Definition: tensor_shape.cpp:48
exatn::numerics::Tensor::printIt
void printIt() const
Definition: tensor.cpp:89
exatn::numerics::Tensor::getShape
const TensorShape & getShape() const
Definition: tensor.cpp:111
exatn::numerics::Tensor::createSubtensor
std::shared_ptr< Tensor > createSubtensor(const std::string &name, const std::vector< int > &mode_mask, int mask_val)
Definition: tensor.cpp:172
exatn::numerics::Tensor::getRank
unsigned int getRank() const
Definition: tensor.cpp:106