Commit 74781f6a authored by Alvarez, Gonzalo's avatar Alvarez, Gonzalo
Browse files

CrsMatrix: A cross B added

parent 50c66a8f
Loading
Loading
Loading
Loading
+55 −0
Original line number Diff line number Diff line
@@ -901,6 +901,61 @@ externalProduct(CrsMatrix<T>& B,
	if (nrow_B != 0) B.checkValidity();
}

//-------

/** If order==false then
		creates C such that C_{i1+j1*nout,i2+j2*nout)=A(j1,j2)B_{i1,i2}
		if order==true then
		creates C such that C_{i1+j1*na,i2+j2*na)=A(i1,i2)B_{j1,j2}
		where na=rank(A) and nout = rank(B)
	  */

template<typename T,typename VectorLikeType>
typename EnableIf<IsVectorLike<VectorLikeType>::True &&
Loki::TypeTraits<typename VectorLikeType::value_type>::isFloat,
void>::Type
externalProduct(CrsMatrix<T>& C,
                const CrsMatrix<T>& A,
                const CrsMatrix<T>& B,
                const VectorLikeType& signs,
                bool order,
                const PsimagLite::Vector<SizeType>::Type& permutationFull)
{
    const SizeType nfull = permutationFull.size();

    Vector<SizeType>::Type perm(nfull);
    for (SizeType i = 0; i < nfull; ++i) perm[permutationFull[i]] = i;

    const SizeType nout = B.rows();
    const SizeType na = A.rows();
    const SizeType noutOrNa = (!order) ? nout : na;
    const CrsMatrix<T>& AorB = (!order) ? A : B;
    const CrsMatrix<T>& BorA = (!order) ? B : A;
    assert(A.rows() == A.cols());
    assert(B.rows() == B.cols());
    assert(nout*na == nfull);
    assert(signs.size() == noutOrNa);
    C.resize(nfull, nfull);
    SizeType counter = 0;
    for (SizeType i = 0; i < nfull; ++i) {
        C.setRow(i, counter);
        const SizeType ind = perm[i];
        div_t q = div(ind, noutOrNa);
        for (int k1 = BorA.getRowPtr(q.rem); k1 < BorA.getRowPtr(q.rem + 1); ++k1) {
            const SizeType col1 = BorA.getCol(k1);
            for (int k2 = AorB.getRowPtr(q.quot); k2 < AorB.getRowPtr(q.quot + 1); ++k2) {
                const SizeType col2 = AorB.getCol(k2);
                SizeType j = permutationFull[col1 + col2*noutOrNa];
                C.pushCol(j);
                C.pushValue(BorA.getValue(k1) * AorB.getValue(k2) * signs[q.rem]);
                ++counter;
            }
        }
    }

    C.setRow(nfull, counter);
    C.checkValidity();
}

template<typename T>
void printFullMatrix(const CrsMatrix<T>& s,