Commit f7137da1 authored by Matthias Springer's avatar Matthias Springer
Browse files

[mlir][linalg] Fix dim(iter_arg) canonicalization

Run a small analysis to see if the runtime type of the iter_arg is changing. Fold only if the runtime type stays the same. (Same as `DimOfIterArgFolder` in SCF.)

Differential Revision: https://reviews.llvm.org/D109299
parent 9da62d3e
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
@@ -2299,10 +2300,47 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
///   tensor.dim %y, %c0 : tensor<...>
/// }
///
/// Note: Dim ops are folded only if it can be proven that the runtime type of
/// the yielded value (in case of outputs) does not change with loop iterations.
template <typename OpTy>
struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
  using OpRewritePattern<OpTy>::OpRewritePattern;

  /// A simple, conservative analysis to determine if the loop is shape
  /// conserving. I.e., the type of the arg-th yielded value is the same as the
  /// type of the corresponding basic block argument of the loop.
  /// Note: This function handles only simple cases. Expand as needed.
  static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
    auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
    if (yieldOp.values().empty())
      // Tiled loop either has no outputs or is a "memref-based version". In
      // either case, the loop is shape conserving.
      return true;
    assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
           "arg is out of bounds");
    Value value = yieldOp.values()[arg];
    while (value) {
      if (value == loopOp.getRegionOutputArgs()[arg])
        return true;
      OpResult opResult = value.dyn_cast<OpResult>();
      if (!opResult)
        return false;

      using tensor::InsertSliceOp;
      value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
                  .template Case<InsertSliceOp>(
                      [&](InsertSliceOp op) { return op.dest(); })
                  .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
                    return isShapePreserving(loopOp, opResult.getResultNumber())
                               ? loopOp.outputs()[opResult.getResultNumber()]
                               : Value();
                  })
                  .Default([&](auto op) { return Value(); });
    }
    return false;
  }

  LogicalResult matchAndRewrite(OpTy dimOp,
                                PatternRewriter &rewriter) const final {
    auto src = dimOp.source().template dyn_cast<BlockArgument>();
@@ -2312,6 +2350,12 @@ struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
        dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp());
    if (!loopOp)
      return failure();
    unsigned numLoops = loopOp.getNumLoops();
    unsigned numInputArgs = loopOp.getRegionInputArgs().size();
    if (src.getArgNumber() >= numInputArgs + numLoops &&
        !isShapePreserving(loopOp,
                           src.getArgNumber() - numInputArgs - numLoops))
      return failure();

    auto inputArgs = loopOp.getRegionInputArgs();
    auto it1 = llvm::find(inputArgs, src);
+28 −0
Original line number Diff line number Diff line
@@ -904,6 +904,34 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {

// -----

// CHECK-LABEL: func @dim_of_tiled_loop_input_no_canonicalize(
//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
//       CHECK:   %[[c0:.*]] = constant 0 : index
//       CHECK:   linalg.tiled_loop {{.*}} outs (%[[o:.*]] =
//       CHECK:     %[[dim:.*]] = tensor.dim %[[o]], %[[c0]]
//       CHECK:     index_cast %[[dim]]
func @dim_of_tiled_loop_input_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
    -> tensor<?x?xf32> {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
      to (%d0, %d1) step (%c1, %c1)
      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
      outs (%out1 = %arg2 : tensor<?x?xf32>) {
    %inner_dim = tensor.dim %out1, %c0 : tensor<?x?xf32>
    %cast1 = std.index_cast %inner_dim : index to i32
    %cast2 = std.sitofp %cast1 : i32 to f32
    %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
    %slice = tensor.extract_slice %fill[0, 0][%s, %s][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    linalg.yield %slice : tensor<?x?xf32>
  }
  return %r : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @dim_of_tiled_loop_input(
//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
//       CHECK:   %[[c0:.*]] = constant 0 : index