Commit 08f0cb77 authored by thomasraoux's avatar thomasraoux
Browse files

[mlir] Prevent crash in DropUnitDim pattern due to tensor with encoding

Differential Revision: https://reviews.llvm.org/D109984
parent d13d9da1
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -361,6 +361,12 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {

  LogicalResult matchAndRewrite(GenericOp genericOp,
                                PatternRewriter &rewriter) const override {
    // Skip the pattern if the op has any tensor with special encoding.
    if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
          auto tensorType = type.dyn_cast<RankedTensorType>();
          return tensorType && tensorType.getEncoding() != nullptr;
        }))
      return failure();
    MLIRContext *context = rewriter.getContext();
    Location loc = genericOp.getLoc();

+31 −0
Original line number Diff line number Diff line
@@ -796,3 +796,34 @@ func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: me
// CHECK:       linalg.yield %[[ARG]] : f32
// CHECK:      }
// CHECK:      return %[[ARG2]] : memref<?x1x?x1x?xf32>

// -----

// Negative test for case with tensor encoding.
#matvec = {
  indexing_maps = [
    affine_map<(i,j) -> (i,j)>, // A
    affine_map<(i,j) -> (j)>,   // b
    affine_map<(i,j) -> (i)>    // x (out)
  ],
  iterator_types = ["parallel", "reduction"]
}

#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }>

func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
    %0 = linalg.init_tensor [8] : tensor<8xf32>
    %1 = linalg.generic #matvec
      ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>)
      outs(%0: tensor<8xf32>) {
      ^bb(%a: f32, %b: f32, %x: f32):
        %m = mulf %a, %b : f32
        %add = addf %x, %m : f32
        linalg.yield %add : f32
    } -> tensor<8xf32>
    return %1: tensor<8xf32>
}

// CHECK-LABEL: func @sparse_case
//  CHECK-NEXT:   linalg.init_tensor
//  CHECK-NEXT:   linalg.generic