Commit 8cb327e0 authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

TensorSlow: fixed for evaluator changes

parent 7ad755a5
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -311,9 +311,7 @@ private:
			assert(ind < energyTerms_.size());
			SrepStatementType* ptr = energyTerms_[ind];
			if (!ptr) return 0.0;
			TensorEvalType tensorEval(*ptr, tensors_);
//			                         nameToIndexLut_,
	//		                         symmLocal_);
			TensorEvalType tensorEval(*ptr, tensors_, nameToIndexLut_, symmLocal_);

			typename TensorEvalType::HandleType handle = tensorEval();
			while (!handle.done());
+1 −1
Original line number Diff line number Diff line
@@ -121,7 +121,7 @@ public:
		outputTensor(eq).setSizes(dimensions);

		// evaluate environment
		TensorEvalType tensorEval(eq, tensors_);
		TensorEvalType tensorEval(eq, tensors_, nameToIndexLut_, symmLocal_);

		typename TensorEvalType::HandleType handle = tensorEval();

+0 −29
Original line number Diff line number Diff line
@@ -23,33 +23,4 @@ along with MERA++. If not, see <http://www.gnu.org/licenses/>.
#include "TensorEvalSlow.h"
#endif

//#include "Tensor.h"
//#include "SrepStatement.h"
//#include "TensorEvalHandle.h"

//namespace Mera {

//template<typename ComplexOrRealType>
//class TensorEvalBase {

//public:

//	typedef TensorEvalHandle HandleType;
//	typedef Tensor<ComplexOrRealType> TensorType;
//	typedef typename PsimagLite::Vector<TensorType*>::Type VectorTensorType;
//	typedef typename PsimagLite::Vector<SizeType>::Type VectorSizeType;
//	typedef SrepStatement<ComplexOrRealType> SrepStatementType;
//	typedef typename SrepStatementType::PairStringSizeType PairStringSizeType;
//	typedef std::map<PairStringSizeType,SizeType> MapPairStringSizeType;
//	typedef typename PsimagLite::Vector<PairStringSizeType>::Type VectorPairStringSizeType;

//	virtual ~TensorEvalBase() {}

//	virtual HandleType operator()() = 0;

//	virtual void printResult(std::ostream& os) const = 0;

//	virtual SizeType nameToIndexLut(PsimagLite::String) = 0;
//};
//} // namespace Mera
#endif // TENSOREVALBASE_H
+26 −25
Original line number Diff line number Diff line
@@ -21,39 +21,40 @@ along with MERA++. If not, see <http://www.gnu.org/licenses/>.
#include "TensorSrep.h"
#include <map>
#include "TensorBreakup.h"
#include "TensorEvalBase.h"
#include "SymmetryLocal.h"
#include "BLAS.h"
#include "PsimagLite.h"
#include "NameToIndexLut.h"
#include "Tensor.h"
#include "SrepStatement.h"
#include "TensorEvalHandle.h"

namespace Mera {

template<typename ComplexOrRealType>
class TensorEvalSlow : public TensorEvalBase<ComplexOrRealType> {
class TensorEval {

	typedef TensorSrep TensorSrepType;

public:

	typedef TensorEvalBase<ComplexOrRealType> TensorEvalBaseType;
	typedef typename TensorEvalBaseType::SrepStatementType SrepStatementType;
	typedef typename TensorEvalBaseType::HandleType HandleType;
	typedef typename TensorEvalBaseType::TensorType TensorType;
	typedef typename TensorEvalBaseType::VectorTensorType VectorTensorType;
	typedef typename TensorEvalBaseType::VectorSizeType VectorSizeType;
	typedef typename TensorEvalBaseType::PairStringSizeType PairStringSizeType;
	typedef typename TensorEvalBaseType::VectorPairStringSizeType VectorPairStringSizeType;
	typedef typename TensorEvalBaseType::MapPairStringSizeType MapPairStringSizeType;
	typedef TensorEvalHandle HandleType;
	typedef Tensor<ComplexOrRealType> TensorType;
	typedef typename PsimagLite::Vector<TensorType*>::Type VectorTensorType;
	typedef typename PsimagLite::Vector<SizeType>::Type VectorSizeType;
	typedef SrepStatement<ComplexOrRealType> SrepStatementType;
	typedef typename SrepStatementType::PairStringSizeType PairStringSizeType;
	typedef std::map<PairStringSizeType,SizeType> MapPairStringSizeType;
	typedef typename PsimagLite::Vector<PairStringSizeType>::Type VectorPairStringSizeType;
	typedef SymmetryLocal SymmetryLocalType;
	typedef typename PsimagLite::Vector<SrepStatementType*>::Type VectorSrepStatementType;
	typedef TensorBreakup::VectorStringType VectorStringType;
	typedef typename TensorType::MatrixType MatrixType;
	typedef SymmetryLocal SymmetryLocalType;
	typedef SymmetryLocalType::VectorVectorSizeType VectorVectorSizeType;

	static const SizeType EVAL_BREAKUP = TensorBreakup::EVAL_BREAKUP;

	TensorEvalSlow(const SrepStatementType& tSrep,
	TensorEval(const SrepStatementType& tSrep,
	           const VectorTensorType& vt,
	           NameToIndexLut<TensorType>& nameToIndexLUT,
	           SymmetryLocalType* symmLocal,
@@ -105,7 +106,7 @@ public:
				veqs[j]->canonicalize();
			veqs[j]->rhs().simplify(empty);

			TensorEvalSlow tEval(*(veqs[j]),
			TensorEval tEval(*(veqs[j]),
			                 data_,
			                 nameToIndexLUT_,
			                 symmLocal_,
@@ -122,7 +123,7 @@ public:
		}
	}

	~TensorEvalSlow()
	~TensorEval()
	{
		for (SizeType i = 0; i < garbage_.size(); ++i) {
			delete garbage_[i];
@@ -652,9 +653,9 @@ private:
		return buffer;
	}

	TensorEvalSlow(const TensorEvalSlow& other);
	TensorEval(const TensorEval& other);

	TensorEvalSlow& operator=(const TensorEvalSlow& other);
	TensorEval& operator=(const TensorEval& other);

	SrepStatementType srepStatement_;
	VectorTensorType data_;
+2 −4
Original line number Diff line number Diff line
@@ -47,10 +47,8 @@ int main(int argc, char **argv)
	vt[1]->setToConstant(1.5);

	Mera::SrepStatement<double> srepEq(str);
	TensorEvalType tensorEval(srepEq, vt);
	        //Mera::NameToIndexLut<TensorType> nameToIndexLut(vt);
	//tensorEval = new TensorEvalSlowType(srepEq, vt, nameToIndexLut, 0,false);

	Mera::NameToIndexLut<TensorType> nameToIndexLut(vt);
	TensorEvalType tensorEval(srepEq, vt, nameToIndexLut, 0, false);
	TensorEvalType::HandleType handle = tensorEval();

	while (!handle.done());