Commit 2140a973 authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Linalg] Extend generic ops to allow tensors

    Summary:
    This diff adds support to allow `linalg.generic` and
    `linalg.indexed_generic` to take tensor input and output
    arguments.

    The subset of output tensor operand types must appear
    verbatim in the result types after an arrow. The parser,
    printer and verifier are extended to accomodate this
    behavior.

    The Linalg operations now support variadic ranked tensor
    return values. This extension exhibited issues with the
    current handling of NativeCall in RewriterGen.cpp. As a
    consequence, an explicit cast to `SmallVector<Value, 4>`
    is added in the proper place to support the new behavior
    (better suggestions are welcome).

    Relevant cleanups and name uniformization are applied.

    Relevant invalid and roundtrip test are added.

    Reviewers: mehdi_amini, rriddle, jpienaar, antiagainst, ftynse

    Subscribers: burmako, shauheen, llvm-commits

    Tags: #llvm

    Differential Revision: https://reviews.llvm.org/D72022
parent 9d49e5c0
Loading
Loading
Loading
Loading
+10 −8
Original line number Diff line number Diff line
@@ -19,9 +19,10 @@ def Linalg_Dialect : Dialect {
  let name = "linalg";
  let description = [{
    The `linalg` dialect groups together a set of types, operations and
    transformations that are useful to implement a structured abstraction where
    ops can lower to scalar load/store and operations or to more general library
    calls.
    transformations that are useful to implement a structured abstraction on
    buffers and tensors. These abstractions are useful for transformations and
    can lower to scalar load/store and other operations or to more general
    library calls.

    The `linalg` dialect manipulates the following types and operations:

@@ -67,12 +68,13 @@ def Linalg_Dialect : Dialect {
    A set of payload carrying operations that implement the [structured ops](
    https://docs.google.com/presentation/d/1P-j1GrH6Q5gLBjao0afQ-GfvcAeF-QU4GXXeSy0eJ9I/edit#slide=id.p
    )
    abstraction on buffers. `linalg` has `2` generic operations `linalg.generic`
    and `linalg.indexed_generic` for expressing custom operations. This is
    subject to further evolution as transformations and analyses continue to be
    developed.
    abstraction on tensors and buffers. `linalg` has `2` generic operations
    `linalg.generic` and `linalg.indexed_generic` for expressing custom
    operations.
    This is subject to further evolution as transformations and analyses
    continue to be developed.

    Additionally, `linalg` provides some common named operations:
    Additionally, `linalg` provides some commonly named operations:

        * `linalg.copy`,
        * `linalg.fill`,
+9 −8
Original line number Diff line number Diff line
@@ -59,7 +59,8 @@ def Linalg_RangeOp :
}

def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
    Arguments<(ins AnyStridedMemRef:$view, Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>,
    Arguments<(ins AnyStridedMemRef:$view,
                   Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>,
    Results<(outs AnyStridedMemRef)> {
  let summary = "Produce a rank-reduced `subview` of a base `view`.";
  let description = [{
@@ -108,11 +109,11 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,

  let extraClassDeclaration = [{
    enum { FirstIndexingOperand = 1 };
    unsigned getRank() { return getViewType().getRank(); }
    Type getElementType() { return getViewType().getElementType(); }
    MemRefType getViewType() { return getType().cast<MemRefType>(); }
    unsigned getRank() { return getShapedType().getRank(); }
    Type getElementType() { return getShapedType().getElementType(); }
    ShapedType getShapedType() { return getType().cast<ShapedType>(); }
    unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
    MemRefType getBaseViewType() { return view()->getType().cast<MemRefType>(); }
    ShapedType getBaseViewType() { return view()->getType().cast<ShapedType>();}

    // Get the underlying indexing at a given rank.
    Value indexing(unsigned rank) { return *(indexings().begin() + rank); }
@@ -131,7 +132,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
    Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>,
    Results<(outs AnyStridedMemRef)> {
  let summary = "transpose operation produces a new strided memref (metadata-only)";
  let summary = "`transpose` produces a new strided memref (metadata-only)";
  let description = [{
    The `linalg.transpose` op produces a strided memref whose sizes and strides
    are a permutation of the original `view`. This is a pure metadata
@@ -151,14 +152,14 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
  let verifier = [{
    if (!permutation().isPermutation())
      return emitOpError("expected a permutation map");
    if (permutation().getNumDims() != getViewType().getRank())
    if (permutation().getNumDims() != getShapedType().getRank())
      return emitOpError("expected a permutation map of same rank as the view");
    return success();
  }];

  let extraClassDeclaration = [{
    static StringRef getPermutationAttrName() { return "permutation"; }
    MemRefType getViewType() { return view()->getType().cast<MemRefType>(); }
    ShapedType getShapedType() { return view()->getType().cast<ShapedType>(); }
  }];
}

+77 −11
Original line number Diff line number Diff line
@@ -89,23 +89,32 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
      "Value ", "getOutput", (ins "unsigned":$i)
    >,
    InterfaceMethod<[{
        Query the index of the given input value, or `None` if the value is not
        an input.
        Return the index of the given input value `v`, or `None` if the value is
        not an input.
      }],
      "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$view)
      "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v)
    >,
    InterfaceMethod<[{
        Query the index of the given view value, or `None` if the value is not
        an view.
        a view.
      }],
      "llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view)
    >,
    InterfaceMethod<[{
        Query the type of the input view at the given index.
      }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>,
        Query the type of the input shape at the given index.
      }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
    InterfaceMethod<[{
        Query the type of the output view at the given index.
      }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>,
      }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
    InterfaceMethod<[{
        Query whether the op has only MemRef input and outputs.
      }], "bool", "hasBufferSemantics">,
    InterfaceMethod<[{
        Query the subset of input operands that are of ranked tensor type.
      }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
    InterfaceMethod<[{
        Query the subset of output operands that are of ranked tensor type.
      }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,

    StaticInterfaceMethod<[{
        Create an operation of the current type with the given location,
@@ -340,7 +349,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
    ArrayAttr iterator_types() {
      // Outer parallel loops are always the number of output dimensions; i.e.
      // [ b, xs, q] in the TF notation above.
      unsigned nPar = getOutputViewType(0).getRank();
      unsigned nPar = getOutputShapedType(0).getRank();
      unsigned nRed = getNumInputFeatureDimensions();
      // Window loops are a special kind of reduction that is never tiled or
      // parallelized across; i.e. [zs] in the TF notation above whose number
@@ -374,8 +383,17 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
  let verifier = [{ return ::verify(*this); }];
}

def LinalgOperand: Type<
  Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>;

class LinalgOperandOfRank<int rank>: Type<
  And<[
    LinalgOperand.predicate,
    CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
  >>;

class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
  let arguments = (ins Variadic<AnyStridedMemRef>:$views,
  let arguments = (ins Variadic<LinalgOperand>:$views,
                   I64Attr:$args_in,
                   I64Attr:$args_out,
                   AffineMapArrayAttr:$indexing_maps,
@@ -383,6 +401,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
                   OptionalAttr<StrAttr>:$doc,
                   OptionalAttr<FlatSymbolRefAttr>:$fun,
                   OptionalAttr<StrAttr>:$library_call);
  let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
  let regions = (region AnyRegion:$region);
  let extraClassDeclaration = [{
    SmallVector<StringRef, 8> linalgTraitAttrNames() {
@@ -511,6 +530,28 @@ def GenericOp : GenericOpBase<"generic"> {
      }
    }
    ```

    To allow progressive lowering from the value world (a.k.a tensor values) to
    the buffer world (a.k.a memref values), a `linalg.generic` op accepts
    mixing input and output ranked tensor values with input and output memrefs.

    ```mlir
      %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
        tensor<?x?xf32>,
        memref<?x?xf32, stride_specification>,
        tensor<?x?xf32>
        -> (tensor<?x?xf32>)
    ```

    In this case, the number of return values must match the number of output
    tensor arguments. The semantics is that the `linalg.generic` op
    produces (i.e. allocates and fills) its return values.
    Tensor values must be legalized by a buffer allocation pass before most
    transformations can be applied. In particular, transformations that create
    control flow around linalg.generic operations are not expected to mix with
    tensors because SSA values do not escape naturally. Still, transformations
    and rewrites that take advantage of tensor SSA values are expected to be
    useful and will be added in the near future.
  }];
  let verifier = [{ return ::verify(*this); }];
}
@@ -555,9 +596,11 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
    Example:
    Defining a #matmul_trait attribute in MLIR can be done as follows:
      ```mlir
        func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32)
        func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
                  %a: f32, %b: f32, %c: f32)
          -> f32
        {
          "some_optional_condition"(%offset_m, %offset_n, %offset_k)
          %d = mulf %a, %b: f32
          %e = addf %c, %d: f32
          return %e: f32
@@ -587,7 +630,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {

    This may lower to either:
      ```mlir
        call @linalg_matmul(%A, %B, %C) :
        call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) :
          (memref<?x?xf32, stride_specification>,
           memref<?x?xf32, stride_specification>,
           memref<?x?xf32, stride_specification>)
@@ -609,6 +652,29 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
      }
    }
    ```

    To allow progressive lowering from the value world (a.k.a tensor values) to
    the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
    accepts mixing input and output ranked tensor values with input and output
    memrefs.

    ```mlir
      %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes}
      : tensor<?x?xf32>,
        memref<?x?xf32, stride_specification>,
        tensor<?x?xf32>
        -> (tensor<?x?xf32>)
    ```

    In this case, the number of return values must match the number of output
    tensor arguments. The semantics is that the `linalg.indexed_generic` op
    produces (i.e. allocates and fills) its return values.
    Tensor values must be legalized by a buffer allocation pass before most
    transformations can be applied. In particular, transformations that create
    control flow around linalg.generic operations are not expected to mix with
    tensors because SSA values do not escape naturally. Still, transformations
    and rewrites that take advantage of tensor SSA values are expected to be
    useful and will be added in the near future.
  }];
  let verifier = [{ return ::verify(*this); }];
}
+56 −34
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ namespace OpTrait {
namespace linalg {

/// This class provides the API for ops that are known to have a specified
/// number of inputs, all passed as operands. This is used as a trait like this:
/// number of inputs, all passed as operands. Use as a trait as follows:
///
///   class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> {
///
@@ -34,7 +34,7 @@ public:
};

/// This class provides the API for ops that are known to have a specified
/// number of inputs, all passed as operands. This is used as a trait like this:
/// number of outputs, all passed as operands. Use as a trait as follows:
///
///   class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> {
///
@@ -47,79 +47,101 @@ public:
  };
};

/// This class provides the API for ops that are known to operate on views. This
/// trait must be used in conjunction with an op definition or a trait that
/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a
/// trait like this:
/// This class provides the API for structured ops that are known to operate on
/// buffers or tensors. This trait must be used in conjunction with an op
/// definition or a trait that provides the methods `getNumInputs` and
/// `getNumOutputs`. Use as a trait as follows:
///
///   class DotOp : public Op<DotOp, OpTrait::ViewTrait> {
///   class DotOp : public Op<DotOp, OpTrait::StructuredOpTraits> {
///
template <typename ConcreteType>
class StructuredOpTraits
    : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
private:
  /// Return the number of input views. For internal use only.
  /// Return the number of inputs. For internal use only.
  unsigned nInputs() {
    return cast<ConcreteType>(this->getOperation()).getNumInputs();
  }
  /// Return the number of input views. For internal use only.
  /// Return the number of outputs. For internal use only.
  unsigned nOutputs() {
    return cast<ConcreteType>(this->getOperation()).getNumOutputs();
  }

public:
  /// Return the `i`-th input view.
  /// Return the `i`-th input value.
  Value getInput(unsigned i) {
    assert(i < nInputs());
    return this->getOperation()->getOperand(i);
  }
  /// Return the index of `view` in the list of input views if found, llvm::None
  /// Return the index of `value` in the list of inputs if found, llvm::None
  /// otherwise.
  Optional<unsigned> getIndexOfInput(Value view) {
    auto it = llvm::find(getInputs(), view);
  Optional<unsigned> getIndexOfInput(Value value) {
    auto it = llvm::find(getInputs(), value);
    if (it != getInputs().end())
      return it - getInputs().begin();
    return llvm::None;
  }
  /// Return the `i`-th input view type.
  MemRefType getInputViewType(unsigned i) {
    return getInput(i)->getType().template cast<MemRefType>();
  /// Return the `i`-th input buffer type.
  ShapedType getInputShapedType(unsigned i) {
    return getInput(i)->getType().template cast<ShapedType>();
  }
  /// Return the range over input views.
  /// Return the range over inputs.
  Operation::operand_range getInputs() {
    auto range = this->getOperation()->getOperands();
    return {range.begin(), range.begin() + nInputs()};
  }
  /// Return the `i`-th output view.
  /// Return the `i`-th output.
  Value getOutput(unsigned i) {
    return this->getOperation()->getOperand(nInputs() + i);
  }
  /// Return the index of `view` in the list of output views if found,
  /// Return the index of `value` in the list of output values if found,
  /// llvm::None otherwise.
  Optional<unsigned> getIndexOfOutput(Value view) {
    auto it = llvm::find(getOutputs(), view);
  Optional<unsigned> getIndexOfOutput(Value value) {
    auto it = llvm::find(getOutputs(), value);
    if (it != getOutputs().end())
      return it - getOutputs().begin();
    return llvm::None;
  }
  /// Return the `i`-th output view type.
  MemRefType getOutputViewType(unsigned i) {
    return getOutput(i)->getType().template cast<MemRefType>();
  }
  /// Return the range over output views.
  /// Return the `i`-th output buffer type.
  ShapedType getOutputShapedType(unsigned i) {
    return getOutput(i)->getType().template cast<ShapedType>();
  }
  /// Query whether the op has only MemRef input and outputs.
  bool hasBufferSemantics() {
    return this->getOperation()->getNumResults() == 0 &&
           llvm::all_of(getInputsAndOutputs(),
                        [](Value v) { return v.getType().isa<MemRefType>(); });
  }
  /// Query the subset of input operands that are of ranked tensor type.
  SmallVector<RankedTensorType, 4> getInputTensorTypes() {
    SmallVector<RankedTensorType, 4> res;
    for (Type type : getInputs().getTypes())
      if (auto t = type.template dyn_cast<RankedTensorType>())
        res.push_back(t);
    return res;
  }
  /// Query the subset of output operands that are of ranked tensor type.
  SmallVector<RankedTensorType, 4> getOutputTensorTypes() {
    SmallVector<RankedTensorType, 4> res;
    for (Type type : getOutputs().getTypes())
      if (auto t = type.template dyn_cast<RankedTensorType>())
        res.push_back(t);
    return res;
  }
  /// Return the range over outputs.
  Operation::operand_range getOutputs() {
    auto range = this->getOperation()->getOperands();
    return {range.begin() + nInputs(),
            range.begin() + getNumInputsAndOutputs()};
  }
  /// Return the number of input and output views.
  /// Return the number of inputs and outputs.
  unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
  /// Return the `i`-th view type.
  MemRefType getViewType(unsigned i) {
    return (i < nInputs()) ? getInputViewType(i)
                           : getOutputViewType(i - nInputs());
  /// Return the `i`-th buffer type.
  ShapedType getShapedType(unsigned i) {
    return (i < nInputs()) ? getInputShapedType(i)
                           : getOutputShapedType(i - nInputs());
  }
  /// Return the range over input and output views.
  /// Return the range over inputs and outputs.
  Operation::operand_range getInputsAndOutputs() {
    auto range = this->getOperation()->getOperands();
    return {range.begin(), range.begin() + getNumInputsAndOutputs()};
@@ -144,8 +166,8 @@ public:
        cast<ConcreteType>(this->getOperation()).iterator_types());
  }
  static LogicalResult verifyTrait(Operation *op) {
    auto nViews = cast<ConcreteType>(op).getNumInputsAndOutputs();
    if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews)))
    auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputs();
    if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
      return failure();
    return success();
  }
+16 −12
Original line number Diff line number Diff line
@@ -84,25 +84,29 @@ class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
  "  return matchFailure();">;

//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
// Linalg to vector patterns precondition and DRR.
//===----------------------------------------------------------------------===//
class VectorizeGenericLinalgOp<string OpType> : NativeCodeCall<
  "if (failed(vectorizeGenericLinalgOp($_builder, op))) " #
  "  return matchFailure();">;
def PreconditionVectorizeGenericLinalgOp : CPred<
  "succeeded(vectorizeGenericLinalgOpPrecondition(op))">;
def VectorizeGenericLinalgOp : NativeCodeCall<
  "vectorizeGenericLinalgOp($_builder, op)">;

//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
// Linalg generic permutation patterns precondition and DRR.
//===----------------------------------------------------------------------===//
class PreconditionPermuteGenericLinalgOp<list<int> permutation> : CPred<
  "succeeded(permuteGenericLinalgOpPrecondition(op, {" #
  StrJoinInt<permutation>.result # "}))">;
class PermuteGenericLinalgOp<list<int> permutation, string value> :
  NativeCodeCall<
    "if (failed(permuteGenericLinalgOp($_builder, op, {" #
    StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
    "  return matchFailure();">;
    "permuteGenericLinalgOp($_builder, op, {" # StrJoinInt<permutation>.result #
    "}, \"" # value # "\")">;

//===----------------------------------------------------------------------===//
// Linalg promote subview operands.
// Linalg promote subview operands precondition and DRR.
//===----------------------------------------------------------------------===//
class PromoteSubviewsLinalgOp<string OpType> : NativeCodeCall<
  "if (failed(promoteSubviewsLinalgOp($_builder, op))) " #
  "  return matchFailure();">;
def PreconditionPromoteSubviewsLinalgOp : CPred<
  "succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
def PromoteSubviewsLinalgOp : NativeCodeCall<
  "promoteSubviewsLinalgOp($_builder, op)">;
#endif // LINALG_TRANSFORMS
Loading