Commit d6b49937 authored by Adrian Kuegel's avatar Adrian Kuegel
Browse files

[mlir][MemRef] Fix canonicalization of BufferCast(TensorLoad).

CastOp::areCastCompatible does not check whether casts are definitely compatible.
When going from dynamic to static offset or stride, the canonicalization cannot
know whether it is really cast compatible. In that case, it can only canonicalize
to an alloc plus copy.

Differential Revision: https://reviews.llvm.org/D107545
parent 4fee756c
Loading
Loading
Loading
Loading
+45 −3
Original line number Diff line number Diff line
@@ -319,10 +319,52 @@ struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
    // types. `BufferCastOp::fold` handles the same type case.
    if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
      return failure();
    // If types are not cast-compatible, bail.
    // If types are definitely not cast-compatible, bail.
    if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
                                   bufferCast.getType()))
      return failure();

    // We already know that the types are potentially cast-compatible. However
    // in case the affine maps are different, we may need to use a copy if we go
    // from dynamic to static offset or stride (the canonicalization cannot know
    // at this point that it is really cast compatible).
    auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
      int64_t sourceOffset, targetOffset;
      SmallVector<int64_t, 4> sourceStrides, targetStrides;
      if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
          failed(getStridesAndOffset(target, targetStrides, targetOffset)))
        return false;
      auto dynamicToStatic = [](int64_t a, int64_t b) {
        return a == MemRefType::getDynamicStrideOrOffset() &&
               b != MemRefType::getDynamicStrideOrOffset();
      };
      if (dynamicToStatic(sourceOffset, targetOffset))
        return false;
      for (auto it : zip(sourceStrides, targetStrides))
        if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
          return false;
      return true;
    };

    auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>();
    auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>();
    if (tensorLoadType && bufferCastType &&
        !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) {
      MemRefType resultType = bufferCastType;
      auto loc = bufferCast.getLoc();
      SmallVector<Value, 4> dynamicOperands;
      for (int i = 0; i < resultType.getRank(); ++i) {
        if (resultType.getShape()[i] != ShapedType::kDynamicSize)
          continue;
        auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i);
        Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index);
        dynamicOperands.push_back(size);
      }
      auto copy =
          rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
      rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy);
      rewriter.replaceOp(bufferCast, {copy});
    } else
      rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
                                          tensorLoad.memref());
    return success();
+32 −3
Original line number Diff line number Diff line
@@ -46,16 +46,18 @@ func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf3
// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>

// Test case: If the memrefs are cast-compatible, canonicalize.
// Test case: If the memrefs are definitely cast-compatible, canonicalize to
//            cast.
// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load(
//  CHECK-SAME:   %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>)
//  CHEKC-SAME:     -> memref<?xf32, #[[$OFF_UNK]]> {
//  CHECK-SAME:     -> memref<?xf32, #[[$OFF_UNK]]> {
//   CHECK-NOT: memref.tensor_load
//   CHECK-NOT: memref.buffer_cast
//       CHECK: %[[R:.*]] = memref.cast %[[M]]
//  CHECK-SAME:   memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]>
//       CHECK: return %[[R]]
func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>)
func @canonicalize_buffer_cast_of_tensor_load(
  %arg0: memref<?xf32, offset: 3, strides: [1]>)
  -> memref<?xf32, offset: ?, strides: [1]>
{
  %0 = memref.tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]>
@@ -65,6 +67,33 @@ func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, st

// -----

// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>

// Test case: If the memrefs are potentially cast-compatible, canonicalize to
//            copy.
// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_to_copy(
//  CHECK-SAME:   %[[M:.*]]: memref<?xf32, #[[$OFF_UNK]]>)
//  CHECK-SAME:     -> memref<?xf32, #[[$OFF_3]]> {
//   CHECK-NOT: memref.tensor_load
//   CHECK-NOT: memref.buffer_cast
//       CHECK: %[[C0:.*]] = constant 0 : index
//       CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, #[[$OFF_UNK]]>
//       CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, #[[$OFF_3]]>
//       CHECK: memref.copy %[[M]], %[[ALLOC]]
//  CHECK-SAME:   memref<?xf32, #[[$OFF_UNK]]> to memref<?xf32, #[[$OFF_3]]>
//       CHECK: return %[[ALLOC]]
func @canonicalize_buffer_cast_of_tensor_load_to_copy(
  %arg0: memref<?xf32, offset: ?, strides: [1]>)
  -> memref<?xf32, offset: 3, strides: [1]>
{
  %0 = memref.tensor_load %arg0 : memref<?xf32, offset: ?, strides: [1]>
  %1 = memref.buffer_cast %0 : memref<?xf32, offset: 3, strides: [1]>
  return %1 : memref<?xf32, offset: 3, strides: [1]>
}

// -----

// CHECK-LABEL: func @subview_of_memcast
//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
//       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>