Commit 1e25109f authored by Ahmed Taei's avatar Ahmed Taei
Browse files

Canonicalize static alloc followed by memref_cast and std.view

Summary: Rewrite alloc, memref_cast, std.view into allo, std.view by droping memref_cast.

Reviewers: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72379
parent 31992a69
Loading
Loading
Loading
Loading
+21 −1
Original line number Diff line number Diff line
@@ -2527,11 +2527,31 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
  }
};

struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
  using OpRewritePattern<ViewOp>::OpRewritePattern;

  PatternMatchResult matchAndRewrite(ViewOp viewOp,
                                     PatternRewriter &rewriter) const override {
    Value memrefOperand = viewOp.getOperand(0);
    MemRefCastOp memrefCastOp =
        dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
    if (!memrefCastOp)
      return matchFailure();
    Value allocOperand = memrefCastOp.getOperand();
    AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
    if (!allocOp)
      return matchFailure();
    rewriter.replaceOpWithNewOp<ViewOp>(memrefOperand, viewOp, viewOp.getType(),
                                        allocOperand, viewOp.operands());
    return matchSuccess();
  }
};

} // end anonymous namespace

void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
  results.insert<ViewOpShapeFolder>(context);
  results.insert<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}

//===----------------------------------------------------------------------===//
+8 −3
Original line number Diff line number Diff line
@@ -695,6 +695,7 @@ func @cast_values(%arg0: tensor<*xi32>, %arg1: memref<?xi32>) -> (tensor<2xi32>,

// CHECK-LABEL: func @view
func @view(%arg0 : index) {
  // CHECK: %[[ALLOC_MEM:.*]] = alloc() : memref<2048xi8>
  %0 = alloc() : memref<2048xi8>
  %c0 = constant 0 : index
  %c7 = constant 7 : index
@@ -730,11 +731,15 @@ func @view(%arg0 : index) {

  // Test: preserve an existing static dim size while folding a dynamic
  // dimension and offset.
  // CHECK: std.view %0[][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]>
  %5 = view %0[%c15][%c7]
    : memref<2048xi8> to memref<?x4xf32, #TEST_VIEW_MAP2>
  // CHECK: std.view %[[ALLOC_MEM]][][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]>
  %5 = view %0[%c15][%c7] : memref<2048xi8> to memref<?x4xf32, #TEST_VIEW_MAP2>
  load %5[%c0, %c0] : memref<?x4xf32, #TEST_VIEW_MAP2>

  // Test: folding static alloc and memref_cast into a view.
  // CHECK: std.view %0[][%c15, %c7] : memref<2048xi8> to memref<?x?xf32>
  %6 = memref_cast %0 : memref<2048xi8> to memref<?xi8>
  %7 = view %6[%c15][%c7] : memref<?xi8> to memref<?x?xf32>
  load %7[%c0, %c0] : memref<?x?xf32>
  return
}