Commit a3dd4e77 authored by Ahmed S. Taei's avatar Ahmed S. Taei
Browse files

Drop transfer_read inner most unit dimensions

Add a pattern to take a rank-reducing subview and drop inner most
contiguous unit dim.
This is useful when lowering vector to backends with 1d vector types.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D111561
parent b046eb19
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -97,6 +97,13 @@ struct UnrollVectorOptions {
void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                  const UnrollVectorOptions &options);

/// Collect a set of patterns to reduce the rank of the operands of vector
/// transfer ops to operate on the largest contigious vector.
/// These patterns are useful when lowering to dialects with 1d vector type
/// such as llvm and it will result fewer memory reads.
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
    RewritePatternSet &patterns);

/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
+80 −0
Original line number Diff line number Diff line
@@ -3529,6 +3529,80 @@ private:
  const bool enableIndexOptimizations;
};

// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                PatternRewriter &rewriter) const override {
    auto srcType = readOp.source().getType().cast<MemRefType>();
    if (!srcType || !srcType.hasStaticShape())
      return failure();

    if (!readOp.permutation_map().isMinorIdentity())
      return failure();

    auto targetType = readOp.getVectorType();
    if (targetType.getRank() <= 1)
      return failure();

    SmallVector<int64_t> srcStrides;
    int64_t srcOffset;
    if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
      return failure();

    size_t dimsToDrop = 0;
    for (size_t i = 1; i < srcStrides.size(); ++i) {
      int dim = srcType.getRank() - i - 1;
      if (srcStrides[dim] == 1) {
        dimsToDrop++;
      } else {
        break;
      }
    }
    if (dimsToDrop == 0)
      return failure();

    auto resultTargetVecType =
        VectorType::get(targetType.getShape().drop_back(dimsToDrop),
                        targetType.getElementType());

    MemRefType resultMemrefType;
    if (srcType.getLayout().getAffineMap().isIdentity()) {
      resultMemrefType = MemRefType::get(
          srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
          {}, srcType.getMemorySpaceAsInt());
    } else {
      AffineMap map = srcType.getLayout().getAffineMap();
      int numResultDims = map.getNumDims() - dimsToDrop;
      int numSymbols = map.getNumSymbols();
      for (size_t i = 0; i < dimsToDrop; ++i) {
        int dim = srcType.getRank() - i - 1;
        map = map.replace(rewriter.getAffineDimExpr(dim),
                          rewriter.getAffineConstantExpr(0), numResultDims,
                          numSymbols);
      }
      resultMemrefType = MemRefType::get(
          srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
          map, srcType.getMemorySpaceAsInt());
    }

    auto loc = readOp.getLoc();
    SmallVector<int64_t> offsets(srcType.getRank(), 0);
    SmallVector<int64_t> strides(srcType.getRank(), 1);
    Value rankedReducedView = rewriter.create<memref::SubViewOp>(
        loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(),
        strides);
    Value result = rewriter.create<vector::TransferReadOp>(
        loc, resultTargetVecType, rankedReducedView,
        readOp.indices().drop_back(dimsToDrop));
    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
                                                     result);

    return success();
  }
};

void mlir::vector::populateVectorMaskMaterializationPatterns(
    RewritePatternSet &patterns, bool enableIndexOptimizations) {
  patterns.add<VectorCreateMaskOpConversion,
@@ -3617,3 +3691,9 @@ void mlir::vector::populateVectorUnrollPatterns(
               UnrollContractionPattern, UnrollElementwisePattern>(
      patterns.getContext(), options);
}

void mlir::vector::
    populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
        RewritePatternSet &patterns) {
  patterns.add<DropInnerMostUnitDims>(patterns.getContext());
}
+33 −0
Original line number Diff line number Diff line
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s

#map1 = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2 + d3)>
func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, #map1>) -> vector<1x8x1xf32>{
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.0 : f32
  %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, #map1>, vector<1x8x1xf32>
  return %0 : vector<1x8x1xf32>
}
//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2 + d3)>
//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 3072 + s0 + d1 * 8 + d2)>
//      CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, #[[MAP0]]>
//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME:    memref<1x1x8x1xf32, #[[MAP0]]> to memref<1x1x8xf32, #[[MAP1]]>
//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
// CHECK-SAME:    memref<1x1x8xf32, #[[MAP1]]>, vector<1x8xf32>
//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
//      CHECK:   return %[[RESULT]]

// -----

func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
  %c0 = arith.constant 0 : index
  %f0 = arith.constant 0.0 : f32
  %1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<8x1xf32>
  return %1 : vector<8x1xf32>
}
//      CHECK: func @contiguous_inner_most_dim(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<8x1xf32>
//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME:     memref<16x1xf32> to memref<16xf32>
//      CHECK:   %[[V:.+]] = vector.transfer_read %[[SRC_0]]
//      CHECK:   %[[RESULT]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32>
//      CHECK:   return %[[RESULT]]
+29 −0
Original line number Diff line number Diff line
@@ -466,6 +466,33 @@ struct TestVectorMultiReductionLoweringPatterns
  }
};

struct TestVectorTransferCollapseInnerMostContiguousDims
    : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
                         FunctionPass> {
  TestVectorTransferCollapseInnerMostContiguousDims() = default;
  TestVectorTransferCollapseInnerMostContiguousDims(
      const TestVectorTransferCollapseInnerMostContiguousDims &pass) {}

  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<memref::MemRefDialect, AffineDialect>();
  }

  StringRef getArgument() const final {
    return "test-vector-transfer-collapse-inner-most-dims";
  }

  StringRef getDescription() const final {
    return "Test conversion patterns that reducedes the rank of the vector "
           "transfer memory and vector operands.";
  }

  void runOnFunction() override {
    RewritePatternSet patterns(&getContext());
    populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
  }
};

} // end anonymous namespace

namespace mlir {
@@ -490,6 +517,8 @@ void registerTestVectorConversions() {
  PassRegistration<TestVectorTransferLoweringPatterns>();

  PassRegistration<TestVectorMultiReductionLoweringPatterns>();

  PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
}
} // namespace test
} // namespace mlir