Commit 2a35ec9c authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh

Implemented TensorNetwork::getContractionCost, working on operation list generation ...

parent 4215b165
......@@ -15,7 +15,7 @@ double ContractionSeqOptimizerDummy::determineContractionSequence(const TensorNe
std::list<ContrTriple> & contr_seq)
{
double flops = 0.0;
//`Finish
//`Finish: Requires TensorNetwork::iterator
return flops;
}
......
......@@ -6,6 +6,7 @@ Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include "tensor_network.hpp"
#include "tensor_symbol.hpp"
#include "contraction_seq_optimizer_factory.hpp"
#include <iostream>
#include <assert.h>
......@@ -300,6 +301,15 @@ void TensorNetwork::updateConnections(unsigned int tensor_id)
}
void TensorNetwork::invalidateContractionSequence()
{
operations_.clear();
contraction_seq_.clear();
contraction_seq_flops_ = 0.0;
return;
}
double TensorNetwork::determineContractionSequence(ContractionSeqOptimizer & contr_seq_optimizer)
{
assert(finalized_ != 0); //tensor network must be in finalized state
......@@ -465,7 +475,7 @@ bool TensorNetwork::appendTensor(unsigned int tensor_id,
return false;
}
}
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
invalidateContractionSequence(); //invalidate previously cached tensor contraction sequence
finalized_ = 1; //implicit leg pairing always keeps the tensor network in a finalized state
return true;
}
......@@ -561,7 +571,7 @@ bool TensorNetwork::appendTensorNetwork(TensorNetwork && network,
);
}
this->updateConnections(0); //update connections in just appended input tensors
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
invalidateContractionSequence(); //invalidate previously cached tensor contraction sequence
finalized_ = 1; //implicit leg pairing always keeps the primary tensor network in a finalized state
return true;
}
......@@ -656,7 +666,7 @@ bool TensorNetwork::deleteTensor(unsigned int tensor_id)
tensor = nullptr;
auto num_deleted = tensors_.erase(tensor_id);
assert(num_deleted == 1);
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
invalidateContractionSequence(); //invalidate previously cached tensor contraction sequence
return true;
}
......@@ -735,7 +745,7 @@ bool TensorNetwork::mergeTensors(unsigned int left_id, unsigned int right_id, un
assert(num_deleted == 1);
//Update connections:
this->updateConnections(result_id);
contraction_seq_.clear(); //invalidate previously cached tensor contraction sequence
invalidateContractionSequence(); //invalidate previously cached tensor contraction sequence
return true;
}
......@@ -743,8 +753,33 @@ bool TensorNetwork::mergeTensors(unsigned int left_id, unsigned int right_id, un
double TensorNetwork::getContractionCost(unsigned int left_id, unsigned int right_id,
double * arithm_intensity, bool adjust_cost)
{
double flops = 0.0;
//`Finish
double flops = 0.0, left_vol = 1.0, right_vol = 1.0, contr_vol = 1.0;
if(left_id != 0 && right_id != 0){
const auto * left_tensor = this->getTensorConn(left_id);
assert(left_tensor != nullptr);
const auto left_rank = left_tensor->getNumLegs();
const auto * right_tensor = this->getTensorConn(right_id);
assert(right_tensor != nullptr);
const auto right_rank = right_tensor->getNumLegs();
const auto & right_legs = right_tensor->getTensorLegs();
for(unsigned int i = 0; i < left_rank; ++i){
left_vol *= static_cast<double>(left_tensor->getDimExtent(i));
}
for(unsigned int i = 0; i < right_rank; ++i){
double dim_ext = static_cast<double>(right_tensor->getDimExtent(i));
if(right_legs[i].getTensorId() == left_id) contr_vol *= dim_ext; //contracted dimension
right_vol *= dim_ext;
}
flops = left_vol * right_vol / contr_vol;
if(arithm_intensity != nullptr) *arithm_intensity = flops / (left_vol + right_vol);
if(adjust_cost){ //increase the "effective" flop count if arithmetic intensity is low
//`Finish: flops *= f(arithm_intensity): [max --> 1]
}
}else{
std::cout << "#ERROR(TensorNetwork::getContractionCost): Invalid request: " <<
"The output tensor of the tensor network (tensor 0) cannot be contracted!" << std::endl;
flops = -1.0; //error
}
return flops;
}
......@@ -752,7 +787,23 @@ double TensorNetwork::getContractionCost(unsigned int left_id, unsigned int righ
std::list<std::shared_ptr<TensorOperation>> TensorNetwork::getOperationList(const std::string & contr_seq_opt_name)
{
std::list<std::shared_ptr<TensorOperation>> ops;
//`Finish
auto iter = optimizers.find(contr_seq_opt_name);
if(iter == optimizers.end()){ //not cached
auto & optimizer_factory = *(ContractionSeqOptimizerFactory::get());
auto optimizer = optimizer_factory.createContractionSeqOptimizer(contr_seq_opt_name);
if(optimizer){
auto res = optimizers.emplace(std::make_pair(contr_seq_opt_name,
std::shared_ptr<ContractionSeqOptimizer>(std::move(optimizer))));
assert(res.second);
iter = res.first;
}else{
std::cout << "#ERROR(TensorNetwork::getOperationList): Invalid request: " <<
"Tensor contraction sequence optimizer" << contr_seq_opt_name << "has not been registered before!" << std::endl;
assert(false);
}
}
double flops = determineContractionSequence(*(iter->second));
//`Finish: Copy tensor network, contract according to the sequence, record each contraction as operation
return ops;
}
......
......@@ -208,6 +208,9 @@ protected:
/** Updates tensor network linking when a tensor has its connections modified. **/
void updateConnections(unsigned int tensor_id); //in: id of the tensor whose connections were modified
/** Invalidates cached tensor contraction sequence. **/
void invalidateContractionSequence();
/** Determines a pseudo-optimal tensor contraction sequence required for evaluating the tensor network.
Returns an estimate of the total flop count required by the returned contraction sequence. **/
double determineContractionSequence(ContractionSeqOptimizer & contr_seq_optimizer);
......@@ -219,8 +222,9 @@ private:
std::string name_; //tensor network name
std::unordered_map<unsigned int, TensorConn> tensors_; //tensors connected to each other via legs (tensor connections)
//map: Non-negative tensor id --> Connected tensor
std::list<ContrTriple> contraction_seq_; //cached tensor contraction sequence
double contraction_seq_flops_; //flop estimate for the determined tensor contraction sequence
std::list<ContrTriple> contraction_seq_; //cached tensor contraction sequence
std::list<std::shared_ptr<TensorOperation>> operations_; //cached tensor operations required for evaluating the tensor network
};
} //namespace numerics
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment