Commit fc4db145 authored by Shawn Yang's avatar Shawn Yang
Browse files

adding support for IO memStart, memCount params for reversed major

parent 31995031
Loading
Loading
Loading
Loading
+27 −22
Original line number Diff line number Diff line
@@ -843,6 +843,7 @@ int NdCopy(const char *in, const Dims &inStart, const Dims &inCount,
        // col-major ==> col-major mode
        if (!inIsRowMajor && !outIsRowMajor)
        {

            GetInEnd(inEnd, inStart, inCount);
            GetOutEnd(outEnd, outStart, outCount);
            GetOvlpStart(ovlpStart, inStart, outStart);
@@ -862,16 +863,18 @@ int NdCopy(const char *in, const Dims &inStart, const Dims &inCount,
            //            Dims revOvlpStart(ovlpStart);
            //            std::reverse(revOvlpStart.begin(),
            //            revOvlpStart.end());
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inStart, ovlpStart);
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outStart, ovlpStart);
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inMemStartNC, ovlpStart);
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outMemStartNC, ovlpStart);
        }
            // row-major ==> col-major mode
        else if (inIsRowMajor && !outIsRowMajor)
        {
            Dims revOutStart(outStart);
            Dims revOutCount(outCount);
            std::reverse(revOutStart.begin(), revOutStart.end());
            std::reverse(revOutCount.begin(), revOutCount.end());
//            std::reverse(revOutStart.begin(), revOutStart.end());
//            std::reverse(revOutCount.begin(), revOutCount.end());
            std::reverse(outMemStartNC.begin(), outMemStartNC.end());
            std::reverse(outMemCountNC.begin(), outMemCountNC.end());

            GetInEnd(inEnd, inStart, inCount);
            GetOutEnd(outEnd, revOutStart, revOutCount);
@@ -882,29 +885,31 @@ int NdCopy(const char *in, const Dims &inStart, const Dims &inCount,
                return 1; // no overlap found

            // get normal order inStride
            GetIoStrides(inStride, inCount, sizeof(T));
            GetIoStrides(inStride, inMemCountNC, sizeof(T));

            // calulate reversed order outStride
            std::reverse(revOutCount.begin(), revOutCount.end());
            GetIoStrides(outStride, revOutCount, sizeof(T));
//            std::reverse(revOutCount.begin(), revOutCount.end());
            GetIoStrides(outStride, outMemCountNC, sizeof(T));
            // reverse outStride so that outStride aligns to inStride
            std::reverse(outStride.begin(), outStride.end());

            // get normal order inOvlpStart
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inStart, ovlpStart);
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inMemStartNC, ovlpStart);

            // get reversed order outOvlpStart
            Dims revOvlpStart(ovlpStart);
            std::reverse(revOvlpStart.begin(), revOvlpStart.end());
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outStart, revOvlpStart);
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outMemStartNC, revOvlpStart);
        }
            // col-major ==> row-major mode
        else if (!inIsRowMajor && outIsRowMajor)
        {
            Dims revInStart(inStart);
            Dims revInCount(inCount);
            std::reverse(revInStart.begin(), revInStart.end());
            std::reverse(revInCount.begin(), revInCount.end());
//            std::reverse(revInStart.begin(), revInStart.end());
//            std::reverse(revInCount.begin(), revInCount.end());
            std::reverse(inMemStartNC.begin(), inMemStartNC.end());
            std::reverse(inMemCountNC.begin(), inMemCountNC.end());

            GetInEnd(inEnd, revInStart, revInCount);
            GetOutEnd(outEnd, outStart, outCount);
@@ -915,20 +920,20 @@ int NdCopy(const char *in, const Dims &inStart, const Dims &inCount,
                return 1; // no overlap found

            // get normal order outStride
            GetIoStrides(outStride, outCount, sizeof(T));
            GetIoStrides(outStride, outMemCountNC, sizeof(T));

            // calculate reversed inStride
            std::reverse(revInCount.begin(), revInCount.end());
            GetIoStrides(inStride, revInCount, sizeof(T));
//            std::reverse(revInCount.begin(), revInCount.end());
            GetIoStrides(inStride, inMemCountNC, sizeof(T));
            // reverse inStride so that inStride aligns to outStride
            std::reverse(inStride.begin(), inStride.end());

            // get reversed order inOvlpStart
            Dims revOvlpStart(ovlpStart);
            std::reverse(revOvlpStart.begin(), revOvlpStart.end());
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inStart, revOvlpStart);
            GetRltvOvlpStartPos(inRltvOvlpStartPos, inMemStartNC, revOvlpStart);
            // get normal order outOvlpStart
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outStart, ovlpStart);
            GetRltvOvlpStartPos(outRltvOvlpStartPos, outMemStartNC, ovlpStart);
        }

        inOvlpBase = in;