Commit b3d74616 authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

TensorExatn::TensorBlobType is now safer

parent 90dbdf64
Loading
Loading
Loading
Loading
+28 −7
Original line number Diff line number Diff line
@@ -40,7 +40,29 @@ public:
	typedef PsimagLite::Vector<SizeType>::Type VectorSizeType;
	typedef typename PsimagLite::Vector<ComplexOrRealType>::Type VectorComplexOrRealType;
	typedef std::pair<PsimagLite::String, SizeType> PairStringSizeType;
	typedef ComplexOrRealType* TensorBlobType;

	class BlobAndSizeConst {

	public:

		BlobAndSizeConst() : size_(0), ptr_(0) {}

		BlobAndSizeConst(SizeType size, const ComplexOrRealType* ptr)
		    : size_(size), ptr_(ptr)
		{}

		const ComplexOrRealType& operator[](SizeType ind) const
		{
			assert(ptr_);
			assert(ind < size_);
			return ptr_[ind];
		}

		SizeType size_;
		const ComplexOrRealType* ptr_;
	};

	typedef BlobAndSizeConst TensorBlobType;

	// Tensor with only one dimension
	Tensor(PsimagLite::String name, SizeType dim0, SizeType ins)
@@ -225,8 +247,8 @@ public:
		// Return tensor data_ at args, const version

		SizeType index = pack(args);
		const ComplexOrRealType* ptr = this->data();
		return ptr[index];
		TensorBlobType tbt = this->data();
		return tbt[index];
	}

	void setValue(const VectorSizeType& args, const ComplexOrRealType& val)
@@ -241,17 +263,16 @@ public:
		ptr[index] = val;
	}

	const ComplexOrRealType* data() const
	TensorBlobType data() const
	{

		std::shared_ptr<talsh::Tensor> ptr = exatn::getLocalTensor(name_);
		const ComplexOrRealType** ptr2 = 0;
		bool ret = ptr->getDataAccessHostConst(ptr2);
		checkTalshErrorCode(ret, "getLocalTensor");
		return *ptr2;
		return TensorBlobType(ptr->getVolume(), *ptr2);
	}

	void setData(const ComplexOrRealType* data)
	void setData(const TensorBlobType& data)
	{
		std::shared_ptr<talsh::Tensor> ptr = exatn::getLocalTensor(name_);
		ComplexOrRealType** ptr2 = 0;
+1 −6
Original line number Diff line number Diff line
@@ -67,7 +67,6 @@ public:
	LanczosSolverType;
	typedef typename TensorType::TensorBlobType TensorBlobType;
	typedef typename TensorEvalSlowType::SymmetryLocalType SymmetryLocalType;
	typedef typename PsimagLite::Stack<TensorBlobType>::Type StackVectorType;

	TensorOptimizer(IoInType& io,
	                PsimagLite::String nameToOptimize,
@@ -455,11 +454,7 @@ private:
	SymmetryLocalType* symmLocal_;
	bool verbose_;
	PsimagLite::Random48<double> rng_;
#ifndef NO_EXATN
	const ComplexOrRealType* savedTensor_;
#else
	VectorType savedTensor_;
#endif
	TensorBlobType savedTensor_;
}; // class TensorOptimizer
} // namespace Mera
#endif // TENSOROPTIMIZER_H