Commit e3373c6c authored by Matthias Springer's avatar Matthias Springer
Browse files

[mlir][memref] Fix crash in SubViewReturnTypeCanonicalizer

`SubViewReturnTypeCanonicalizer` is used by `OpWithOffsetSizesAndStridesConstantArgumentFolder`, which folds constant SSA value (dynamic) sizes into static sizes. The previous implementation crashed when a dynamic size was folded into a static `1` dimension, which was then mistaken as a rank reduction.

Differential Revision: https://reviews.llvm.org/D158721
parent 742fa941
Loading
Loading
Loading
Loading
+33 −31
Original line number Diff line number Diff line
@@ -31,23 +31,17 @@ namespace {
namespace saturated_arith {
struct Wrapper {
  static Wrapper stride(int64_t v) {
    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
                                                    : Wrapper{false, v};
    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
  }
  static Wrapper offset(int64_t v) {
    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
                                                    : Wrapper{false, v};
    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
  }
  static Wrapper size(int64_t v) {
    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
  }
  int64_t asOffset() {
    return saturated ? ShapedType::kDynamic : v;
  }
  int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
  int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
  int64_t asStride() {
    return saturated ? ShapedType::kDynamic : v;
  }
  int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
  bool operator==(Wrapper other) {
    return (saturated && other.saturated) ||
           (!saturated && !other.saturated && v == other.v);
@@ -731,8 +725,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
  for (auto it : llvm::zip(sourceStrides, resultStrides)) {
    auto ss = std::get<0>(it), st = std::get<1>(it);
    if (ss != st)
      if (ShapedType::isDynamic(ss) &&
          !ShapedType::isDynamic(st))
      if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
        return false;
  }

@@ -765,8 +758,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
      // same. They are also compatible if either one is dynamic (see
      // description of MemRefCastOp for details).
      auto checkCompatible = [](int64_t a, int64_t b) {
        return (ShapedType::isDynamic(a) ||
                ShapedType::isDynamic(b) || a == b);
        return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
      };
      if (!checkCompatible(aOffset, bOffset))
        return false;
@@ -1889,8 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
  // Match offset in result memref type and in static_offsets attribute.
  int64_t expectedOffset = getStaticOffsets().front();
  if (!ShapedType::isDynamic(resultOffset) &&
      !ShapedType::isDynamic(expectedOffset) &&
      resultOffset != expectedOffset)
      !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
    return emitError("expected result type with offset = ")
           << expectedOffset << " instead of " << resultOffset;

@@ -2944,18 +2935,6 @@ static MemRefType getCanonicalSubViewResultType(
                         nonRankReducedType.getMemorySpace());
}

/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
/// to deduce the result type. Additionally, reduce the rank of the inferred
/// result type if `currentResultType` is lower rank than `sourceType`.
static MemRefType getCanonicalSubViewResultType(
    MemRefType currentResultType, MemRefType sourceType,
    ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
    ArrayRef<OpFoldResult> mixedStrides) {
  return getCanonicalSubViewResultType(currentResultType, sourceType,
                                       sourceType, mixedOffsets, mixedSizes,
                                       mixedStrides);
}

Value mlir::memref::createCanonicalRankReducingSubViewOp(
    OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
  auto memrefType = llvm::cast<MemRefType>(memref.getType());
@@ -3108,9 +3087,32 @@ struct SubViewReturnTypeCanonicalizer {
  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
                        ArrayRef<OpFoldResult> mixedSizes,
                        ArrayRef<OpFoldResult> mixedStrides) {
    return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
                                         mixedOffsets, mixedSizes,
                                         mixedStrides);
    // Infer a memref type without taking into account any rank reductions.
    MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
        op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));

    // Directly return the non-rank reduced type if there are no dropped dims.
    llvm::SmallBitVector droppedDims = op.getDroppedDims();
    if (droppedDims.empty())
      return nonReducedType;

    // Take the strides and offset from the non-rank reduced type.
    auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);

    // Drop dims from shape and strides.
    SmallVector<int64_t> targetShape;
    SmallVector<int64_t> targetStrides;
    for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
      if (droppedDims.test(i))
        continue;
      targetStrides.push_back(nonReducedStrides[i]);
      targetShape.push_back(nonReducedType.getDimSize(i));
    }

    return MemRefType::get(targetShape, nonReducedType.getElementType(),
                           StridedLayoutAttr::get(nonReducedType.getContext(),
                                                  offset, targetStrides),
                           nonReducedType.getMemorySpace());
  }
};

+16 −1
Original line number Diff line number Diff line
@@ -931,7 +931,7 @@ func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32

// -----

// CHECK-lABEL: func @ub_negative_alloc_size
// CHECK-LABEL: func private @ub_negative_alloc_size
func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
  %idx1 = index.constant 1
  %c-2 = arith.constant -2 : index
@@ -940,3 +940,18 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
  %alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1>
  return %alloc : memref<?x?x?xi1>
}

// -----

// CHECK-LABEL: func @subview_rank_reduction(
//  CHECK-SAME:     %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
    -> memref<?x?xf32, strided<[384, 1], offset: ?>> {
  %c1 = arith.constant 1 : index
  // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
  // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
  %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
      : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
  // CHECK: return %[[cast]]
  return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
}