Commit 89e19e8e authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Linalg] Add tensor support to Linalg EDSC Builders

Summary:
This diff extends the Linalg EDSC builders so we can easily create mixed
tensor/buffer linalg.generic ops. This is expected to be useful for
HLO -> Linalg lowering.

The `StructuredIndexed` struct is made to derive from `ValueHandle` and can
now capture a type + indexing expressions. This is used to represent return
tensors.

Pointwise unary and binary builders are extended to allow both output buffers
and return tensors. This has implications on the number of region arguments.

Reviewers: ftynse, herhut, hanchung, asaadaldien, stellaraccident

Reviewed By: asaadaldien

Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72863
parent e03ead67
Loading
Loading
Loading
Loading
+116 −18
Original line number Diff line number Diff line
@@ -110,11 +110,14 @@ struct StructuredIndexed {

  operator Value() const /* implicit */ { return value; }
  ArrayRef<AffineExpr> getExprs() { return exprs; }
  Type getType() { return value.getType(); }

private:
  StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
      : value(v), exprs(indexings.begin(), indexings.end()) {
    assert(v.getType().isa<MemRefType>() && "MemRefType expected");
    assert((v.getType().isa<MemRefType>() ||
            v.getType().isa<RankedTensorType>()) &&
           "MemRef or RankedTensor expected");
  }
  StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
      : StructuredIndexed(v.getValue(), indexings) {}
@@ -125,9 +128,21 @@ private:

inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}

/// Build a `linalg.generic` op with the specified inputs, outputs and region.
///
/// `otherValues` and `otherAttributes` may be passed and will be appended as
/// operands and attributes respectively.
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
Operation *makeGenericLinalgOp(
    ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
    ArrayRef<StructuredIndexed> outputs,
    ArrayRef<StructuredIndexed> outputs, ArrayRef<Type> resultTensorTypes = {},
    function_ref<void(ArrayRef<BlockArgument>)> regionBuilder =
        defaultRegionBuilder,
    ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {});
@@ -167,32 +182,77 @@ void macRegionBuilder(ArrayRef<BlockArgument> args);
/// with in-place semantics and parallelism.

/// Unary pointwise operation (with broadcast) entry point.
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
using UnaryPointwiseOpBuilder = function_ref<Value(ValueHandle)>;
Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
                            StructuredIndexed I, StructuredIndexed O);
                            StructuredIndexed I, StructuredIndexed O,
                            ArrayRef<Type> resultTensorTypes = {});

/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = tanh(I)`. The client is responsible for specifying the proper
/// indexings when creating the StructuredIndexed.
Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O);
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O,
                                 ArrayRef<Type> resultTensorTypes = {});

/// Binary pointwise operation (with broadcast) entry point.
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
using BinaryPointwiseOpBuilder = function_ref<Value(ValueHandle, ValueHandle)>;
Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
                            StructuredIndexed I1, StructuredIndexed I2,
                            StructuredIndexed O);
                            StructuredIndexed O,
                            ArrayRef<Type> resultTensorTypes = {});

/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = I1 + I2`. The client is responsible for specifying the proper
/// indexings when creating the StructuredIndexed.
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2,
                                StructuredIndexed O);
                                StructuredIndexed O,
                                ArrayRef<Type> resultTensorTypes = {});

/// Build a linalg.pointwise with all `parallel` iterators and a region that
/// computes `O = max(I!, I2)`. The client is responsible for specifying the
/// proper indexings when creating the StructuredIndexed.
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
                                StructuredIndexed O);
                                StructuredIndexed O,
                                ArrayRef<Type> resultTensorTypes = {});

// TODO(ntv): Implement more useful pointwise operations on a per-need basis.

@@ -203,11 +263,23 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2,
///    |
///    |  C(m, n) += A(m, k) * B(k, n)
/// ```
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC);
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC,
                         ArrayRef<Type> resultTensorTypes = {});

template <typename Container> Operation *linalg_matmul(Container values) {
template <typename Container>
Operation *linalg_matmul(Container values,
                         ArrayRef<Type> resultTensorTypes = {}) {
  assert(values.size() == 3 && "Expected exactly 3 values");
  return linalg_matmul(values[0], values[1], values[2]);
  assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor");
  return linalg_matmul(values[0], values[1], values[2], resultTensorTypes);
}

/// Build a linalg.generic, under the current ScopedContext, at the current
@@ -231,16 +303,28 @@ template <typename Container> Operation *linalg_matmul(Container values) {
///
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
//
// TODO(ntv) Extend convolution rank with some template magic.
Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO,
                            ArrayRef<Type> resultTensorTypes = {},
                            ArrayRef<int> strides = {},
                            ArrayRef<int> dilations = {});

template <typename Container>
Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
                            ArrayRef<int> dilations = {}) {
Operation *
linalg_conv_nhwc(Container values, ArrayRef<Type> resultTensorTypes = {},
                 ArrayRef<int> strides = {}, ArrayRef<int> dilations = {}) {
  assert(values.size() == 3 && "Expected exactly 3 values");
  return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations);
  assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor");
  return linalg_conv_nhwc(values[0], values[1], values[2], resultTensorTypes,
                          strides, dilations);
}

/// Build a linalg.generic, under the current ScopedContext, at the current
@@ -249,7 +333,7 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
///    (batch, dm, c, [h, w, ...], [kh, kw, ...]) =
///    |  (par, par, par, [par, par, ...], [red, red, ...])
///    |
///    | O(batch, [h, w, ...], c * depth_multiplier) +=
///    | O(batch, [h, w, ...], c * depthMultiplier) +=
///    |   I(batch,
///    |     [
///    |       stride[0] * h + dilations[0] * kh,
@@ -257,26 +341,40 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
///          ],
///    |     c)
///    |   *
///    |   W([kh, kw, ...], c, depth_multiplier)
///    |   W([kh, kw, ...], c, depthMultiplier)
/// ```
/// If `dilations` or `strides` are left empty, the default value of `1` is used
/// along each relevant dimension.
///
/// For now `...` must be empty (i.e. only 2-D convolutions are supported).
///
/// This accepts both buffers and tensors as `inputs` but only buffers as
/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in
/// which case, the canonical identity indexing_map is assumed.
//
// TODO(ntv) In the future we may want to relax this identity assumption (e.g.
// for automatic differentiation purposes). In that case we will want to make
// StructuredIndexed work with ValueHandle to encode type or value.
//
// TODO(ntv) Extend convolution rank with some template magic.
Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW,
                                    ValueHandle vO, int depth_multiplier = 1,
                                    ValueHandle vO,
                                    ArrayRef<Type> resultTensorTypes = {},
                                    int depthMultiplier = 1,
                                    ArrayRef<int> strides = {},
                                    ArrayRef<int> dilations = {});

template <typename Container>
Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier,
Operation *linalg_dilated_conv_nhwc(Container values,
                                    ArrayRef<Type> resultTensorTypes = {},
                                    int depthMultiplier = 1,
                                    ArrayRef<int> strides = {},
                                    ArrayRef<int> dilations = {}) {
  assert(values.size() == 3 && "Expected exactly 3 values");
  assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor");
  return linalg_dilated_conv_nhwc(values[0], values[1], values[2],
                                  depth_multiplier, strides, dilations);
                                  resultTensorTypes, depthMultiplier, strides,
                                  dilations);
}

} // namespace ops
+91 −41
Original line number Diff line number Diff line
@@ -128,16 +128,20 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,

Operation *mlir::edsc::makeGenericLinalgOp(
    ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
    ArrayRef<StructuredIndexed> outputs,
    ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Type> resultTensorTypes,
    function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
    ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
  assert(
      llvm::all_of(llvm::make_range(outputBuffers.begin(), outputBuffers.end()),
                   [](Value v) { return v.getType().isa<MemRefType>(); }) &&
      "output operands must all be buffers.");
  auto &builder = edsc::ScopedContext::getBuilder();
  auto *ctx = builder.getContext();
  unsigned nInputs = inputs.size();
  unsigned nOutputs = outputs.size();
  unsigned nOutputs = outputBuffers.size() + resultTensorTypes.size();
  unsigned maxPos = 0;
  getMaxDimIndex(inputs, maxPos);
  getMaxDimIndex(outputs, maxPos);
  getMaxDimIndex(outputBuffers, maxPos);
  // maxPos is 0 indexed, need to turn this into a count (i.e. +1)
  unsigned nDims = maxPos + 1;

@@ -146,7 +150,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
  for (auto in : inputs)
    maps.push_back(
        AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
  for (auto out : outputs)
  for (auto out : outputBuffers)
    maps.push_back(
        AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));

@@ -154,7 +158,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
  SmallVector<Value, 4> values;
  values.reserve(nViews);
  values.append(inputs.begin(), inputs.end());
  values.append(outputs.begin(), outputs.end());
  values.append(outputBuffers.begin(), outputBuffers.end());

  auto iteratorStrTypes = functional::map(toString, iteratorTypes);
  // clang-format off
@@ -162,7 +166,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
      edsc::ScopedContext::getBuilder()
          .create<linalg::GenericOp>(
              edsc::ScopedContext::getLocation(),
              ArrayRef<Type>{}, // TODO(ntv): support tensors
              resultTensorTypes,
              values,
              IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
              IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
@@ -207,7 +211,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {

Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
                                             StructuredIndexed I,
                                             StructuredIndexed O) {
                                             StructuredIndexed O,
                                             ArrayRef<Type> resultTensorTypes) {
  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
                                           edsc::IterType::Parallel);
  auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
@@ -215,22 +220,30 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
    ValueHandle a(args[0]);
    linalg_yield(unaryOp(a));
  };
  return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);

  // Distinguish between tensor and buffer semantics.
  if (O.getType().isa<MemRefType>()) {
    assert(resultTensorTypes.empty());
    return makeGenericLinalgOp(iterTypes, {I}, {O}, {}, fun);
  }
  return makeGenericLinalgOp(iterTypes, {I, O}, {}, resultTensorTypes, fun);
}

Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I,
                                                  StructuredIndexed O) {
Operation *
mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O,
                                       ArrayRef<Type> resultTensorTypes) {
  ;
  using edsc::intrinsics::tanh;
  UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
  return linalg_pointwise(unOp, I, O);
  return linalg_pointwise(unOp, I, O, resultTensorTypes);
}

/// Binary pointwise operation (with broadcast) entry point.
Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
                                             StructuredIndexed I1,
                                             StructuredIndexed I2,
                                             StructuredIndexed O) {
                                             StructuredIndexed O,
                                             ArrayRef<Type> resultTensorTypes) {
  SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
                                           edsc::IterType::Parallel);
  auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
@@ -238,45 +251,62 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
    ValueHandle a(args[0]), b(args[1]);
    linalg_yield(binaryOp(a, b));
  };
  return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
  // Distinguish between tensor and buffer semantics.
  if (O.getType().isa<MemRefType>()) {
    assert(resultTensorTypes.empty());
    return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, {}, fun);
  }
  return makeGenericLinalgOp(iterTypes, {I1, I2, O}, {}, resultTensorTypes,
                             fun);
}

Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
                                                 StructuredIndexed I2,
                                                 StructuredIndexed O) {
Operation *
mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1,
                                      StructuredIndexed I2, StructuredIndexed O,
                                      ArrayRef<Type> resultTensorTypes) {
  using edsc::op::operator+;
  BinaryPointwiseOpBuilder binOp(
      [](ValueHandle a, ValueHandle b) -> Value { return a + b; });
  return linalg_pointwise(binOp, I1, I2, O);
  return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes);
}

Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
                                                 StructuredIndexed I2,
                                                 StructuredIndexed O) {
Operation *
mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
                                      StructuredIndexed I2, StructuredIndexed O,
                                      ArrayRef<Type> resultTensorTypes) {
  BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
    using edsc::intrinsics::select;
    using edsc::op::operator>;
    return select(a > b, a, b).getValue();
  });
  return linalg_pointwise(binOp, I1, I2, O);
  return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes);
}

Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
                                          ValueHandle vC) {
  // clang-format off
                                          ValueHandle vC,
                                          ArrayRef<Type> resultTensorTypes) {
  AffineExpr m, n, k;
  bindDims(ScopedContext::getContext(), m, n, k);
  StructuredIndexed A(vA), B(vB), C(vC);

  assert(!C.getType().isa<MemRefType>() || resultTensorTypes.empty());
  StructuredIndexed allIndexed[3]{A({m, k}), B({k, n}), C({m, n})};
  ArrayRef<StructuredIndexed> inputs =
      (C.getType().isa<MemRefType>())
          ? ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 2}
          : ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 3};
  ArrayRef<StructuredIndexed> outputs =
      (C.getType().isa<MemRefType>())
          ? ArrayRef<StructuredIndexed>{allIndexed + 2, allIndexed + 3}
          : ArrayRef<StructuredIndexed>{};
  return makeGenericLinalgOp(
    {IterType::Parallel, IterType::Parallel, IterType::Reduction},
    {A({m, k}), B({k, n})},
    {C({m, n})},
    macRegionBuilder);
  // clang-format on
      {IterType::Parallel, IterType::Parallel, IterType::Reduction}, inputs,
      outputs, resultTensorTypes, macRegionBuilder);
}

Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
                                             ValueHandle vO,
                                             ArrayRef<Type> resultTensorTypes,
                                             ArrayRef<int> strides,
                                             ArrayRef<int> dilations) {
  MLIRContext *ctx = ScopedContext::getContext();
@@ -294,23 +324,33 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
  bindDims(ctx, b, f, h, w, kh, kw, c);
  unsigned numDims = c.cast<AffineDimExpr>().getPosition() + 1;
  StructuredIndexed I(vI), W(vW), O(vO);
  // clang-format off
  return makeGenericLinalgOp(
    {par, par, par, par, red, red, red}, {
      I({b,

  assert(!O.getType().isa<MemRefType>() || resultTensorTypes.empty());
  // Roundtrip to flattened form to serve as canonicalization and ensure
  // consistent ordering of subexpressions.
  // clang-format off
  StructuredIndexed allIndexed[3] = {
      I({b,
         simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
         simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
         c}),
      W({kh, kw, c, f})}, {
      O({b, h, w, f})},
    macRegionBuilder);
      W({kh, kw, c, f}),
      O({b, h, w, f})};
  // clang-format on
  auto inputs = (O.getType().isa<MemRefType>())
                    ? ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 2}
                    : ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 3};
  ArrayRef<StructuredIndexed> outputs =
      (O.getType().isa<MemRefType>())
          ? ArrayRef<StructuredIndexed>{allIndexed + 2, allIndexed + 3}
          : ArrayRef<StructuredIndexed>{};
  return makeGenericLinalgOp({par, par, par, par, red, red, red}, inputs,
                             outputs, resultTensorTypes, macRegionBuilder);
}

Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
    ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier,
    ValueHandle vI, ValueHandle vW, ValueHandle vO,
    ArrayRef<Type> resultTensorTypes, int depthMultiplier,
    ArrayRef<int> strides, ArrayRef<int> dilations) {
  MLIRContext *ctx = ScopedContext::getContext();
  // TODO(ntv) some template magic to make everything rank-polymorphic.
@@ -328,16 +368,26 @@ Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
  bindDims(ctx, b, dm, c, h, w, kh, kw);
  unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
  StructuredIndexed I(vI), W(vW), O(vO);
  return makeGenericLinalgOp(
    {par, par, par, par, par, red, red}, {
  // Roundtrip to flattened form to serve as canonicalization and ensure
  // consistent ordering of subexpressions.
  // clang-format off
  StructuredIndexed allIndexed[3] = {
      I({b,
         // Roundtrip to flattened form to serve as canonicalization and ensure
         // consistent ordering of subexpressions.
         simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
         simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
         c}),
      W({kh, kw, c, dm})}, {
      O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
    macRegionBuilder);
      W({kh, kw, c, dm}),
      O({b, h, w, simplifyAffineExpr(c * depthMultiplier + dm, numDims, 0)})};
  // clang-format on
  auto inputs = (O.getType().isa<MemRefType>())
                    ? ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 2}
                    : ArrayRef<StructuredIndexed>{allIndexed, allIndexed + 3};
  ArrayRef<StructuredIndexed> outputs =
      (O.getType().isa<MemRefType>())
          ? ArrayRef<StructuredIndexed>{allIndexed + 2, allIndexed + 3}
          : ArrayRef<StructuredIndexed>{};
  return makeGenericLinalgOp({par, par, par, par, par, red, red}, inputs,
                             outputs, resultTensorTypes, macRegionBuilder);
}
+39 −5

File changed.

Preview size limit exceeded, changes collapsed.