Commit 7ad755a5 authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

TensorEval option is now compile time

parent 19721cb4
Loading
Loading
Loading
Loading
+12 −21
Original line number Diff line number Diff line
@@ -20,8 +20,6 @@ along with MERA++. If not, see <http://www.gnu.org/licenses/>.
#define MERASOLVER_H
#include "InputNg.h"
#include "TensorSrep.h"
#include "TensorEvalSlow.h"
#include "TensorEvalNew.h"
#include "TensorOptimizer.h"
#include "InputCheck.h"
#include "ModelSelector.h"
@@ -44,17 +42,15 @@ class MeraSolver {
	typedef typename TensorOptimizerType::ParametersForSolverType ParametersForSolverType;
	typedef typename TensorOptimizerType::SrepStatementType SrepStatementType;
	typedef typename TensorOptimizerType::VectorSrepStatementType VectorSrepStatementType;
	typedef typename TensorOptimizerType::TensorEvalBaseType TensorEvalBaseType;
	typedef typename TensorEvalBaseType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalBaseType::VectorPairStringSizeType VectorPairStringSizeType;
	typedef typename TensorEvalBaseType::TensorType TensorType;
	typedef typename TensorEvalBaseType::VectorTensorType VectorTensorType;
	typedef typename TensorOptimizerType::TensorEvalType TensorEvalType;
	typedef typename TensorEvalType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalType::VectorPairStringSizeType VectorPairStringSizeType;
	typedef typename TensorEvalType::TensorType TensorType;
	typedef typename TensorEvalType::VectorTensorType VectorTensorType;
	typedef ModelBase<ComplexOrRealType> ModelBaseType;
	typedef ModelSelector<ModelBaseType> ModelType;
	typedef typename TensorOptimizerType::SymmetryLocalType SymmetryLocalType;
	typedef typename TensorOptimizerType::ParametersForMeraType ParametersForMeraType;
	typedef TensorEvalSlow<ComplexOrRealType> TensorEvalSlowType;
	typedef TensorEvalNew<ComplexOrRealType> TensorEvalNewType;
	typedef PsimagLite::Vector<PsimagLite::String>::Type VectorStringType;

	static const int EVAL_BREAKUP = TensorOptimizerType::EVAL_BREAKUP;
@@ -315,20 +311,15 @@ private:
			assert(ind < energyTerms_.size());
			SrepStatementType* ptr = energyTerms_[ind];
			if (!ptr) return 0.0;
			TensorEvalBaseType* tensorEval = TensorOptimizerType::ParallelEnvironHelperType::
			        getTensorEvalPtr(paramsForMera_.evaluator,
			                         *ptr,
			                         tensors_,
			                         nameToIndexLut_,
			                         symmLocal_);
			TensorEvalType tensorEval(*ptr, tensors_);
//			                         nameToIndexLut_,
	//		                         symmLocal_);

			typename TensorEvalBaseType::HandleType handle = tensorEval->operator()();
			typename TensorEvalType::HandleType handle = tensorEval();
			while (!handle.done());

			VectorSizeType args(1,0);
			const SizeType index = tensorEval->nameToIndexLut("e" + ttos(ind));
			delete tensorEval;
			tensorEval = 0;
			const SizeType index = tensorEval.nameToIndexLut("e" + ttos(ind));
			assert(index < tensors_.size());
			return tensors_[index]->operator()(args);
		}
@@ -430,7 +421,7 @@ private:
	{
		TensorSrep tsrep(meraStr_);
		SizeType maxLegs = 2.0*paramsForMera_.hamiltonianConnection.size();
		SymmetryLocal* symmLocal = new SymmetryLocal(tsrep.size(), model_().qOne(), maxLegs);
		SymmetryLocalType* symmLocal = new SymmetryLocalType(tsrep.size(), model_().qOne(), maxLegs);
		DimensionSrep<SymmetryLocal> dimSrep(meraStr_, *symmLocal, paramsForMera_.m);
		PsimagLite::String dsrep = dimSrep() + dsrepEnvirons_;

+29 −38
Original line number Diff line number Diff line
@@ -18,11 +18,10 @@ along with MERA++. If not, see <http://www.gnu.org/licenses/>.
#ifndef PARALLELENVIRONHELPER_H
#define PARALLELENVIRONHELPER_H
#include "Matrix.h"
#include "TensorEvalSlow.h"
#include "TensorEvalBase.h"
#include "TensorEvalNew.h"
#include "Vector.h"
#include "TensorStanza.h"
#include "NameToIndexLut.h"

namespace  Mera {

@@ -31,23 +30,21 @@ class ParallelEnvironHelper {

public:

	typedef TensorEvalBase<ComplexOrRealType> TensorEvalBaseType;
	typedef TensorEvalSlow<ComplexOrRealType> TensorEvalSlowType;
	typedef TensorEvalNew<ComplexOrRealType> TensorEvalNewType;
	typedef TensorEval<ComplexOrRealType> TensorEvalType;
	typedef PsimagLite::Matrix<ComplexOrRealType> MatrixType;
	typedef typename TensorEvalBaseType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalBaseType::TensorType TensorType;
	typedef typename TensorEvalBaseType::VectorTensorType VectorTensorType;
	typedef typename TensorEvalBaseType::SrepStatementType SrepStatementType;
	typedef typename TensorEvalType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalType::TensorType TensorType;
	typedef typename TensorEvalType::VectorTensorType VectorTensorType;
	typedef typename TensorEvalType::SrepStatementType SrepStatementType;
	typedef typename PsimagLite::Vector<SrepStatementType*>::Type VectorSrepStatementType;
	typedef PsimagLite::Vector<TensorStanza::IndexDirectionEnum>::Type VectorDirType;
	typedef PsimagLite::Vector<bool>::Type VectorBoolType;
	typedef PsimagLite::Vector<SizeType>::Type VectorSizeType;
	typedef typename PsimagLite::Vector<MatrixType*>::Type VectorMatrixType;
	typedef std::pair<SizeType,SizeType> PairSizeType;
	typedef typename TensorEvalBaseType::MapPairStringSizeType MapPairStringSizeType;
	typedef typename TensorEvalBaseType::VectorPairStringSizeType VectorPairStringSizeType;
	typedef typename TensorEvalSlowType::SymmetryLocalType SymmetryLocalType;
	typedef typename TensorEvalType::MapPairStringSizeType MapPairStringSizeType;
	typedef typename TensorEvalType::VectorPairStringSizeType VectorPairStringSizeType;
	typedef typename TensorEvalType::SymmetryLocalType SymmetryLocalType;

	ParallelEnvironHelper(VectorSrepStatementType& tensorSrep,
	                      PsimagLite::String evaluator,
@@ -124,17 +121,11 @@ public:
		outputTensor(eq).setSizes(dimensions);

		// evaluate environment
		TensorEvalBaseType* tensorEval =  getTensorEvalPtr(evaluator,
		                                                   eq,
		                                                   tensors_,
		                                                   nameToIndexLut_,
		                                                   symmLocal_);
		TensorEvalType tensorEval(eq, tensors_);

		typename TensorEvalBaseType::HandleType handle = tensorEval->operator()();
		while (!handle.done());
		typename TensorEvalType::HandleType handle = tensorEval();

		delete tensorEval;
		tensorEval = 0;
		while (!handle.done());

		// copy result into m
		SizeType count = 0;
@@ -146,23 +137,23 @@ public:
		} while (ProgramGlobals::nextIndex(freeIndices,dimensions,total));
	}

	static TensorEvalBaseType* getTensorEvalPtr(PsimagLite::String evaluator,
	                                            const SrepStatementType& srep,
	                                            VectorTensorType& tensors,
	                                            NameToIndexLut<TensorType>& nameToIndexLut,
	                                            SymmetryLocalType* symmLocal)
	{
		TensorEvalBaseType* tensorEval = 0;
		if (evaluator == "slow") {
			tensorEval = new TensorEvalSlowType(srep, tensors, nameToIndexLut, symmLocal);
		} else if (evaluator == "new") {
			tensorEval = new TensorEvalNewType(srep, tensors);
		} else {
			throw PsimagLite::RuntimeError("Unknown evaluator " + evaluator + "\n");
		}

		return tensorEval;
	}
//	static TensorEvalBaseType* getTensorEvalPtr(PsimagLite::String evaluator,
//	                                            const SrepStatementType& srep,
//	                                            VectorTensorType& tensors,
//	                                            NameToIndexLut<TensorType>& nameToIndexLut,
//	                                            SymmetryLocalType* symmLocal)
//	{
//		TensorEvalBaseType* tensorEval = 0;
//		if (evaluator == "slow") {
//			tensorEval = new TensorEvalSlowType(srep, tensors, nameToIndexLut, symmLocal);
//		} else if (evaluator == "new") {
//			tensorEval = new TensorEvalNewType(srep, tensors);
//		} else {
//			throw PsimagLite::RuntimeError("Unknown evaluator " + evaluator + "\n");
//		}

//		return tensorEval;
//	}

private:

+27 −21
Original line number Diff line number Diff line
@@ -17,33 +17,39 @@ along with MERA++. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TENSOREVALBASE_H
#define TENSOREVALBASE_H
#include "Tensor.h"
#include "SrepStatement.h"
#include "TensorEvalHandle.h"
#ifndef NO_EXATN
#include "TensorEvalNew.h"
#else
#include "TensorEvalSlow.h"
#endif

namespace Mera {
//#include "Tensor.h"
//#include "SrepStatement.h"
//#include "TensorEvalHandle.h"

template<typename ComplexOrRealType>
class TensorEvalBase {
//namespace Mera {

public:
//template<typename ComplexOrRealType>
//class TensorEvalBase {

	typedef TensorEvalHandle HandleType;
	typedef Tensor<ComplexOrRealType> TensorType;
	typedef typename PsimagLite::Vector<TensorType*>::Type VectorTensorType;
	typedef typename PsimagLite::Vector<SizeType>::Type VectorSizeType;
	typedef SrepStatement<ComplexOrRealType> SrepStatementType;
	typedef typename SrepStatementType::PairStringSizeType PairStringSizeType;
	typedef std::map<PairStringSizeType,SizeType> MapPairStringSizeType;
	typedef typename PsimagLite::Vector<PairStringSizeType>::Type VectorPairStringSizeType;
//public:

	virtual ~TensorEvalBase() {}
//	typedef TensorEvalHandle HandleType;
//	typedef Tensor<ComplexOrRealType> TensorType;
//	typedef typename PsimagLite::Vector<TensorType*>::Type VectorTensorType;
//	typedef typename PsimagLite::Vector<SizeType>::Type VectorSizeType;
//	typedef SrepStatement<ComplexOrRealType> SrepStatementType;
//	typedef typename SrepStatementType::PairStringSizeType PairStringSizeType;
//	typedef std::map<PairStringSizeType,SizeType> MapPairStringSizeType;
//	typedef typename PsimagLite::Vector<PairStringSizeType>::Type VectorPairStringSizeType;

	virtual HandleType operator()() = 0;
//	virtual ~TensorEvalBase() {}

	virtual void printResult(std::ostream& os) const = 0;
//	virtual HandleType operator()() = 0;

	virtual SizeType nameToIndexLut(PsimagLite::String) = 0;
};
} // namespace Mera
//	virtual void printResult(std::ostream& os) const = 0;

//	virtual SizeType nameToIndexLut(PsimagLite::String) = 0;
//};
//} // namespace Mera
#endif // TENSOREVALBASE_H
+17 −13
Original line number Diff line number Diff line
@@ -17,26 +17,30 @@ along with MERA++. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TENSOREVALNEW_H
#define TENSOREVALNEW_H
#include "TensorEvalBase.h"

#include "Tensor.h"
#include "SrepStatement.h"
#include "TensorEvalHandle.h"
#include "SymmetryLocal.h"

namespace Mera {

template<typename ComplexOrRealType>
class TensorEvalNew : public TensorEvalBase<ComplexOrRealType> {

	typedef TensorEvalBase<ComplexOrRealType> TensorEvalBaseType;
	typedef typename TensorEvalBaseType::SrepStatementType SrepStatementType;
	typedef typename TensorEvalBaseType::HandleType HandleType;
	typedef typename TensorEvalBaseType::TensorType TensorType;
	typedef typename TensorEvalBaseType::VectorTensorType VectorTensorType;
	typedef typename TensorEvalBaseType::VectorSizeType VectorSizeType;
	typedef typename TensorEvalBaseType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalBaseType::MapPairStringSizeType MapPairStringSizeType;
	typedef typename TensorEvalBaseType::VectorPairStringSizeType VectorPairStringSizeType;
class TensorEval {

public:

	TensorEvalNew(const SrepStatementType& tSrep,
	typedef TensorEvalHandle HandleType;
	typedef Tensor<ComplexOrRealType> TensorType;
	typedef typename PsimagLite::Vector<TensorType*>::Type VectorTensorType;
	typedef typename PsimagLite::Vector<SizeType>::Type VectorSizeType;
	typedef SrepStatement<ComplexOrRealType> SrepStatementType;
	typedef typename SrepStatementType::PairStringSizeType PairStringSizeType;
	typedef std::map<PairStringSizeType,SizeType> MapPairStringSizeType;
	typedef typename PsimagLite::Vector<PairStringSizeType>::Type VectorPairStringSizeType;
	typedef SymmetryLocal SymmetryLocalType;

	TensorEval(const SrepStatementType& tSrep,
	              const VectorTensorType& vt)
	{}

+10 −11
Original line number Diff line number Diff line
@@ -19,8 +19,7 @@ along with MERA++. If not, see <http://www.gnu.org/licenses/>.
#define TENSOROPTIMIZER_H
#include "Vector.h"
#include "TensorSrep.h"
#include "TensorEvalSlow.h"
#include "TensorEvalNew.h"
#include "TensorEvalBase.h"
#include <algorithm>
#include "Sort.h"
#include "Matrix.h"
@@ -47,26 +46,25 @@ public:

	typedef ParallelEnvironHelper<ComplexOrRealType> ParallelEnvironHelperType;
	typedef ParametersForMera<ComplexOrRealType> ParametersForMeraType;
	typedef TensorEvalBase<ComplexOrRealType> TensorEvalBaseType;
	typedef TensorEvalSlow<ComplexOrRealType> TensorEvalSlowType;
	typedef TensorEval<ComplexOrRealType> TensorEvalType;
	typedef typename PsimagLite::Real<ComplexOrRealType>::Type RealType;
	typedef typename PsimagLite::Vector<RealType>::Type VectorRealType;
	typedef typename TensorEvalBaseType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalBaseType::VectorPairStringSizeType VectorPairStringSizeType;
	typedef typename TensorEvalBaseType::TensorType TensorType;
	typedef typename TensorEvalBaseType::VectorTensorType VectorTensorType;
	typedef typename TensorEvalBaseType::SrepStatementType SrepStatementType;
	typedef typename TensorEvalType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalType::VectorPairStringSizeType VectorPairStringSizeType;
	typedef typename TensorEvalType::TensorType TensorType;
	typedef typename TensorEvalType::VectorTensorType VectorTensorType;
	typedef typename TensorEvalType::SrepStatementType SrepStatementType;
	typedef typename TensorEvalType::MapPairStringSizeType MapPairStringSizeType;
	typedef typename PsimagLite::Vector<SrepStatementType*>::Type VectorSrepStatementType;
	typedef typename ParallelEnvironHelperType::MatrixType MatrixType;
	typedef std::pair<SizeType,SizeType> PairSizeType;
	typedef typename TensorEvalBaseType::MapPairStringSizeType MapPairStringSizeType;
	typedef PsimagLite::ParametersForSolver<RealType> ParametersForSolverType;
	typedef typename PsimagLite::Vector<ComplexOrRealType>::Type VectorType;
	typedef PsimagLite::CrsMatrix<ComplexOrRealType> SparseMatrixType;
	typedef PsimagLite::LanczosSolver<ParametersForSolverType,SparseMatrixType,VectorType>
	LanczosSolverType;
	typedef typename TensorType::TensorBlobType TensorBlobType;
	typedef typename TensorEvalSlowType::SymmetryLocalType SymmetryLocalType;
	typedef typename TensorEvalType::SymmetryLocalType SymmetryLocalType;

	TensorOptimizer(IoInType& io,
	                PsimagLite::String nameToOptimize,
@@ -212,6 +210,7 @@ private:

	void saveTensor()
	{
		assert(indToOptimize_ < tensors_.size());
		savedTensor_ = tensors_[indToOptimize_]->data();
	}

Loading