Commit 58c18850 authored by Alexander Belyaev's avatar Alexander Belyaev
Browse files

[mlir][linalg] Fix `FoldInitTensorWithDimOp` if dim(init_tensor) is static.

It looks like it was a typo. Instead of `*maybeConstantIndex`,
`initTensorOp.getStaticSize(*maybeConstantIndex)` should be used to access the
dim size of the tensor. There is a test for that in `canonicalize.mlir`, but it
was working correctly because `ReplaceStaticShapeDims` was canonicalizing DimOp
before `FoldInitTensorWithDimOp`. So, to make the patterns more "orthogonal",
this case is disabled.

Differential Revision: https://reviews.llvm.org/D109247
parent 915a8bb5
Loading
Loading
Loading
Loading
+3 −6
Original line number Diff line number Diff line
@@ -977,12 +977,9 @@ struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
    auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>();
    if (!initTensorOp || !maybeConstantIndex)
      return failure();
    if (initTensorOp.isDynamicSize(*maybeConstantIndex)) {
      rewriter.replaceOp(dimOp,
                         initTensorOp.getDynamicSize(*maybeConstantIndex));
      return success();
    }
    rewriter.replaceOpWithNewOp<ConstantIndexOp>(dimOp, *maybeConstantIndex);
    if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
      return failure();
    rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
    return success();
  }
};