Unverified Commit 932dc9d8 authored by qcolombet's avatar qcolombet Committed by GitHub
Browse files

[mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… (#68291)

…st)`

`expand-strided-metadata` was missing a pattern to get rid of
`memref.cast`.
The pattern is straight foward:
Produce a new `extract_strided_metadata` with the source of the cast and
fold the static information (sizes, strides, offset) along the way.
parent 253ee85f
Loading
Loading
Loading
Loading
+88 −0
Original line number Diff line number Diff line
@@ -870,6 +870,92 @@ class ExtractStridedMetadataOpReinterpretCastFolder
  }
};

/// Replace `base, offset, sizes, strides =
///              extract_strided_metadata(
///                 cast(src) to dstTy)`
/// With
/// ```
/// base, ... = extract_strided_metadata(src)
/// offset = !dstTy.srcOffset.isDynamic()
///            ? dstTy.srcOffset
///            : extract_strided_metadata(src).offset
/// sizes = for each srcSize in dstTy.srcSizes:
///           !srcSize.isDynamic()
///             ? srcSize
//              : extract_strided_metadata(src).sizes[i]
/// strides = for each srcStride in dstTy.srcStrides:
///             !srcStrides.isDynamic()
///               ? srcStrides
///               : extract_strided_metadata(src).strides[i]
/// ```
///
/// In other words, consume the `cast` and apply its effects
/// on the offset, sizes, and strides or compute them directly from `src`.
class ExtractStridedMetadataOpCastFolder
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult
  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
                  PatternRewriter &rewriter) const override {
    Value source = extractStridedMetadataOp.getSource();
    auto castOp = source.getDefiningOp<memref::CastOp>();
    if (!castOp)
      return failure();

    Location loc = extractStridedMetadataOp.getLoc();
    // Check if the source is suitable for extract_strided_metadata.
    SmallVector<Type> inferredReturnTypes;
    if (failed(extractStridedMetadataOp.inferReturnTypes(
            rewriter.getContext(), loc, {castOp.getSource()},
            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
            inferredReturnTypes)))
      return rewriter.notifyMatchFailure(castOp,
                                         "cast source's type is incompatible");

    auto memrefType = cast<MemRefType>(source.getType());
    unsigned rank = memrefType.getRank();
    SmallVector<OpFoldResult> results;
    results.resize_for_overwrite(rank * 2 + 2);

    auto newExtractStridedMetadata =
        rewriter.create<memref::ExtractStridedMetadataOp>(loc,
                                                          castOp.getSource());

    // Register the base_buffer.
    results[0] = newExtractStridedMetadata.getBaseBuffer();

    auto getConstantOrValue = [&rewriter](int64_t constant,
                                          OpFoldResult ofr) -> OpFoldResult {
      return !ShapedType::isDynamic(constant)
                 ? OpFoldResult(rewriter.getIndexAttr(constant))
                 : ofr;
    };

    auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
    assert(sourceStrides.size() == rank && "unexpected number of strides");

    // Register the new offset.
    results[1] =
        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());

    const unsigned sizeStartIdx = 2;
    const unsigned strideStartIdx = sizeStartIdx + rank;
    ArrayRef<int64_t> sourceSizes = memrefType.getShape();

    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
    for (unsigned i = 0; i < rank; ++i) {
      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
      results[strideStartIdx + i] =
          getConstantOrValue(sourceStrides[i], strides[i]);
    }
    rewriter.replaceOp(extractStridedMetadataOp,
                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
    return success();
  }
};

/// Replace `base, offset =
///            extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
@@ -911,6 +997,7 @@ void memref::populateExpandStridedMetadataPatterns(
               ExtractStridedMetadataOpGetGlobalFolder,
               RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
               ExtractStridedMetadataOpReinterpretCastFolder,
               ExtractStridedMetadataOpCastFolder,
               ExtractStridedMetadataOpExtractStridedMetadataFolder>(
      patterns.getContext());
}
@@ -923,6 +1010,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
               ExtractStridedMetadataOpSubviewFolder,
               RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
               ExtractStridedMetadataOpReinterpretCastFolder,
               ExtractStridedMetadataOpCastFolder,
               ExtractStridedMetadataOpExtractStridedMetadataFolder>(
      patterns.getContext());
}
+125 −0
Original line number Diff line number Diff line
@@ -1369,3 +1369,128 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
  return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
      memref<i32>, index, index, index, index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// When we apply the transformation the resulting offset, sizes and strides
// should come straight from the inputs of the cast.
// Additionally the folder on extract_strided_metadata should propagate the
// static information.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast
//  CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
//
//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
func.func @extract_strided_metadata_of_cast(
  %arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
  -> (memref<i32>, index,
      index, index,
      index, index) {

  %cast =
    memref.cast %arg :
      memref<3x?xi32, strided<[4, ?], offset: ?>> to
      memref<?x?xi32, strided<[?, ?], offset: ?>>

  %base, %base_offset, %sizes:2, %strides:2 =
    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
    -> memref<i32>, index,
       index, index,
       index, index

  return %base, %base_offset,
    %sizes#0, %sizes#1,
    %strides#0, %strides#1 :
      memref<i32>, index,
      index, index,
      index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
// in the destination type.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
//  CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
//
//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
//   CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
//   CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
//       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
//       CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
func.func @extract_strided_metadata_of_cast_w_csts(
  %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
  -> (memref<i32>, index,
      index, index,
      index, index) {

  %cast =
    memref.cast %arg :
      memref<?x?xi32, strided<[?, ?], offset: ?>> to
      memref<4x?xi32, strided<[?, 18], offset: 25>>

  %base, %base_offset, %sizes:2, %strides:2 =
    memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
    -> memref<i32>, index,
       index, index,
       index, index

  return %base, %base_offset,
    %sizes#0, %sizes#1,
    %strides#0, %strides#1 :
      memref<i32>, index,
      index, index,
      index, index
}
// -----

// Check that we don't simplify extract_strided_metadata of
// cast when the source of the cast is unranked.
// Unranked memrefs cannot feed into extract_strided_metadata operations.
// Note: Technically we could still fold the sizes and strides.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
//  CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
//
//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
//
//       CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
func.func @extract_strided_metadata_of_cast_unranked(
  %arg : memref<*xi32>)
  -> (memref<i32>, index,
      index, index,
      index, index) {

  %cast =
    memref.cast %arg :
      memref<*xi32> to
      memref<?x?xi32, strided<[?, ?], offset: ?>>

  %base, %base_offset, %sizes:2, %strides:2 =
    memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
    -> memref<i32>, index,
       index, index,
       index, index

  return %base, %base_offset,
    %sizes#0, %sizes#1,
    %strides#0, %strides#1 :
      memref<i32>, index,
      index, index,
      index, index
}