Commit 8db39908 authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh

ContractionSeqOptimizer.determineContractionSequence() now accepts

lambda-based integer generator.
parent 60f04e4e
/** ExaTN::Numerics: Tensor contraction sequence optimizer
REVISION: 2019/09/08
REVISION: 2019/09/09
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -14,6 +14,7 @@ Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include <list>
#include <memory>
#include <functional>
namespace exatn{
......@@ -33,9 +34,13 @@ class ContractionSeqOptimizer{
public:
/** Determines the pseudo-optimal tensor contraction sequence required for
evaluating a given tensor network. The unique intermediate tensor id's are generated
by the provided intermediate number generator (each invocation returns a new id).
The latter can be conveniently passed as a lambda closure. **/
virtual double determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq,
unsigned int intermediate_num_begin) = 0;
std::function<unsigned int ()> intermediate_num_generator) = 0;
};
using createContractionSeqOptimizerFn = std::unique_ptr<ContractionSeqOptimizer> (*)(void);
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Dummy
REVISION: 2019/09/08
REVISION: 2019/09/09
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -13,7 +13,7 @@ namespace numerics{
double ContractionSeqOptimizerDummy::determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq,
unsigned int intermediate_num_begin)
std::function<unsigned int ()> intermediate_num_generator)
{
contr_seq.clear();
double flops = 0.0;
......@@ -32,11 +32,12 @@ double ContractionSeqOptimizerDummy::determineContractionSequence(const TensorNe
contr_seq.emplace_back(ContrTriple{0,curr_tensor,prev_tensor});
flops += net.getContractionCost(curr_tensor,prev_tensor);
}else{ //intermediate tensor contraction
contr_seq.emplace_back(ContrTriple{intermediate_num_begin,curr_tensor,prev_tensor});
auto intermediate_num = intermediate_num_generator();
contr_seq.emplace_back(ContrTriple{intermediate_num,curr_tensor,prev_tensor});
flops += net.getContractionCost(curr_tensor,prev_tensor);
auto merged = net.mergeTensors(curr_tensor,prev_tensor,intermediate_num_begin);
auto merged = net.mergeTensors(curr_tensor,prev_tensor,intermediate_num);
assert(merged);
prev_tensor = intermediate_num_begin++;
prev_tensor = intermediate_num;
}
}
}
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Dummy
REVISION: 2019/09/08
REVISION: 2019/09/09
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -22,7 +22,7 @@ public:
virtual double determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq,
unsigned int intermediate_num_begin) override;
std::function<unsigned int ()> intermediate_num_generator) override;
static std::unique_ptr<ContractionSeqOptimizer> createNew();
};
......
......@@ -16,7 +16,7 @@ static constexpr unsigned int NUM_WALKERS = 1024; //default number of walkers fo
double ContractionSeqOptimizerHeuro::determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq,
unsigned int intermediate_num_begin)
std::function<unsigned int ()> intermediate_num_generator)
{
contr_seq.clear();
double flops = 0.0;
......
......@@ -22,7 +22,7 @@ public:
virtual double determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq,
unsigned int intermediate_num_begin) override;
std::function<unsigned int ()> intermediate_num_generator) override;
static std::unique_ptr<ContractionSeqOptimizer> createNew();
};
......
......@@ -196,6 +196,16 @@ unsigned int TensorNetwork::getNumTensors() const
}
unsigned int TensorNetwork::getMaxTensorId() const
{
unsigned int max_id = 0;
for(const auto & kv: tensors_){
if(kv.first > max_id) max_id = kv.first;
}
return max_id;
}
const std::string & TensorNetwork::getName() const
{
return name_;
......@@ -314,7 +324,9 @@ double TensorNetwork::determineContractionSequence(ContractionSeqOptimizer & con
{
assert(finalized_ != 0); //tensor network must be in finalized state
if(contraction_seq_.empty()){
contraction_seq_flops_ = contr_seq_optimizer.determineContractionSequence(*this,contraction_seq_,this->getNumTensors()+1);
auto intermediate_num_begin = this->getMaxTensorId() + 1;
auto intermediate_num_generator = [intermediate_num_begin]() mutable {return intermediate_num_begin++;};
contraction_seq_flops_ = contr_seq_optimizer.determineContractionSequence(*this,contraction_seq_,intermediate_num_generator);
}
return contraction_seq_flops_;
}
......
......@@ -72,7 +72,6 @@ namespace numerics{
class TensorNetwork{
public:
using ContractionSequence = std::vector<std::pair<unsigned int, unsigned int>>; //pairs of contracted tensor id's
using Iterator = typename std::unordered_map<unsigned int, TensorConn>::iterator; //iterator
using ConstIterator = typename std::unordered_map<unsigned int, TensorConn>::const_iterator; //constant iterator
......@@ -115,6 +114,9 @@ public:
Note that the output tensor (tensor #0) is not counted here. **/
unsigned int getNumTensors() const;
/** Returns the maximal tensor id value used in the tensor network. **/
unsigned int getMaxTensorId() const;
/** Returns the name of the tensor network. **/
const std::string & getName() const;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment