Commit 31995031 authored by Shawn Yang's avatar Shawn Yang
Browse files

changing row-col major function signature(reversed parameter)

parent 07153cf1
Loading
Loading
Loading
Loading
+75 −22
Original line number Diff line number Diff line
@@ -837,8 +837,12 @@ int NdCopy(const char *in, const Dims &inStart, const Dims &inCount,
    // padding
    else
    {
        Dims revInCount(inCount);
        Dims revOutCount(outCount);
        //        Dims revInCount(inCount);
        //        Dims revOutCount(outCount);
        //
        // col-major ==> col-major mode
        if (!inIsRowMajor && !outIsRowMajor)
        {
            GetInEnd(inEnd, inStart, inCount);
            GetOutEnd(outEnd, outStart, outCount);
            GetOvlpStart(ovlpStart, inStart, outStart);
@@ -846,38 +850,87 @@ int NdCopy(const char *in, const Dims &inStart, const Dims &inCount,
            GetOvlpCount(ovlpCount, ovlpStart, ovlpEnd);
            if (!HasOvlp(ovlpStart, ovlpEnd))
                return 1; // no overlap found
        // col-major ==> col-major mode
        if (!inIsRowMajor && !outIsRowMajor)
        {
            // reverse the inCount, calculate inStride with it and reverse the
            // inStride
            std::reverse(revInCount.begin(), revInCount.end());
            GetIoStrides(inStride, revInCount, sizeof(T));
            std::reverse(inStride.begin(), inStride.end());
            // reverse the outCount, calculate outStride with it and reverse the
            // outStride
            std::reverse(revOutCount.begin(), revOutCount.end());
            GetIoStrides(outStride, revOutCount, sizeof(T));
            std::reverse(outStride.begin(), outStride.end());

            //            std::reverse(revInCount.begin(), revInCount.end());
            GetIoStrides(inStride, inCount, sizeof(T));
            //            std::reverse(inStride.begin(), inStride.end());

            //            std::reverse(revOutCount.begin(), revOutCount.end());
            GetIoStrides(outStride, outCount, sizeof(T));
            //            std::reverse(outStride.begin(), outStride.end());

            //            Dims revOvlpStart(ovlpStart);
            //            std::reverse(revOvlpStart.begin(),
            //            revOvlpStart.end());
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inStart, ovlpStart);
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outStart, ovlpStart);
        }
        // row-major ==> col-major mode
        else if (inIsRowMajor && !outIsRowMajor)
        {
            Dims revOutStart(outStart);
            Dims revOutCount(outCount);
            std::reverse(revOutStart.begin(), revOutStart.end());
            std::reverse(revOutCount.begin(), revOutCount.end());

            GetInEnd(inEnd, inStart, inCount);
            GetOutEnd(outEnd, revOutStart, revOutCount);
            GetOvlpStart(ovlpStart, inStart, revOutStart);
            GetOvlpEnd(ovlpEnd, inEnd, outEnd);
            GetOvlpCount(ovlpCount, ovlpStart, ovlpEnd);
            if (!HasOvlp(ovlpStart, ovlpEnd))
                return 1; // no overlap found

            // get normal order inStride
            GetIoStrides(inStride, inCount, sizeof(T));

            // calulate reversed order outStride
            std::reverse(revOutCount.begin(), revOutCount.end());
            GetIoStrides(outStride, revOutCount, sizeof(T));
            // reverse outStride so that outStride aligns to inStride
            std::reverse(outStride.begin(), outStride.end());

            // get normal order inOvlpStart
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inStart, ovlpStart);

            // get reversed order outOvlpStart
            Dims revOvlpStart(ovlpStart);
            std::reverse(revOvlpStart.begin(), revOvlpStart.end());
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outStart, revOvlpStart);
        }
        // col-major ==> row-major mode
        else if (!inIsRowMajor && outIsRowMajor)
        {
            Dims revInStart(inStart);
            Dims revInCount(inCount);
            std::reverse(revInStart.begin(), revInStart.end());
            std::reverse(revInCount.begin(), revInCount.end());

            GetInEnd(inEnd, revInStart, revInCount);
            GetOutEnd(outEnd, outStart, outCount);
            GetOvlpStart(ovlpStart, revInStart, outStart);
            GetOvlpEnd(ovlpEnd, inEnd, outEnd);
            GetOvlpCount(ovlpCount, ovlpStart, ovlpEnd);
            if (!HasOvlp(ovlpStart, ovlpEnd))
                return 1; // no overlap found

            // get normal order outStride
            GetIoStrides(outStride, outCount, sizeof(T));

            // calculate reversed inStride
            std::reverse(revInCount.begin(), revInCount.end());
            GetIoStrides(inStride, revInCount, sizeof(T));
            // reverse inStride so that inStride aligns to outStride
            std::reverse(inStride.begin(), inStride.end());
            GetIoStrides(outStride, outCount, sizeof(T));
        }
        GetRltvOvlpStartPos(inRltvOvlpStartPos, inStart, ovlpStart);

            // get reversed order inOvlpStart
            Dims revOvlpStart(ovlpStart);
            std::reverse(revOvlpStart.begin(), revOvlpStart.end());
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inStart, revOvlpStart);
            // get normal order outOvlpStart
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outStart, ovlpStart);
        }

        inOvlpBase = in;
        outOvlpBase = out;
        // Same Endian"