Commit e4abd02b authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh

Enabled caching of optimized tensor network contraction sequences for reuse.

To activate this feature call exatn::activateContrSeqCaching().
parent 50324f8c
/** ExaTN::Numerics: General client header
REVISION: 2020/07/06
REVISION: 2020/07/08
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -632,6 +632,16 @@ inline void resetContrSeqOptimizer(const std::string & optimizer_name)
{return numericalServer->resetContrSeqOptimizer(optimizer_name);}
/** Activates optimized tensor contraction sequence caching for later reuse. **/
inline void activateContrSeqCaching()
{return numericalServer->activateContrSeqCaching();}
/** Deactivates optimized tensor contraction sequence caching. **/
inline void deactivateContrSeqCaching()
{return numericalServer->deactivateContrSeqCaching();}
/** Resets tensor runtime logging level (0:none). **/
inline void resetRuntimeLoggingLevel(int level = 0)
{return numericalServer->resetRuntimeLoggingLevel(level);}
......
/** ExaTN::Numerics: Numerical server
REVISION: 2020/07/07
REVISION: 2020/07/08
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -29,7 +29,7 @@ NumServer::NumServer(const MPICommProxy & communicator,
const ParamConf & parameters,
const std::string & graph_executor_name,
const std::string & node_executor_name):
contr_seq_optimizer_("metis"), intra_comm_(communicator)
contr_seq_optimizer_("metis"), contr_seq_caching_(false), intra_comm_(communicator)
{
int mpi_error = MPI_Comm_size(*(communicator.get<MPI_Comm>()),&num_processes_); assert(mpi_error == MPI_SUCCESS);
mpi_error = MPI_Comm_rank(*(communicator.get<MPI_Comm>()),&process_rank_); assert(mpi_error == MPI_SUCCESS);
......@@ -47,7 +47,7 @@ NumServer::NumServer(const MPICommProxy & communicator,
NumServer::NumServer(const ParamConf & parameters,
const std::string & graph_executor_name,
const std::string & node_executor_name):
contr_seq_optimizer_("metis")
contr_seq_optimizer_("metis"), contr_seq_caching_(false)
{
num_processes_ = 1; process_rank_ = 0;
process_world_ = std::make_shared<ProcessGroup>(intra_comm_,num_processes_); //intra-communicator is empty here
......@@ -103,9 +103,22 @@ void NumServer::reconfigureTensorRuntime(const ParamConf & parameters,
}
#endif
void NumServer::resetContrSeqOptimizer(const std::string & optimizer_name)
void NumServer::resetContrSeqOptimizer(const std::string & optimizer_name, bool caching)
{
contr_seq_optimizer_ = optimizer_name;
contr_seq_caching_ = caching;
return;
}
void NumServer::activateContrSeqCaching()
{
contr_seq_caching_ = true;
return;
}
void NumServer::deactivateContrSeqCaching()
{
contr_seq_caching_ = false;
return;
}
......@@ -321,6 +334,10 @@ bool NumServer::submit(const ProcessGroup & process_group,
<< process_group.getMemoryLimitPerProcess() << std::endl << std::flush; //debug
//Get tensor operation list:
if(contr_seq_caching_ && network.exportContractionSequence().empty()){ //check whether the optimal tensor contraction sequence is already available from the past
auto cached_seq = ContractionSeqOptimizer::findContractionSequence(network);
if(cached_seq.first != nullptr) network.importContractionSequence(*(cached_seq.first),cached_seq.second);
}
auto & op_list = network.getOperationList(contr_seq_optimizer_,(num_procs > 1));
const double max_intermediate_presence_volume = network.getMaxIntermediatePresenceVolume();
double max_intermediate_volume = network.getMaxIntermediateVolume();
......
/** ExaTN::Numerics: Numerical server
REVISION: 2020/07/06
REVISION: 2020/07/08
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -134,7 +134,14 @@ public:
/** Resets the tensor contraction sequence optimizer that is
invoked when evaluating tensor networks. **/
void resetContrSeqOptimizer(const std::string & optimizer_name);
void resetContrSeqOptimizer(const std::string & optimizer_name, //in: tensor contraction sequence optimizer name
bool caching = false); //whether or not optimized tensor contraction sequence will be cached for later reuse
/** Activates optimized tensor contraction sequence caching for later reuse. **/
void activateContrSeqCaching();
/** Deactivates optimized tensor contraction sequence caching. **/
void deactivateContrSeqCaching();
/** Resets the runtime logging level (0:none). **/
void resetRuntimeLoggingLevel(int level = 0);
......@@ -572,6 +579,7 @@ private:
std::list<std::shared_ptr<Tensor>> implicit_tensors_; //tensors created implicitly by the runtime (for garbage collection)
std::string contr_seq_optimizer_; //tensor contraction sequence optimizer invoked when evaluating tensor networks
bool contr_seq_caching_; //regulates whether or not to cache pseudo-optimal tensor contraction orders for later reuse
std::map<std::string,std::shared_ptr<TensorMethod>> ext_methods_; //external tensor methods
std::map<std::string,std::shared_ptr<BytePacket>> ext_data_; //external data
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Base
REVISION: 2020/07/06
REVISION: 2020/07/08
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -12,14 +12,17 @@ namespace exatn{
namespace numerics{
std::unordered_map<std::string,std::pair<MetisGraph,std::list<ContrTriple>>> ContractionSeqOptimizer::cached_contr_seqs_;
std::unordered_map<std::string,ContractionSeqOptimizer::CachedContrSeq> ContractionSeqOptimizer::cached_contr_seqs_;
bool ContractionSeqOptimizer::cacheContractionSequence(const TensorNetwork & network)
{
auto res = cached_contr_seqs_.emplace(network.getName(),
std::make_pair(MetisGraph(network),network.exportContractionSequence()));
return res.second;
if(!(network.exportContractionSequence().empty())){
auto res = cached_contr_seqs_.emplace(network.getName(),
std::move(CachedContrSeq{std::make_shared<MetisGraph>(network),network.exportContractionSequence(),network.getFMAFlops()}));
return res.second;
}
return false;
}
......@@ -30,14 +33,14 @@ bool ContractionSeqOptimizer::eraseContractionSequence(const TensorNetwork & net
}
const std::list<ContrTriple> * ContractionSeqOptimizer::findContractionSequence(const TensorNetwork & network)
std::pair<const std::list<ContrTriple> *, double> ContractionSeqOptimizer::findContractionSequence(const TensorNetwork & network)
{
auto iter = cached_contr_seqs_.find(network.getName());
if(iter != cached_contr_seqs_.end()){
MetisGraph network_graph(network);
if(network_graph == iter->second.first) return &(iter->second.second);
if(network_graph == *(iter->second.graph)) return std::make_pair(&(iter->second.contr_seq),iter->second.fma_flops);
};
return nullptr;
return std::pair<const std::list<ContrTriple> *, double> {nullptr,0.0};
}
} //namespace numerics
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer
REVISION: 2020/07/06
REVISION: 2020/07/08
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -59,16 +59,21 @@ public:
static bool eraseContractionSequence(const TensorNetwork & network); //in: tensor network
/** Retrieves a previously cached tensor contraction sequence for a given tensor
network. Returns nullptr in case it has not been cached before. **/
static const std::list<ContrTriple> * findContractionSequence(const TensorNetwork & network); //in: tensor network
network and its corresponding FMA flop count. Returns {nullptr,0.0} in case
no previously cached tensor contraction sequence has been found. **/
static std::pair<const std::list<ContrTriple> *, double> findContractionSequence(const TensorNetwork & network); //in: tensor network
private:
//Cached optimized tensor contraction sequence:
struct CachedContrSeq{
std::shared_ptr<MetisGraph> graph; //METIS graph of the tensor network
std::list<ContrTriple> contr_seq; //optimized tensor contraction sequence for the tensor network
double fma_flops; //FMA flop count for the stored tensor contraction sequence
};
/** Cached tensor contraction sequences. **/
static std::unordered_map<std::string, //tensor network name
std::pair<MetisGraph, //METIS graph of the tensor network
std::list<ContrTriple> //tensor contraction sequence
>> cached_contr_seqs_;
static std::unordered_map<std::string,CachedContrSeq> cached_contr_seqs_; //tensor network name --> optimized tensor contraction sequence
};
using createContractionSeqOptimizerFn = std::unique_ptr<ContractionSeqOptimizer> (*)(void);
......
/** ExaTN::Numerics: Tensor network
REVISION: 2020/06/02
REVISION: 2020/07/08
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -509,12 +509,12 @@ double TensorNetwork::determineContractionSequence(ContractionSeqOptimizer & con
}
void TensorNetwork::importContractionSequence(const std::list<ContrTriple> & contr_sequence)
void TensorNetwork::importContractionSequence(const std::list<ContrTriple> & contr_sequence, double fma_flops)
{
assert(finalized_ != 0); //tensor network must be in finalized state
contraction_seq_.clear();
contraction_seq_ = contr_sequence;
contraction_seq_flops_ = 0.0; //flop count is unknown yet
contraction_seq_flops_ = fma_flops; //flop count may be unknown yet (defaults to zero)
max_intermediate_presence_volume_ = 0.0; //max cumulative volume of intermediates present at a time
max_intermediate_volume_ = 0.0; //max intermediate tensor volume is unknown yet
max_intermediate_rank_ = 0; //max intermediate tensor rank
......@@ -522,8 +522,9 @@ void TensorNetwork::importContractionSequence(const std::list<ContrTriple> & con
}
const std::list<ContrTriple> & TensorNetwork::exportContractionSequence() const
const std::list<ContrTriple> & TensorNetwork::exportContractionSequence(double * fma_flops) const
{
if(fma_flops != nullptr) *fma_flops = contraction_seq_flops_;
return contraction_seq_;
}
......
/** ExaTN::Numerics: Tensor network
REVISION: 2020/06/01
REVISION: 2020/07/08
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -324,10 +324,11 @@ public:
bool adjust_cost = false); //in: whether or not to adjust the flops cost due to arithmetic intensity
/** Imports and caches an externally provided tensor contraction sequence. **/
void importContractionSequence(const std::list<ContrTriple> & contr_sequence);
void importContractionSequence(const std::list<ContrTriple> & contr_sequence, //in: imported tensor contraction sequence
double fma_flops = 0.0); //in: FMA flop count for the imported tensor contraction sequence
/** Returns the currently stored tensor contraction sequence, if any. **/
const std::list<ContrTriple> & exportContractionSequence() const;
const std::list<ContrTriple> & exportContractionSequence(double * fma_flops = nullptr) const; //out: FMA flop count for the imported tensor contraction sequence
/** Returns the list of tensor operations required for evaluating the tensor network.
Parameter universal_indices set to TRUE will activate the universal index numeration
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment