Commit 1bc4a281 authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

MPS Simplify

parent e47e0feb
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -194,11 +194,12 @@ private:
	{
		assert(sites_ > 2);
		const SizeType nMinusOne = sites_ - 1;
		const PsimagLite::String name = "a";

		srep_ = "u0(s" + ttos(nMinusOne) + ",f0|s0)";
		srep_ = name + "0(s" + ttos(nMinusOne) + ",f0|s0)";

		for (SizeType i = 0; i < nMinusOne; ++i) {
			srep_ +=  "u" + ttos(i+1) + "(s" + ttos(i) + ",f" + ttos(i+1) + "|s" + ttos(i+1) + ")";
			srep_ +=  name + ttos(i+1) + "(s" + ttos(i) + ",f" + ttos(i+1) + "|s" + ttos(i+1) + ")";
		}
	}

+50 −6
Original line number Diff line number Diff line
@@ -458,12 +458,13 @@ private:
	{
		bool simplificationHappended = false;
		SizeType ntensors = data_.size();
		const bool mpsMode = (data_[0]->fullName()[0] == 'a');
		for (SizeType i = 0; i < ntensors; ++i) {
			if (data_[i]->isConjugate()) continue;
			SizeType j = findConjugate(i);
			if (j >= data_.size()) continue; // no conjugate
			if (!inputsMatch(i,j)) continue;
			if (!simplify(i,j)) continue;
			if (!inputsMatchShim(mpsMode, i, j)) continue;
			if (!simplify(mpsMode, i, j)) continue;
			simplificationHappended = true;
			break; // only one simplification
		}
@@ -471,12 +472,12 @@ private:
		return simplificationHappended;
	}

	bool simplify(SizeType ind, SizeType jnd)
	bool simplify(bool mpsMode, SizeType ind, SizeType jnd)
	{
		bool simplificationHappended = false;
		SizeType outs = data_[ind]->outs();
		VectorPairSizeType replacements(outs);
		if (!computeReplacements(replacements,ind,jnd))
		VectorPairSizeType replacements((mpsMode) ? 2 : outs);
		if (!computeReplacementsShim(mpsMode, replacements, ind, jnd))
			return simplificationHappended;

		data_[ind]->setAsErased();
@@ -517,6 +518,15 @@ private:
		}
	}

	bool computeReplacementsShim(bool mpsMode,
	                             VectorPairSizeType& replacements,
	                             SizeType ind,
	                             SizeType jnd) const
	{
		return (mpsMode) ? computeReplacementsMPS(replacements, ind, jnd )
		                 : computeReplacements(replacements, ind, jnd);
	}

	bool computeReplacements(VectorPairSizeType& replacements,
	                         SizeType ind,
	                         SizeType jnd) const
@@ -538,6 +548,30 @@ private:
		return true;
	}

	bool computeReplacementsMPS(VectorPairSizeType& replacements,
	                            SizeType ind,
	                            SizeType jnd) const
	{
		for (SizeType leg = 0; leg < 3; ++leg) {
			if (leg == 1) continue;
			if (data_[ind]->legType(leg) != TensorStanzaType::INDEX_TYPE_SUMMED)
				return false;
			if (data_[jnd]->legType(leg) != TensorStanzaType::INDEX_TYPE_SUMMED)
				return false;
			const SizeType s1 = data_[ind]->legTag(leg);
			const SizeType s2 = data_[jnd]->legTag(leg);
			SizeType x = (leg == 0) ? 0 : 1;
			replacements[x] = PairSizeType((s1 < s2) ? s2 : s1,(s1 < s2) ? s1 : s2);
		}

		return true;
	}

	bool inputsMatchShim(bool mpsMode, SizeType ind, SizeType jnd) const
	{
		return (mpsMode) ? inputsMatchMPS(ind, jnd) : inputsMatch(ind, jnd);
	}

	bool inputsMatch(SizeType ind, SizeType jnd) const
	{
		SizeType ins = data_[ind]->ins();
@@ -554,6 +588,16 @@ private:
		return true;
	}

	bool inputsMatchMPS(SizeType ind, SizeType jnd) const
	{
		SizeType ins = data_[ind]->ins();
		if (ins != data_[jnd]->ins()) return false;
		if (ins < 2) return false;
		return (data_[ind]->legTag(1) == data_[jnd]->legTag(1) &&
		        data_[ind]->legType(1) == TensorStanzaType::INDEX_TYPE_SUMMED &&
		        data_[jnd]->legType(1) == TensorStanzaType::INDEX_TYPE_SUMMED);
	}

	void append(const TensorSrep& other)
	{
		SizeType ntensors = data_.size();