Commit ce5af02c authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

TensorSlow: fix for setValue

parent 98b23738
Loading
Loading
Loading
Loading
+16 −72
Original line number Diff line number Diff line
@@ -43,10 +43,6 @@ public:
	Tensor(PsimagLite::String name, SizeType dim0, SizeType ins)
	    : name_(name), dimensions_(1, dim0), ins_(ins)
	{
		if (hasExatnBackend_) {
			return;
		}

		data_.resize(dim0, 0.0);
	}

@@ -56,11 +52,6 @@ public:
		SizeType n = dimensions_.size();
		if (n == 0) return;
		assert(0 < n);

		if (hasExatnBackend_) {
			return;
		}

		data_.resize(volume(), 0.0);
	}

@@ -82,10 +73,6 @@ public:
		for (SizeType i = ins; i < dimensions_.size(); ++i)
			douts *= dimensions_[i];

		if (hasExatnBackend_) {
			return;
		}

		for (SizeType x = 0; x < dins; ++x)
			if (x < douts)
				data_[x + x*dins] = value;
@@ -93,10 +80,6 @@ public:

	void setToRandom()
	{
		if (hasExatnBackend_) {
			return;
		}

		SizeType n = data_.size();
		ComplexOrRealType sum = 0.0;
		for (SizeType i = 0; i < n; ++i) {
@@ -114,10 +97,6 @@ public:

	void setToConstant(ComplexOrRealType value)
	{
		if (hasExatnBackend_) {
			return;
		}

		// use std:: function here instead of loop, FIXME
		SizeType n = data_.size();
		for (SizeType i = 0; i < n; ++i)
@@ -129,10 +108,6 @@ public:
		if (ins_ == 0) return;
		if (dimensions_.size() < ins_) return;

		if (hasExatnBackend_) {
			return;
		}

		if (dimensions_.size() == ins_) {
			SizeType rows = m.n_row();
			SizeType cols = m.n_col();
@@ -165,10 +140,6 @@ public:
		if (v == 0)
			throw PsimagLite::RuntimeError("Tensor::setSizes(...): dimensions == 0\n");

		if (hasExatnBackend_) {
			return;
		}

		data_.resize(v,0.0);
	}

@@ -197,34 +168,34 @@ public:
		return dimensions_[ind];
	}

	SizeType index(const VectorSizeType& args) const
	{
		return pack(args);
	}

	const ComplexOrRealType& operator()(const VectorSizeType& args) const
	{
		if (hasExatnBackend_) {
			return data_[0];
		SizeType index = pack(args);
		assert(index < data_.size());
		return data_[index];
	}

	void setValue(const VectorSizeType& args, const ComplexOrRealType& val)
	{
		SizeType index = pack(args);
		assert(index < data_.size());
		return data_[index];
		data_[index] = val;
	}

	// FIXME: GIVES AWAY INTERNALS!!
	ComplexOrRealType& operator()(const VectorSizeType& args)
	const VectorComplexOrRealType& data() const
	{
		if (hasExatnBackend_) {
			return data_[0];
		return data_;
	}

		SizeType index = pack(args);
		assert(index < data_.size());
		return data_[index];
	void setData(const VectorComplexOrRealType& data)
	{
		data_ = data;
	}

	PsimagLite::String name() const { return name_; }

private:

	SizeType pack(const VectorSizeType& args) const
	{
		assert(args.size() > 0);
@@ -242,27 +213,6 @@ public:
		return index;
	}

	const VectorComplexOrRealType& data() const
	{
		return data_;
	}

	void setData(const VectorComplexOrRealType& data)
	{
		data_ = data;
	}

	PsimagLite::String name() const { return name_; }

private:


#ifdef USE_EXATN
	static const bool hasExatnBackend_ = true;
	static exatn::numerics::TensorOpFactory* opFactory_;
#else
	static const bool hasExatnBackend_ = false;
#endif
	static PsimagLite::RandomForTests<ComplexOrRealType> rng_;
	PsimagLite::String name_;
	VectorSizeType dimensions_;
@@ -272,11 +222,5 @@ private:

template<typename ComplexOrRealType>
PsimagLite::RandomForTests<ComplexOrRealType> Tensor<ComplexOrRealType>::rng_(1234);

#ifdef USE_EXATN
template<typename ComplexOrRealType>
exatn::numerics::TensorOpFactory* Tensor<ComplexOrRealType>::opFactory_ = exatn::numerics::TensorOpFactory::get();
#endif

}
#endif // TENSOR_SLOW_H