Commit a9d9aadc authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Linalg] NFC - Cleanup Linalg Declarative Transformations

Summary:
This is part of an ongoing cleanup and uniformization work.

This diff performs 3 types of cleanups:
1. Uniformize transformation names.
2. Replace all pattern operands that need not be captured by `$_`
3. Replace all usage of pattern captured op by the normalized `op` name (instead of positional parameters such as `$0`)

Reviewers: ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72081
parent 87a004d0
Loading
Loading
Loading
Loading
+15 −15
Original line number Diff line number Diff line
@@ -18,24 +18,24 @@ include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
include "mlir/Dialect/AffineOps/AffineOps.td"

def HasNoLinalgTransformMarker : CPred<[{
  !$0.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
  !op.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
}]>;

class HasLinalgTransformMarker<string str> : CPred<[{
  $0.getAttrOfType<StringAttr>(
  op.getAttrOfType<StringAttr>(
    LinalgTransforms::kLinalgTransformMarker) &&
  $0.getAttrOfType<StringAttr>(
  op.getAttrOfType<StringAttr>(
    LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;

class IsProducedByOpOfType<string str> :
  CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
  CPred<"isProducedByOpOfType<" # str # ">(op, $0)">;

class AffineMapDomainHasDim<int n> : CPred<[{
  $0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
  op.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
  cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;

class HasOperandsOfType<string type>: CPred<[{
    llvm::any_of($0.getOperands(),
    llvm::any_of(op.getOperands(),
        [](Value v) {
          return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
        })
@@ -50,7 +50,7 @@ class HasOperandsOfType<string type>: CPred<[{
// patterns.
class TileAndFuseLinalgOp<
    list<int> sizes, list<int> operandIndices, string value> : NativeCodeCall<
  "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" #
  "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
  StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
      " \"" # value # "\")))" #
  "  return matchFailure();">;
@@ -67,7 +67,7 @@ class TileAndFuseLinalgOp<
// of elements as `sizes`.
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
  NativeCodeCall<
    "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
    "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
    StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
    StrJoinInt<permutation>.result # "})))" #
    "  return matchFailure();">;
@@ -76,18 +76,18 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
// Linalg to loop patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToLoops<string OpType> : NativeCodeCall<
  "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " #
  "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
  "  return matchFailure();">;

class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
  "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
  "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
  "  return matchFailure();">;

//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
  "if (failed(vectorizeGenericOp($_builder, $0))) " #
class VectorizeGenericLinalgOp<string OpType> : NativeCodeCall<
  "if (failed(vectorizeGenericLinalgOp($_builder, op))) " #
  "  return matchFailure();">;

//===----------------------------------------------------------------------===//
@@ -95,14 +95,14 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
//===----------------------------------------------------------------------===//
class PermuteGenericLinalgOp<list<int> permutation, string value> :
  NativeCodeCall<
    "if (failed(permuteGenericLinalgOp($_builder, $0, {" #
    "if (failed(permuteGenericLinalgOp($_builder, op, {" #
    StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
    "  return matchFailure();">;

//===----------------------------------------------------------------------===//
// Linalg promote subview operands.
//===----------------------------------------------------------------------===//
class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall<
  "if (failed(linalgOpPromoteSubviews($_builder, $0))) " #
class PromoteSubviewsLinalgOp<string OpType> : NativeCodeCall<
  "if (failed(promoteSubviewsLinalgOp($_builder, op))) " #
  "  return matchFailure();">;
#endif // LINALG_TRANSFORMS
+3 −2
Original line number Diff line number Diff line
@@ -79,7 +79,8 @@ template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);

/// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter,
                                       Operation *op);

/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
@@ -88,7 +89,7 @@ LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
                                     StringRef linalgMarker);

/// Promote std.subviews feeding linalg operations
LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op);
LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op);

} // namespace linalg
} // namespace mlir
+3 −3
Original line number Diff line number Diff line
@@ -153,7 +153,7 @@ static bool isMatmul(linalg::GenericOp genericOp) {
         genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}

LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
                                                     Operation *op) {
  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
                       "]: Rewrite linalg op as vector.contract: "
@@ -223,7 +223,7 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
  return success();
}

LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
                                                    Operation *op) {
  LinalgOp linOp = dyn_cast<LinalgOp>(op);
  SetVector<Value> subViews;
+59 −59
Original line number Diff line number Diff line
@@ -19,11 +19,11 @@ include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
//===----------------------------------------------------------------------===//
// Test Linalg fusion patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$consumer $A, $B, $C),
          (TileAndFuseLinalgOp<[100, 150], [0], "L1"> $consumer),
def : Pat<(MatmulOp:$op $A, $_, $_),
          (TileAndFuseLinalgOp<[100, 150], [0], "L1">),
          [
            (Constraint<HasNoLinalgTransformMarker> $consumer),
            (Constraint<IsProducedByOpOfType<"MatmulOp">> $consumer, $A),
            (Constraint<HasNoLinalgTransformMarker>),
            (Constraint<IsProducedByOpOfType<"MatmulOp">> $A),
          ],
          // In the buffer world there is no use-def chains or dags so benefits
          // cannot be computed automatically from the length of the matched
@@ -36,91 +36,91 @@ def : Pat<(MatmulOp:$consumer $A, $B, $C),
//===----------------------------------------------------------------------===//
// Linalg tiling patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[2000, 3000, 4000], "L3"> $op),
def : Pat<(MatmulOp:$op $_, $_, $_),
          (TileLinalgOp<[2000, 3000, 4000], "L3">),
          [(Constraint<Or<[HasNoLinalgTransformMarker,
                           HasLinalgTransformMarker<"MEM">]>> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[200, 300, 400], "L2"> $op),
          [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[20, 30, 40], "L1"> $op),
          [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[2, 3, 4], "REG"> $op),
          [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
                           HasLinalgTransformMarker<"MEM">]>>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
          (TileLinalgOp<[200, 300, 400], "L2">),
          [(Constraint<HasLinalgTransformMarker<"L3">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
          (TileLinalgOp<[20, 30, 40], "L1">),
          [(Constraint<HasLinalgTransformMarker<"L2">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
          (TileLinalgOp<[2, 3, 4], "REG">),
          [(Constraint<HasLinalgTransformMarker<"L1">>)]>;

def : Pattern<(MatvecOp:$op $A, $b, $c),
              [(TileLinalgOp<[5, 6], "L1"> $op)],
              [(Constraint<HasNoLinalgTransformMarker> $op)]>;
def : Pattern<(MatvecOp:$op $_, $_, $_),
              [(TileLinalgOp<[5, 6], "L1">)],
              [(Constraint<HasNoLinalgTransformMarker>)]>;

def : Pattern<(DotOp:$op $a, $b, $c),
              [(TileLinalgOp<[8000], "L1"> $op)],
def : Pattern<(DotOp:$op $_, $_, $_),
              [(TileLinalgOp<[8000], "L1">)],
              [(Constraint<Or<[HasNoLinalgTransformMarker,
                               HasLinalgTransformMarker<"MEM">,
                               HasLinalgTransformMarker<"L3">,
                               HasLinalgTransformMarker<"L2">]>> $op)]>;
def : Pattern<(DotOp:$op $a, $b, $c),
              [(TileLinalgOp<[8], "REG"> $op)],
              [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
                               HasLinalgTransformMarker<"L2">]>>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
              [(TileLinalgOp<[8], "REG">)],
              [(Constraint<HasLinalgTransformMarker<"L1">>)]>;

//===----------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]> $op),
          [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]> $op),
          [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">> $op)]>;
def : Pat<(MatmulOp:$op $A, $B, $C),
          (TileLinalgOp<[20, 30, 40], "REG__with_perm__"> $op),
          [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
          (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>),
          [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
          (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>),
          [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
          (TileLinalgOp<[20, 30, 40], "REG__with_perm__">),
          [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;


def : Pattern<(MatvecOp:$op $A, $b, $c),
              [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]> $op)],
              [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
def : Pattern<(MatvecOp:$op $_, $_, $_),
              [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)],
              [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;

def : Pattern<(DotOp:$op $a, $b, $c),
              [(TileLinalgOp<[8000], "L1__with_perm__"> $op)],
              [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
def : Pattern<(DotOp:$op $a, $b, $c),
              [(TileLinalgOp<[8], "REG__with_perm__"> $op)],
              [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
              [(TileLinalgOp<[8000], "L1__with_perm__">)],
              [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
              [(TileLinalgOp<[8], "REG__with_perm__">)],
              [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;

//===----------------------------------------------------------------------===//
// Linalg to loops patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(DotOp:$op $a, $b, $c),
              [(LinalgOpToLoops<"DotOp"> $op)],
              [(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
def : Pattern<(DotOp:$op $_, $_, $_),
              [(LinalgOpToLoops<"DotOp">)],
              [(Constraint<HasLinalgTransformMarker<"REG">>)]>;

//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
              [(LinalgOpToVectorContraction<"GenericOp"> $op)],
              [(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
              [(VectorizeGenericLinalgOp<"GenericOp">)],
              [(Constraint<HasLinalgTransformMarker<"_marked_matmul_">>)]>;

//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
              (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
              (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
              [(Constraint<And<[HasNoLinalgTransformMarker,
                           AffineMapDomainHasDim<3>]>> $op)]>;
                           AffineMapDomainHasDim<3>]>>)]>;

def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
              (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
              (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
              [(Constraint<And<[HasNoLinalgTransformMarker,
                           AffineMapDomainHasDim<3>]>> $op)]>;
                           AffineMapDomainHasDim<3>]>>)]>;

//===----------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $A, $B, $C),
          (LinalgOpPromoteSubviews<"MatmulOp"> $op),
          [(Constraint<HasOperandsOfType<"SubViewOp">> $op),
          (Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
          (PromoteSubviewsLinalgOp<"MatmulOp">),
          [(Constraint<HasOperandsOfType<"SubViewOp">>),
          (Constraint<HasLinalgTransformMarker<"_promote_views_">>)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS