52 #ifndef EXATN_NUMERICS_TENSOR_NETWORK_HPP_
53 #define EXATN_NUMERICS_TENSOR_NETWORK_HPP_
55 #include "tensor_basic.hpp"
56 #include "tensor_connected.hpp"
57 #include "tensor_op_factory.hpp"
58 #include "network_build_factory.hpp"
59 #include "contraction_seq_optimizer.hpp"
61 #include <unordered_map>
75 using Iterator =
typename std::unordered_map<unsigned int, TensorConn>::iterator;
76 using ConstIterator =
typename std::unordered_map<unsigned int, TensorConn>::const_iterator;
84 std::shared_ptr<Tensor> output_tensor,
85 const std::vector<TensorLeg> & output_legs);
88 const std::string & tensor_network,
89 const std::map<std::string,std::shared_ptr<Tensor>> & tensors);
92 std::shared_ptr<Tensor> output_tensor,
124 const std::string &
getName()
const;
127 void rename(
const std::string & name);
131 std::shared_ptr<Tensor>
getTensor(
unsigned int tensor_id,
132 bool * conjugated =
nullptr);
138 inline Iterator
begin() {
return tensors_.begin();}
140 inline Iterator
end() {
return tensors_.end();}
142 inline ConstIterator
cbegin()
const {
return tensors_.cbegin();}
144 inline ConstIterator
cend()
const {
return tensors_.cend();}
148 bool finalize(
bool check_validity =
false);
157 std::shared_ptr<Tensor> tensor,
158 const std::vector<TensorLeg> & connections,
159 bool conjugated =
false,
160 bool leg_matching_check =
true);
170 std::shared_ptr<Tensor> tensor,
171 const std::vector<std::pair<unsigned int, unsigned int>> & pairing,
172 const std::vector<LegDirection> & leg_dir = std::vector<LegDirection>{},
173 bool conjugated =
false);
179 std::shared_ptr<Tensor> tensor,
180 const std::vector<unsigned int> & pairing,
181 bool conjugated =
false);
193 const std::vector<std::pair<unsigned int, unsigned int>> & pairing);
207 const std::vector<unsigned int> & pairing);
224 unsigned int right_id,
225 unsigned int result_id,
226 std::string * contr_pattern =
nullptr);
233 unsigned int left_tensor_id,
234 const std::string & left_tensor_name,
235 unsigned int right_tensor_id,
236 const std::string & right_tensor_name,
238 const std::vector<int> & right_dims);
249 unsigned int right_id,
250 double * arithm_intensity =
nullptr,
251 bool adjust_cost =
false);
254 std::list<std::shared_ptr<TensorOperation>> &
getOperationList(
const std::string & contr_seq_opt_name =
"dummy");
290 int explicit_output_;
293 std::unordered_map<unsigned int, TensorConn> tensors_;
295 double contraction_seq_flops_;
296 std::list<ContrTriple> contraction_seq_;
297 std::list<std::shared_ptr<TensorOperation>> operations_;
302 template<
typename... Args>
303 inline std::shared_ptr<numerics::TensorNetwork> makeSharedTensorNetwork(Args&&... args)
305 return std::make_shared<numerics::TensorNetwork>(args...);
310 #endif //EXATN_NUMERICS_TENSOR_NETWORK_HPP_
unsigned int getMaxTensorId() const
Definition: tensor_network.cpp:213
Definition: tensor_network.hpp:72
void conjugate()
Definition: tensor_network.cpp:1093
void updateConnectionsFromInputTensors()
Definition: tensor_network.cpp:344
unsigned int getNumTensors() const
Definition: tensor_network.cpp:207
const std::string & getName() const
Definition: tensor_network.cpp:223
bool finalize(bool check_validity=false)
Definition: tensor_network.cpp:272
bool deleteTensor(unsigned int tensor_id)
Definition: tensor_network.cpp:856
ConstIterator cbegin() const
Definition: tensor_network.hpp:142
bool appendTensorGate(unsigned int tensor_id, std::shared_ptr< Tensor > tensor, const std::vector< unsigned int > &pairing, bool conjugated=false)
Definition: tensor_network.cpp:541
Definition: tensor_shape.hpp:29
bool checkConnections()
Definition: tensor_network.cpp:313
bool appendTensorNetworkGate(TensorNetwork &&network, const std::vector< unsigned int > &pairing)
Definition: tensor_network.cpp:736
Definition: DriverClient.hpp:10
unsigned int getRank() const
Definition: tensor_network.cpp:200
bool appendTensorNetwork(TensorNetwork &&network, const std::vector< std::pair< unsigned int, unsigned int >> &pairing)
Definition: tensor_network.cpp:640
double getContractionCost(unsigned int left_id, unsigned int right_id, double *arithm_intensity=nullptr, bool adjust_cost=false)
Definition: tensor_network.cpp:1100
Iterator begin()
Definition: tensor_network.hpp:138
bool mergeTensors(unsigned int left_id, unsigned int right_id, unsigned int result_id, std::string *contr_pattern=nullptr)
Definition: tensor_network.cpp:922
void rename(const std::string &name)
Definition: tensor_network.cpp:229
Iterator end()
Definition: tensor_network.hpp:140
const std::vector< TensorLeg > * getTensorConnections(unsigned int tensor_id)
Definition: tensor_network.cpp:264
void invalidateContractionSequence()
Definition: tensor_network.cpp:353
ConstIterator cend() const
Definition: tensor_network.hpp:144
bool reorderOutputModes(const std::vector< unsigned int > &order)
Definition: tensor_network.cpp:828
bool isExplicit() const
Definition: tensor_network.cpp:188
void printIt() const
Definition: tensor_network.cpp:168
TensorNetwork()
Definition: tensor_network.cpp:30
std::shared_ptr< Tensor > getTensor(unsigned int tensor_id, bool *conjugated=nullptr)
Definition: tensor_network.cpp:255
bool appendTensor(unsigned int tensor_id, std::shared_ptr< Tensor > tensor, const std::vector< TensorLeg > &connections, bool conjugated=false, bool leg_matching_check=true)
Definition: tensor_network.cpp:374
bool isFinalized() const
Definition: tensor_network.cpp:194
bool splitTensor(unsigned int tensor_id, unsigned int left_tensor_id, const std::string &left_tensor_name, unsigned int right_tensor_id, const std::string &right_tensor_name, const TensorShape &contracted_dims, const std::vector< int > &right_dims)
Definition: tensor_network.cpp:1007
void updateConnections(unsigned int tensor_id)
Definition: tensor_network.cpp:323
std::list< std::shared_ptr< TensorOperation > > & getOperationList(const std::string &contr_seq_opt_name="dummy")
Definition: tensor_network.cpp:1134
bool isEmpty() const
Definition: tensor_network.cpp:182
TensorConn * getTensorConn(unsigned int tensor_id)
Definition: tensor_network.cpp:236
double determineContractionSequence(ContractionSeqOptimizer &contr_seq_optimizer)
Definition: tensor_network.cpp:362
Definition: network_builder.hpp:25
std::vector< TensorConn * > getTensorConnAll()
Definition: tensor_network.cpp:244