Commit 704f1a53 authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh

Implemented infrastructure for tensor contraction sequence caching.

parent 1cfeba60
......@@ -31,6 +31,7 @@ add_library(${LIBRARY_NAME}
network_builder_mps.cpp
network_builder_tree.cpp
network_build_factory.cpp
contraction_seq_optimizer.cpp
contraction_seq_optimizer_dummy.cpp
contraction_seq_optimizer_heuro.cpp
contraction_seq_optimizer_greed.cpp
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Base
REVISION: 2020/07/06
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
#include "contraction_seq_optimizer.hpp"
#include "tensor_network.hpp"
#include "metis_graph.hpp"
namespace exatn{
namespace numerics{
std::unordered_map<std::string,std::pair<MetisGraph,std::list<ContrTriple>>> ContractionSeqOptimizer::cached_contr_seqs_;
bool ContractionSeqOptimizer::cacheContractionSequence(const TensorNetwork & network)
{
auto res = cached_contr_seqs_.emplace(network.getName(),
std::make_pair(MetisGraph(network),network.exportContractionSequence()));
return res.second;
}
bool ContractionSeqOptimizer::eraseContractionSequence(const TensorNetwork & network)
{
auto num_deleted = cached_contr_seqs_.erase(network.getName());
return (num_deleted == 1);
}
const std::list<ContrTriple> * ContractionSeqOptimizer::findContractionSequence(const TensorNetwork & network)
{
auto iter = cached_contr_seqs_.find(network.getName());
if(iter != cached_contr_seqs_.end()){
MetisGraph network_graph(network);
if(network_graph == iter->second.first) return &(iter->second.second);
};
return nullptr;
}
} //namespace numerics
} //namespace exatn
/** ExaTN::Numerics: Tensor contraction sequence optimizer
REVISION: 2019/11/08
REVISION: 2020/07/06
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
/** Rationale:
**/
......@@ -14,6 +14,7 @@ Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include <list>
#include <memory>
#include <unordered_map>
#include <functional>
namespace exatn{
......@@ -28,6 +29,7 @@ struct ContrTriple{
};
class TensorNetwork;
class MetisGraph;
class ContractionSeqOptimizer{
......@@ -46,6 +48,27 @@ public:
virtual double determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq,
std::function<unsigned int ()> intermediate_num_generator) = 0;
/** Caches the determined pseudo-optimal tensor contraction sequence for a given
tensor network for a later retrieval for the same tensor networks. Returns TRUE
on success, FALSE in case this tensor network has already been cached before. **/
static bool cacheContractionSequence(const TensorNetwork & network); //in: tensor network with a determined tensor contraction sequence
/** Erases the previously cached tensor contraction sequence for a given tensor
network and returns TRUE, or returns FALSE in case it has not been cached before. **/
static bool eraseContractionSequence(const TensorNetwork & network); //in: tensor network
/** Retrieves a previously cached tensor contraction sequence for a given tensor
network. Returns nullptr in case it has not been cached before. **/
static const std::list<ContrTriple> * findContractionSequence(const TensorNetwork & network); //in: tensor network
private:
/** Cached tensor contraction sequences. **/
static std::unordered_map<std::string, //tensor network name
std::pair<MetisGraph, //METIS graph of the tensor network
std::list<ContrTriple> //tensor contraction sequence
>> cached_contr_seqs_;
};
using createContractionSeqOptimizerFn = std::unique_ptr<ContractionSeqOptimizer> (*)(void);
......
/** ExaTN::Numerics: Graph k-way partitioning via METIS
REVISION: 2020/05/19
REVISION: 2020/07/06
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -11,6 +11,7 @@ Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
#include <iostream>
#include <unordered_map>
#include <algorithm>
#include <tuple>
#include <cmath>
#include <cassert>
......@@ -191,6 +192,19 @@ MetisGraph::MetisGraph(const MetisGraph & parent, //in: parti
}
bool operator==(const MetisGraph & lhs, const MetisGraph & rhs)
{
return std::tie(lhs.num_vertices_,lhs.xadj_,lhs.adjncy_,lhs.vwgt_,lhs.adjwgt_,lhs.renumber_)
== std::tie(rhs.num_vertices_,rhs.xadj_,rhs.adjncy_,rhs.vwgt_,rhs.adjwgt_,rhs.renumber_);
}
bool operator!=(const MetisGraph & lhs, const MetisGraph & rhs)
{
return !(lhs == rhs);
}
void MetisGraph::clearPartitions()
{
tpwgts_.clear();
......
/** ExaTN::Numerics: Graph k-way partitioning via METIS
REVISION: 2020/05/19
REVISION: 2020/07/06
Copyright (C) 2018-2020 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2020 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -63,6 +63,10 @@ public:
MetisGraph & operator=(MetisGraph &&) noexcept = default;
~MetisGraph() = default;
/** Compares two METIS graphs for equality/inequality. **/
friend bool operator==(const MetisGraph & lhs, const MetisGraph & rhs);
friend bool operator!=(const MetisGraph & lhs, const MetisGraph & rhs);
/** Clears the current partitioning. **/
void clearPartitions();
......@@ -142,6 +146,9 @@ private:
idx_t num_cross_edges_; //number of cross edges in the edge cut
};
bool operator==(const MetisGraph & lhs, const MetisGraph & rhs);
bool operator!=(const MetisGraph & lhs, const MetisGraph & rhs);
} //namespace numerics
} //namespace exatn
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment