Commit 4215b165 authored by Dmitry I. Lyakh's avatar Dmitry I. Lyakh

Implemented ContractionSeqOptimizer factory.

parent 4e88b67f
......@@ -22,6 +22,7 @@ add_library(${LIBRARY_NAME}
network_builder_mps.cpp
network_build_factory.cpp
contraction_seq_optimizer_dummy.cpp
contraction_seq_optimizer_factory.cpp
tensor_network.cpp)
target_include_directories(${LIBRARY_NAME}
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer
REVISION: 2019/09/04
REVISION: 2019/09/05
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -13,6 +13,7 @@ Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include "tensor_basic.hpp"
#include <list>
#include <memory>
namespace exatn{
......@@ -36,6 +37,8 @@ public:
std::list<ContrTriple> & contr_seq) = 0;
};
using createContractionSeqOptimizerFn = std::unique_ptr<ContractionSeqOptimizer> (*)(void);
} //namespace numerics
} //namespace exatn
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Dummy
REVISION: 2019/09/04
REVISION: 2019/09/05
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -19,6 +19,12 @@ double ContractionSeqOptimizerDummy::determineContractionSequence(const TensorNe
return flops;
}
std::unique_ptr<ContractionSeqOptimizer> ContractionSeqOptimizerDummy::createNew()
{
return std::unique_ptr<ContractionSeqOptimizer>(new ContractionSeqOptimizerDummy());
}
} //namespace numerics
} //namespace exatn
/** ExaTN::Numerics: Tensor contraction sequence optimizer: Dummy
REVISION: 2019/09/04
REVISION: 2019/09/05
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -22,6 +22,8 @@ public:
virtual double determineContractionSequence(const TensorNetwork & network,
std::list<ContrTriple> & contr_seq) override;
static std::unique_ptr<ContractionSeqOptimizer> createNew();
};
} //namespace numerics
......
/** ExaTN::Numerics: Tensor contraction sequence optimizer factory
REVISION: 2019/09/05
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include "contraction_seq_optimizer_factory.hpp"
namespace exatn{
namespace numerics{
ContractionSeqOptimizerFactory::ContractionSeqOptimizerFactory()
{
registerContractionSeqOptimizer("dummy",&ContractionSeqOptimizerDummy::createNew);
}
void ContractionSeqOptimizerFactory::registerContractionSeqOptimizer(const std::string & name,
createContractionSeqOptimizerFn creator)
{
factory_map_[name] = creator;
return;
}
std::unique_ptr<ContractionSeqOptimizer> ContractionSeqOptimizerFactory::createContractionSeqOptimizer(const std::string & name)
{
auto it = factory_map_.find(name);
if(it != factory_map_.end()) return (it->second)();
return std::unique_ptr<ContractionSeqOptimizer>(nullptr);
}
ContractionSeqOptimizerFactory * ContractionSeqOptimizerFactory::get()
{
static ContractionSeqOptimizerFactory single_instance;
return &single_instance;
}
} //namespace numerics
} //namespace exatn
/** ExaTN::Numerics: Tensor contraction sequence optimizer factory
REVISION: 2019/09/05
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
/** Rationale:
(a) Creates tensor contraction sequence optimizers of desired kind.
**/
#ifndef EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_FACTORY_HPP_
#define EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_FACTORY_HPP_
#include "tensor_basic.hpp"
#include "contraction_seq_optimizer.hpp"
#include "contraction_seq_optimizer_dummy.hpp"
#include <string>
#include <memory>
#include <map>
namespace exatn{
namespace numerics{
class ContractionSeqOptimizerFactory{
public:
ContractionSeqOptimizerFactory(const ContractionSeqOptimizerFactory &) = delete;
ContractionSeqOptimizerFactory & operator=(const ContractionSeqOptimizerFactory &) = delete;
ContractionSeqOptimizerFactory(ContractionSeqOptimizerFactory &&) noexcept = default;
ContractionSeqOptimizerFactory & operator=(ContractionSeqOptimizerFactory &&) noexcept = default;
~ContractionSeqOptimizerFactory() = default;
/** Registers a new tensor contraction optimizer subtype to produce instances of. **/
void registerContractionSeqOptimizer(const std::string & name, createContractionSeqOptimizerFn creator);
/** Creates a new instance of a desired subtype. **/
std::unique_ptr<ContractionSeqOptimizer> createContractionSeqOptimizer(const std::string & name);
/** Returns a pointer to the ContractionSeqOptimizerFactory singleton. **/
static ContractionSeqOptimizerFactory * get();
private:
ContractionSeqOptimizerFactory(); //private ctor
std::map<std::string,createContractionSeqOptimizerFn> factory_map_;
};
} //namespace numerics
} //namespace exatn
#endif //EXATN_NUMERICS_CONTRACTION_SEQ_OPTIMIZER_FACTORY_HPP_
/** ExaTN::Numerics: Tensor network
REVISION: 2019/09/04
REVISION: 2019/09/05
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -13,11 +13,16 @@ Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
#include <string>
#include <vector>
#include <map>
#include <memory>
namespace exatn{
namespace numerics{
//Tensor contraction sequence optmizers:
std::map<std::string,std::shared_ptr<ContractionSeqOptimizer>> optimizers;
TensorNetwork::TensorNetwork():
explicit_output_(0), finalized_(0), contraction_seq_flops_(0.0)
{
......@@ -734,6 +739,23 @@ bool TensorNetwork::mergeTensors(unsigned int left_id, unsigned int right_id, un
return true;
}
double TensorNetwork::getContractionCost(unsigned int left_id, unsigned int right_id,
double * arithm_intensity, bool adjust_cost)
{
double flops = 0.0;
//`Finish
return flops;
}
std::list<std::shared_ptr<TensorOperation>> TensorNetwork::getOperationList(const std::string & contr_seq_opt_name)
{
std::list<std::shared_ptr<TensorOperation>> ops;
//`Finish
return ops;
}
} //namespace numerics
} //namespace exatn
/** ExaTN::Numerics: Tensor network
REVISION: 2019/09/04
REVISION: 2019/09/05
Copyright (C) 2018-2019 Dmitry I. Lyakh (Liakh)
Copyright (C) 2018-2019 Oak Ridge National Laboratory (UT-Battelle) **/
......@@ -178,6 +178,18 @@ public:
unsigned int right_id, //in: right tensor id (present in the tensor network)
unsigned int result_id); //in: result tensor id (absent in the tensor network, to be appended)
/** Returns the FMA flop count for a given contraction of two tensors identified by their ids
in the tensor network. Optionally returns the arithmetic intensity of the tensor contraction as well.
Additionally, it also allows rescaling of the tensor contraction cost with the adjustment
by the arithmetic intensity (lower arithmetic intensity will increase the cost). **/
double getContractionCost(unsigned int left_id, //in: left tensor id (present in the tensor network)
unsigned int right_id, //in: right tensor id (present in the tensor network)
double * arithm_intensity = nullptr, //out: arithmetic intensity of the tensor contraction
bool adjust_cost = false); //in: whether or not to adjust the flops cost due to arithmetic intensity
/** Returns the list of tensor operations required for evaluating the tensor network. **/
std::list<std::shared_ptr<TensorOperation>> getOperationList(const std::string & contr_seq_opt_name = "dummy");
protected:
/** Returns a non-owning pointer to a given tensor of the tensor network
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment