Loading mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +6 −0 Original line number Diff line number Diff line Loading @@ -361,6 +361,12 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Skip the pattern if the op has any tensor with special encoding. if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { auto tensorType = type.dyn_cast<RankedTensorType>(); return tensorType && tensorType.getEncoding() != nullptr; })) return failure(); MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); Loading mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +31 −0 Original line number Diff line number Diff line Loading @@ -796,3 +796,34 @@ func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: me // CHECK: linalg.yield %[[ARG]] : f32 // CHECK: } // CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32> // ----- // Negative test for case with tensor encoding. #matvec = { indexing_maps = [ affine_map<(i,j) -> (i,j)>, // A affine_map<(i,j) -> (j)>, // b affine_map<(i,j) -> (i)> // x (out) ], iterator_types = ["parallel", "reduction"] } #CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }> func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> { %0 = linalg.init_tensor [8] : tensor<8xf32> %1 = linalg.generic #matvec ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>) outs(%0: tensor<8xf32>) { ^bb(%a: f32, %b: f32, %x: f32): %m = mulf %a, %b : f32 %add = addf %x, %m : f32 linalg.yield %add : f32 } -> tensor<8xf32> return %1: tensor<8xf32> } // CHECK-LABEL: func @sparse_case // CHECK-NEXT: linalg.init_tensor // CHECK-NEXT: linalg.generic Loading
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +6 −0 Original line number Diff line number Diff line Loading @@ -361,6 +361,12 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Skip the pattern if the op has any tensor with special encoding. if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { auto tensorType = type.dyn_cast<RankedTensorType>(); return tensorType && tensorType.getEncoding() != nullptr; })) return failure(); MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); Loading
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +31 −0 Original line number Diff line number Diff line Loading @@ -796,3 +796,34 @@ func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: me // CHECK: linalg.yield %[[ARG]] : f32 // CHECK: } // CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32> // ----- // Negative test for case with tensor encoding. #matvec = { indexing_maps = [ affine_map<(i,j) -> (i,j)>, // A affine_map<(i,j) -> (j)>, // b affine_map<(i,j) -> (i)> // x (out) ], iterator_types = ["parallel", "reduction"] } #CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }> func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> { %0 = linalg.init_tensor [8] : tensor<8xf32> %1 = linalg.generic #matvec ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>) outs(%0: tensor<8xf32>) { ^bb(%a: f32, %b: f32, %x: f32): %m = mulf %a, %b : f32 %add = addf %x, %m : f32 linalg.yield %add : f32 } -> tensor<8xf32> return %1: tensor<8xf32> } // CHECK-LABEL: func @sparse_case // CHECK-NEXT: linalg.init_tensor // CHECK-NEXT: linalg.generic