Unverified Commit 5aa2c65a authored by tyb0807's avatar tyb0807 Committed by GitHub
Browse files

[mlir][MemRef] Add subview folding pattern for vector.maskedload (#71380)

This is required for fixing https://github.com/openxla/iree/issues/15031
parent 2302e4c3
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {

static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }

static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }

static Value getMemRefOperand(vector::TransferWriteOp op) {
  return op.getSource();
}
@@ -415,6 +417,11 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
        rewriter.replaceOpWithNewOp<vector::LoadOp>(
            op, op.getType(), subViewOp.getSource(), sourceIndices);
      })
      .Case([&](vector::MaskedLoadOp op) {
        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
            op, op.getType(), subViewOp.getSource(), sourceIndices,
            op.getMask(), op.getPassThru());
      })
      .Case([&](vector::TransferReadOp op) {
        rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
            op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
@@ -687,6 +694,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
               LoadOpOfSubViewOpFolder<memref::LoadOp>,
               LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
               LoadOpOfSubViewOpFolder<vector::LoadOp>,
               LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
               LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
               LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
               StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
+18 −1
Original line number Diff line number Diff line
@@ -665,3 +665,20 @@ func.func @fold_vector_load(
// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
//      CHECK:   vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<12x32xf32>

// -----

func.func @fold_vector_maskedload(
  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
  return %1 : vector<32xf32>
}

//      CHECK: func @fold_vector_maskedload
// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>