Commit 5bb8d28e authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Linalg] Add tensor support to Linalg EDSC Builders

Summary:
This diff extends the Linalg EDSC builders so we can easily create mixed
tensor/buffer linalg.generic ops. This is expected to be useful for
HLO -> Linalg lowering.

The StructuredIndexed struct is made to derive from ValueHandle and can
now capture a type + indexing expressions. This is used to represent return
tensors.

Pointwise unary and binary builders are extended to allow both output buffers
and return tensors. This has implications on the number of region arguments.

Reviewers: ftynse, hanchung, asaadaldien

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73149
parent f55b033c
Loading
Loading
Loading
Loading
+40 −14
Original line number Diff line number Diff line
@@ -94,37 +94,63 @@ inline StringRef toString(IterType t) {
  llvm_unreachable("Unsupported IterType");
}

/// A StructuredIndexed represents a captured value that can be indexed and
/// passed to the `makeGenericLinalgOp`. It allows writing intuitive index
/// expressions such as:
/// A StructuredIndexed represents an indexable quantity that is either:
/// 1. a captured value, which is suitable for buffer and tensor operands, or;
/// 2. a captured type, which is suitable for tensor return values.
///
/// A StructuredIndexed itself is indexed and passed to `makeGenericLinalgOp`.
/// It enable an idiomatic syntax for index expressions such as:
///
/// ```
///      StructuredIndexed A(vA), B(vB), C(vC);
///      StructuredIndexed A(buffer_or_tensor_value), B(buffer_or_tensor_value),
///        C(buffer_value_or_tensor_type);
///      makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
/// ```
struct StructuredIndexed {
  StructuredIndexed(Value v) : value(v) {}
struct StructuredIndexed : public ValueHandle {
  StructuredIndexed(Type type) : ValueHandle(type) {}
  StructuredIndexed(Value value) : ValueHandle(value) {}
  StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {}
  StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
    return StructuredIndexed(value, indexings);
    return StructuredIndexed(*this, indexings);
  }

  operator Value() const /* implicit */ { return value; }
  ArrayRef<AffineExpr> getExprs() { return exprs; }

private:
  StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
      : ValueHandle(t), exprs(indexings.begin(), indexings.end()) {
    assert(t.isa<RankedTensorType>() && "RankedTensor expected");
  }
  StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
      : value(v), exprs(indexings.begin(), indexings.end()) {
    assert(v.getType().isa<MemRefType>() && "MemRefType expected");
      : ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
    assert((v.getType().isa<MemRefType>() ||
            v.getType().isa<RankedTensorType>()) &&
           "MemRef or RankedTensor expected");
  }
  StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
      : StructuredIndexed(v.getValue(), indexings) {}
  StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
      : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}

  Value value;
  SmallVector<AffineExpr, 4> exprs;
};

inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}

/// Build a `linalg.generic` op with the specified `inputs`, `outputs` and
/// `region`.
///
/// `otherValues` and `otherAttributes` may be passed and will be appended as
/// operands and attributes respectively.
///
/// Prerequisites:
/// =============
///
/// 1. `inputs` may contain StructuredIndexed that capture either buffer or
/// tensor values.
/// 2. `outputs` may contain StructuredIndexed that capture either buffer values
/// or tensor types. If both buffer values and tensor types are present, then
/// all buffer values must appear before any tensor type. Without this
/// restriction output tensor results would need to be reordered, which would
/// result in surprising behavior when combined with region definition.
Operation *makeGenericLinalgOp(
    ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
    ArrayRef<StructuredIndexed> outputs,
@@ -189,7 +215,7 @@ Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2,
                                StructuredIndexed O);

/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = max(I!, I2)`. The client is responsible for specifying the
/// computes `O = max(I1, I2)`. The client is responsible for specifying the
/// proper indexings when creating the StructuredIndexed.
Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
                                StructuredIndexed O);
+1 −0
Original line number Diff line number Diff line
@@ -339,6 +339,7 @@ public:

  /// Implicit conversion useful for automatic conversion to Container<Value>.
  operator Value() const { return getValue(); }
  operator Type() const { return getType(); }
  operator bool() const { return hasValue(); }

  /// Generic mlir::Op create. This is the key to being extensible to the whole
+26 −3
Original line number Diff line number Diff line
@@ -131,6 +131,10 @@ Operation *mlir::edsc::makeGenericLinalgOp(
    ArrayRef<StructuredIndexed> outputs,
    function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
    ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
  for (unsigned i = 0, e = outputs.size(); i + 1 < e; ++i)
    assert(!(outputs[i].getType().isa<RankedTensorType>() &&
             outputs[i + 1].getType().isa<MemRefType>()) &&
           "output tensors must be passed after output buffers");
  auto &builder = edsc::ScopedContext::getBuilder();
  auto *ctx = builder.getContext();
  unsigned nInputs = inputs.size();
@@ -154,7 +158,11 @@ Operation *mlir::edsc::makeGenericLinalgOp(
  SmallVector<Value, 4> values;
  values.reserve(nViews);
  values.append(inputs.begin(), inputs.end());
  values.append(outputs.begin(), outputs.end());
  std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(values),
               [](StructuredIndexed s) { return s.hasValue(); });
  SmallVector<Type, 4> types;
  std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(types),
               [](StructuredIndexed s) { return !s.hasValue(); });

  auto iteratorStrTypes = functional::map(toString, iteratorTypes);
  // clang-format off
@@ -162,7 +170,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
      edsc::ScopedContext::getBuilder()
          .create<linalg::GenericOp>(
              edsc::ScopedContext::getLocation(),
              ArrayRef<Type>{}, // TODO(ntv): support tensors
              types,
              values,
              IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
              IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
@@ -210,6 +218,14 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
                                             StructuredIndexed O) {
  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
                                           edsc::IterType::Parallel);
  if (O.getType().isa<RankedTensorType>()) {
    auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
      assert(args.size() == 1 && "expected 1 block arguments");
      ValueHandle a(args[0]);
      linalg_yield(unaryOp(a));
    };
    return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
  }
  auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
    assert(args.size() == 2 && "expected 2 block arguments");
    ValueHandle a(args[0]);
@@ -220,7 +236,6 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,

Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
                                                  StructuredIndexed O) {
  ;
  using edsc::intrinsics::tanh;
  UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
  return linalg_pointwise(unOp, I, O);
@@ -233,6 +248,14 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
                                             StructuredIndexed O) {
  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
                                           edsc::IterType::Parallel);
  if (O.getType().isa<RankedTensorType>()) {
    auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
      assert(args.size() == 2 && "expected 2 block arguments");
      ValueHandle a(args[0]), b(args[1]);
      linalg_yield(binaryOp(a, b));
    };
    return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
  }
  auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
    assert(args.size() == 3 && "expected 3 block arguments");
    ValueHandle a(args[0]), b(args[1]);
+43 −0
Original line number Diff line number Diff line
@@ -871,6 +871,49 @@ TEST_FUNC(linalg_pointwise_test) {
  f.erase();
}

// clang-format off
// CHECK-LABEL: func @linalg_pointwise_mixed_tensors
//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
//       CHECK:       addf
//       CHECK:     }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
//       CHECK:       cmpf "ogt"
//       CHECK:       select
//       CHECK:   }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
//       CHECK:   linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
// CHECK-SAME:      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME:      iterator_types = ["parallel", "parallel"]}
//       CHECK:     tanh
//       CHECK:   }: tensor<?x?xf32> -> tensor<?x?xf32>
// clang-format on
TEST_FUNC(linalg_pointwise_mixed_tensors_test) {
  using namespace edsc;
  using namespace edsc::ops;

  auto f32Type = FloatType::getF32(&globalContext());
  auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
  auto tensorType = RankedTensorType::get({-1, -1}, f32Type);
  auto f = makeFunction("linalg_pointwise_mixed_tensors", {},
                        {tensorType, memrefType});

  OpBuilder builder(f.getBody());
  ScopedContext scope(builder, f.getLoc());
  ValueHandle A(f.getArgument(0)), B(f.getArgument(1));
  AffineExpr i, j;
  bindDims(&globalContext(), i, j);
  StructuredIndexed SA(A), SB(B), SC(tensorType);
  linalg_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
  linalg_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
  linalg_pointwise_tanh(SA({i, j}), SC({i, j}));

  f.print(llvm::outs());
  f.erase();
}

// clang-format off
// CHECK-LABEL: func @linalg_matmul
//       CHECK:   linalg.generic {args_in = 2 : i64, args_out = 1 : i64,