Commit ea1e3369 authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Linalg] Introduce folding patterns to remove certain MemRefCastOp

Summary:
Canonicalization and folding patterns in StandardOps may interfere with the needs
of Linalg. This revision introduces specific foldings for dynamic memrefs that can
be proven to be static.

Very concretely:

Determines whether it is possible to fold it away in the parent Linalg op:

```mlir
  %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
  %2 = linalg.slice %1 ... : memref<?x?xf32> ...
  // or
  %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
         to memref<?x?xf32>
  linalg.generic(%1 ...) : memref<?x?xf32> ...
```

into

```mlir
  %2 = linalg.slice %0 ... : memref<8x16xf32> ...
  // or
  linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
```

Reviewers: ftynse, aartbik, jsetoain, tetuante, asaadaldien

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73565
parent 02adfb51
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -117,6 +117,8 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
    static StringRef getReassociationAttrName() { return "reassociation"; }
    MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
  }];

  let hasFolder = 1;
}

def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
@@ -188,6 +190,8 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
      return res;
    }
  }];

  let hasFolder = 1;
}

def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
@@ -222,6 +226,8 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
    static StringRef getPermutationAttrName() { return "permutation"; }
    ShapedType getShapedType() { return view().getType().cast<ShapedType>(); }
  }];

  let hasFolder = 1;
}

def Linalg_YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
+19 −0
Original line number Diff line number Diff line
@@ -270,6 +270,8 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
    }
  }];
  let verifier = [{ return ::verify(*this); }];

  let hasFolder = 1;
}

def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
@@ -287,6 +289,8 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
    }
  }];
  let verifier = [{ return ::verify(*this); }];

  let hasFolder = 1;
}

def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
@@ -302,6 +306,8 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
        StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
    }
  }];

  let hasFolder = 1;
}

def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
@@ -319,6 +325,8 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
      return ArrayAttr::get(iters, ctx);
    }
  }];

  let hasFolder = 1;
}

def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
@@ -337,6 +345,8 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
      return ArrayAttr::get(iters, ctx);
    }
  }];

  let hasFolder = 1;
}

def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
@@ -406,7 +416,10 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
        .cast<IntegerAttr>().getValue().getSExtValue();
    }
  }];

  let verifier = [{ return ::verify(*this); }];

  let hasFolder = 1;
}

def LinalgOperand: Type<
@@ -583,7 +596,10 @@ def GenericOp : GenericOpBase<"generic"> {
    tensor SSA values are expected to be useful and will be added in the near
    future.
  }];

  let verifier = [{ return ::verify(*this); }];

  let hasFolder = 1;
}

def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
@@ -710,7 +726,10 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
    tensor SSA values are expected to be useful and will be added in the near
    future.
  }];

  let verifier = [{ return ::verify(*this); }];

  let hasFolder = 1;
}

#endif // LINALG_STRUCTURED_OPS
+135 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@

#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -31,6 +32,89 @@
using namespace mlir;
using namespace mlir::linalg;

/// Determines whether it is possible to fold it away in the parent Linalg op:
///
/// ```mlir
///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
///   %2 = linalg.slice %1 ... : memref<?x?xf32> ...
///   // or
///   %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
///          to memref<?x?xf32>
///   linalg.generic(%1 ...) : memref<?x?xf32> ...
/// ```
///
/// into
///
/// ```mlir
///   %2 = linalg.slice %0 ... : memref<8x16xf32> ...
///   // or
///   linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// ```
///
static bool canFold(MemRefCastOp castOp) {
  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();

  // If we don't have MemRefType as source and destination, bail out.
  if (!sourceType || !resultType)
    return false;

  // If resultType has a map, it needs to be the same as the source type to
  // canonicalize.
  if (!resultType.getAffineMaps().empty() &&
      sourceType.getAffineMaps() != resultType.getAffineMaps())
    return false;

  // Ensure that:
  //   1. source is static
  //   2. source and target have the same rank (will be extended when needed)
  //   3. if result is partially static, ensure sizes match.
  if (!sourceType.hasStaticShape() ||
      sourceType.getRank() != resultType.getRank())
    return false;

  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
    auto sourceSize = std::get<0>(it);
    auto resultSize = std::get<1>(it);
    if (ShapedType::isDynamic(resultSize))
      continue;
    if (sourceSize != resultSize)
      return false;
  }

  // If source has a map, it can only canonicalize if it is the canonical
  // strided layout map.
  if (sourceType.getAffineMaps().empty())
    return true;

  int64_t offset;
  SmallVector<int64_t, 4> strides;
  auto res = getStridesAndOffset(sourceType, strides, offset);
  (void)res;
  assert(succeeded(res));
  auto stridedMap =
      makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
  AffineMap sourceMap = sourceType.getAffineMaps().front();
  return sourceMap == stridedMap;
}

/// This is a common class used for patterns of the form
/// ```
///    someop(memrefcast) -> someop
/// ```
/// It folds the source of any memref_cast into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
  bool folded = false;
  for (OpOperand &operand : op->getOpOperands()) {
    auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
    if (castOp && canFold(castOp)) {
      operand.set(castOp.getOperand());
      folded = true;
    }
  }
  return success(folded);
}

///////////////////// Operations defined with Tablegen /////////////////////////
// For such operations that do not correspond to library calls (i.e. defined in
// LinalgOps.td), we define an overloaded `print` function and a
@@ -1077,3 +1161,54 @@ ArrayAttr mlir::linalg::MatmulOp::indexing_maps() {
ArrayAttr mlir::linalg::MatvecOp::indexing_maps() {
  return getIndexingMaps(getOperation());
}

// TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate
// with Tablegen. This seems a desirable property in the context of OpInterfaces
// where a Linalg "named" op **isa** LinalgOp.
LogicalResult ConvOp::fold(ArrayRef<Attribute>,
                           SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
LogicalResult CopyOp::fold(ArrayRef<Attribute>,
                           SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
LogicalResult DotOp::fold(ArrayRef<Attribute>,
                          SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
LogicalResult FillOp::fold(ArrayRef<Attribute>,
                           SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
                              SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
                                     SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
                             SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
                             SmallVectorImpl<OpFoldResult> &) {
  return foldMemRefCast(*this);
}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
  if (succeeded(foldMemRefCast(*this)))
    return getResult();
  return {};
}
OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
  if (succeeded(foldMemRefCast(*this)))
    return getResult();
  return {};
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
  if (succeeded(foldMemRefCast(*this)))
    return getResult();
  return {};
}
+20 −0
Original line number Diff line number Diff line
// RUN: mlir-opt %s -canonicalize | FileCheck %s

// CHECK-LABEL: func @memref_cast(
func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %c8 = constant 8 : index
  %c16 = constant 16 : index
  %1 = alloc (%b) : memref<?xi8>
  %2 = view %1[][] : memref<?xi8> to memref<16x16xf32>
  %3 = memref_cast %2 : memref<16x16xf32> to memref<?x?xf32>
  %r0 = linalg.range %c0:%c8:%c1 : !linalg.range

  // CHECK:  linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
  %4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>

  // CHECK:  linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>
  linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
  return %4: memref<?x?xf32>
}