Commit 4e88b67f authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh

Introduced ContractionSeqOptimizer interface.

parent 100a215c
......@@ -21,6 +21,7 @@ add_library(${LIBRARY_NAME}
tensor_op_factory.cpp
network_builder_mps.cpp
network_build_factory.cpp
contraction_seq_optimizer_dummy.cpp
tensor_network.cpp)
target_include_directories(${LIBRARY_NAME}
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer
REVISION: 2019/09/04
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
/** Rationale:
**/
#ifndef EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_HPP_
#define EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_HPP_
#include "tensor_basic.hpp"
#include <list>
namespace exatn{
namespace numerics{
//Tensor contraction triple:
struct ContrTriple{
unsigned int result_id; //id of the tensor-result (new)
unsigned int left_id; //id of the left input tensor (old)
unsigned int right_id; //id of the right input tensor (old)
};
class TensorNetwork;
class ContractionSeqOptimizer{
public:
virtual double determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq) = 0;
};
} //namespace numerics
} //namespace exatn
#endif //EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_HPP_
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Dummy
REVISION: 2019/09/04
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include "contraction_seq_optimizer_dummy.hpp"
#include "tensor_network.hpp"
namespace exatn{
namespace numerics{
double ContractionSeqOptimizerDummy::determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq)
{
double flops = 0.0;
//`Finish
return flops;
}
} //namespace numerics
} //namespace exatn
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Dummy
REVISION: 2019/09/04
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
/** Rationale:
**/
#ifndef EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_DUMMY_HPP_
#define EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_DUMMY_HPP_
#include "contraction_seq_optimizer.hpp"
namespace exatn{
namespace numerics{
class ContractionSeqOptimizerDummy: public ContractionSeqOptimizer{
public:
virtual double determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq) override;
};
} //namespace numerics
} //namespace exatn
#endif //EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_DUMMY_HPP_
/** ExaTN::Numerics: Tensor network
REVISION: 2019/08/15
REVISION: 2019/09/04
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -19,7 +19,7 @@ namespace exatn{
namespace numerics{
TensorNetwork::TensorNetwork():
explicit_output_(0), finalized_(0)
explicit_output_(0), finalized_(0), contraction_seq_flops_(0.0)
{
tensors_.emplace( //output tensor (id = 0)
std::make_pair(
......@@ -31,7 +31,7 @@ TensorNetwork::TensorNetwork():
TensorNetwork::TensorNetwork(const std::string & name):
explicit_output_(0), finalized_(0), name_(name)
explicit_output_(0), finalized_(0), name_(name), contraction_seq_flops_(0.0)
{
tensors_.emplace( //output tensor (id = 0)
std::make_pair(
......@@ -45,7 +45,7 @@ TensorNetwork::TensorNetwork(const std::string & name):
TensorNetwork::TensorNetwork(const std::string & name,
std::shared_ptr<Tensor> output_tensor,
const std::vector<TensorLeg> & output_legs):
explicit_output_(1), finalized_(0), name_(name)
explicit_output_(1), finalized_(0), name_(name), contraction_seq_flops_(0.0)
{
tensors_.emplace( //output tensor (id = 0)
std::make_pair(
......@@ -59,7 +59,7 @@ TensorNetwork::TensorNetwork(const std::string & name,
TensorNetwork::TensorNetwork(const std::string & name,
const std::string & tensor_network,
const std::map<std::string,std::shared_ptr<Tensor>> & tensors):
explicit_output_(1), finalized_(0), name_(name)
explicit_output_(1), finalized_(0), name_(name), contraction_seq_flops_(0.0)
{
//Convert tensor hypernetwork into regular tensor network, if needed:
//`Finish
......@@ -137,7 +137,7 @@ TensorNetwork::TensorNetwork(const std::string & name,
TensorNetwork::TensorNetwork(const std::string & name,
std::shared_ptr<Tensor> output_tensor,
NetworkBuilder & builder):
explicit_output_(0), finalized_(0), name_(name)
explicit_output_(0), finalized_(0), name_(name), contraction_seq_flops_(0.0)
{
tensors_.emplace( //output tensor (id = 0)
std::make_pair(
......@@ -295,6 +295,16 @@ void TensorNetwork::updateConnections(unsigned int tensor_id)
}
double TensorNetwork::determineContractionSequence(ContractionSeqOptimizer & contr_seq_optimizer)
{
assert(finalized_ != 0); //tensor network must be in finalized state
if(contraction_seq_.empty()){
contraction_seq_flops_ = contr_seq_optimizer.determineContractionSequence(*this,contraction_seq_);
}
return contraction_seq_flops_;
}
bool TensorNetwork::appendTensor(unsigned int tensor_id, //in: tensor id (unique within the tensor network)
std::shared_ptr<Tensor> tensor, //in: appended tensor
const std::vector<TensorLeg> & connections) //in: tensor connections (fully specified)
......@@ -450,6 +460,7 @@ bool TensorNetwork::appendTensor(unsigned int tensor_id,
return false;
}
}
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
finalized_ = 1; //implicit leg pairing always keeps the tensor network in a finalized state
return true;
}
......@@ -545,6 +556,7 @@ bool TensorNetwork::appendTensorNetwork(TensorNetwork && network,
);
}
this->updateConnections(0); //update connections in just appended input tensors
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
finalized_ = 1; //implicit leg pairing always keeps the primary tensor network in a finalized state
return true;
}
......@@ -639,6 +651,7 @@ bool TensorNetwork::deleteTensor(unsigned int tensor_id)
tensor = nullptr;
auto num_deleted = tensors_.erase(tensor_id);
assert(num_deleted == 1);
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
return true;
}
......@@ -717,6 +730,7 @@ bool TensorNetwork::mergeTensors(unsigned int left_id, unsigned int right_id, un
assert(num_deleted == 1);
//Update connections:
this->updateConnections(result_id);
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
return true;
}
......
/** ExaTN::Numerics: Tensor network
REVISION: 2019/08/09
REVISION: 2019/09/04
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -56,10 +56,12 @@ Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include "tensor_connected.hpp"
#include "tensor_op_factory.hpp"
#include "network_build_factory.hpp"
#include "contraction_seq_optimizer.hpp"
#include <unordered_map>
#include <map>
#include <vector>
#include <list>
#include <string>
#include <memory>
......@@ -82,9 +84,9 @@ public:
TensorNetwork(const std::string & name, //in: tensor network name
std::shared_ptr<Tensor> output_tensor, //in: fully specified output tensor of the tensor network
const std::vector<TensorLeg> & output_legs); //in: fully specified output tensor legs
/** Creates a named tensor network from a symbolic tensor network expression and vector of tensors. **/
TensorNetwork(const std::string & name, //in: tensor network name
const std::string & tensor_network, //in: tensor network expression (symbolic math expression)
/** Creates a named tensor network from a symbolic tensor network expression and a container of tensors. **/
TensorNetwork(const std::string & name, //in: tensor network name
const std::string & tensor_network, //in: tensor network expression (symbolic math expression)
const std::map<std::string,std::shared_ptr<Tensor>> & tensors); //in: participating tensors identified by their names
/** Builds a named tensor network from a template implemented by a custom tensor network builder. **/
TensorNetwork(const std::string & name, //in: tensor network name
......@@ -194,6 +196,10 @@ protected:
/** Updates tensor network linking when a tensor has its connections modified. **/
void updateConnections(unsigned int tensor_id); //in: id of the tensor whose connections were modified
/** Determines a pseudo-optimal tensor contraction sequence required for evaluating the tensor network.
Returns an estimate of the total flop count required by the returned contraction sequence. **/
double determineContractionSequence(ContractionSeqOptimizer & contr_seq_optimizer);
private:
int explicit_output_; //whether or not the output tensor has been fully specified during construction
......@@ -201,6 +207,8 @@ private:
std::string name_; //tensor network name
std::unordered_map<unsigned int, TensorConn> tensors_; //tensors connected to each other via legs (tensor connections)
//map: Non-negative tensor id --> Connected tensor
std::list<ContrTriple> contraction_seq_; //cached tensor contraction sequence
double contraction_seq_flops_; //flop estimate for the determined tensor contraction sequence
};
} //namespace numerics
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment