ExaTN
tensor_shape.hpp
1 
12 #ifndef EXATN_NUMERICS_TENSOR_SHAPE_HPP_
13 #define EXATN_NUMERICS_TENSOR_SHAPE_HPP_
14 
15 #include "tensor_basic.hpp"
16 
17 #include <iostream>
18 #include <fstream>
19 #include <type_traits>
20 #include <initializer_list>
21 #include <vector>
22 
23 #include <cassert>
24 
25 namespace exatn{
26 
27 namespace numerics{
28 
30 public:
31 
33  template<typename T>
34  TensorShape(std::initializer_list<T> extents);
35  template<typename T>
36  TensorShape(const std::vector<T> & extents);
38  TensorShape();
39 
40  TensorShape(const TensorShape & tens_shape) = default;
41  TensorShape & operator=(const TensorShape & tens_shape) = default;
42  TensorShape(TensorShape && tens_shape) noexcept = default;
43  TensorShape & operator=(TensorShape && tens_shape) noexcept = default;
44  virtual ~TensorShape() = default;
45 
47  void printIt() const;
48  void printItFile(std::ofstream & output_file) const;
49 
51  unsigned int getRank() const;
52 
54  DimExtent getDimExtent(unsigned int dim_id) const;
55 
57  const std::vector<DimExtent> & getDimExtents() const;
58 
60  bool isCongruentTo(const TensorShape & another) const;
61 
63  void resetDimension(unsigned int dim_id, DimExtent extent);
64 
66  void deleteDimension(unsigned int dim_id);
67 
69  void appendDimension(DimExtent dim_extent);
70 
71 private:
72 
73  std::vector<DimExtent> extents_; //tensor dimension extents
74 };
75 
76 
77 //TEMPLATES:
78 template<typename T>
79 TensorShape::TensorShape(std::initializer_list<T> extents):
80 extents_(extents.size())
81 {
82  static_assert(std::is_integral<T>::value,"FATAL(TensorShape::TensorShape): TensorShape extent type must be integral!");
83 
84  //DEBUG:
85  for(const auto & extent: extents){
86  if(extent < 0) std::cout << "ERROR(TensorShape::TensorShape): Negative dimension extent passed!" << std::endl;
87  assert(extent >= 0);
88  }
89 
90  int i = 0;
91  for(const auto & extent: extents) extents_[i++] = static_cast<DimExtent>(extent);
92 }
93 
94 template<typename T>
95 TensorShape::TensorShape(const std::vector<T> & extents):
96 extents_(extents.size())
97 {
98  static_assert(std::is_integral<T>::value,"FATAL(TensorShape::TensorShape): TensorShape extent type must be integral!");
99 
100  //DEBUG:
101  for(const auto & extent: extents){
102  if(extent < 0) std::cout << "ERROR(TensorShape::TensorShape): Negative dimension extent passed!" << std::endl;
103  assert(extent >= 0);
104  }
105 
106  int i = 0;
107  for(const auto & extent: extents) extents_[i++] = static_cast<DimExtent>(extent);
108 }
109 
110 } //namespace numerics
111 
112 } //namespace exatn
113 
114 #endif //EXATN_NUMERICS_TENSOR_SHAPE_HPP_
exatn::numerics::TensorShape::printIt
void printIt() const
Definition: tensor_shape.cpp:20
exatn::numerics::TensorShape::getDimExtents
const std::vector< DimExtent > & getDimExtents() const
Definition: tensor_shape.cpp:59
exatn::numerics::TensorShape::deleteDimension
void deleteDimension(unsigned int dim_id)
Definition: tensor_shape.cpp:81
exatn::numerics::TensorShape::resetDimension
void resetDimension(unsigned int dim_id, DimExtent extent)
Definition: tensor_shape.cpp:74
exatn::numerics::TensorShape
Definition: tensor_shape.hpp:29
exatn
Definition: DriverClient.hpp:10
exatn::numerics::TensorShape::TensorShape
TensorShape()
Definition: tensor_shape.cpp:16
exatn::numerics::TensorShape::getDimExtent
DimExtent getDimExtent(unsigned int dim_id) const
Definition: tensor_shape.cpp:53
exatn::numerics::TensorShape::appendDimension
void appendDimension(DimExtent dim_extent)
Definition: tensor_shape.cpp:88
exatn::numerics::TensorShape::getRank
unsigned int getRank() const
Definition: tensor_shape.cpp:48
exatn::numerics::TensorShape::isCongruentTo
bool isCongruentTo(const TensorShape &another) const
Definition: tensor_shape.cpp:64