Commit 47ec8702 authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir][Linalg] Revisit 0-D abstraction

This revision takes advantage of the empty AffineMap to specify the
0-D edge case. This allows removing a bunch of annoying corner cases
that ended up impacting users of Linalg.

Differential Revision: https://reviews.llvm.org/D75831
parent 4a0267e3
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -91,7 +91,8 @@ affine-expr ::= `(` affine-expr `)`
              | bare-id
              | `-`? integer-literal

multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
multi-dim-affine-expr ::= `(` `)`
                        | `(` affine-expr (`,` affine-expr)* `)`
```

`ceildiv` is the ceiling function which maps the result of the division of its
+3 −1
Original line number Diff line number Diff line
@@ -184,7 +184,9 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
      MLIRContext *context = getContext();
      auto r_i = getAffineDimExpr(0, context);
      return SmallVector<AffineMap, 8>{
        AffineMap::get(1, 0, {r_i}), AffineMap::get(1, 0, {r_i}), AffineMap()};
        AffineMap::get(1, 0, {r_i}),
        AffineMap::get(1, 0, {r_i}),
        AffineMap::get(1, 0, context)};
    }
  }];

+6 −2
Original line number Diff line number Diff line
@@ -44,6 +44,11 @@ public:
  /// Returns a zero result affine map with no dimensions or symbols: () -> ().
  static AffineMap get(MLIRContext *context);

  /// Returns a zero result affine map with `dimCount` dimensions and
  /// `symbolCount` symbols, e.g.: `(...) -> ()`.
  static AffineMap get(unsigned dimCount, unsigned symbolCount,
                       MLIRContext *context);

  static AffineMap get(unsigned dimCount, unsigned symbolCount,
                       ArrayRef<AffineExpr> results);

@@ -275,8 +280,7 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
namespace llvm {

// AffineExpr hash just like pointers
template <>
struct DenseMapInfo<mlir::AffineMap> {
template <> struct DenseMapInfo<mlir::AffineMap> {
  static mlir::AffineMap getEmptyKey() {
    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
    return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
+2 −8
Original line number Diff line number Diff line
@@ -356,16 +356,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
             << idx << " to have " << nLoops
             << " dim(s) to match the number of loops";

    if (m.getNumResults() == 1 && view.getRank() == 0) {
      auto cst = m.getResult(0).template dyn_cast<AffineConstantExpr>();
      if (!cst || cst.getValue() != 0)
        return op.emitOpError("expected indexing_map #")
               << idx << " to be 0 to match 0-D view: " << view;
    } else if (m.getNumResults() != view.getRank()) {
    if (m.getNumResults() != view.getRank())
      return op.emitOpError("expected indexing_map #")
             << idx << " results to match view rank: " << view;
  }
  }

  auto concatMap = concatAffineMaps(indexingMaps);
  auto aggregateMap = inversePermutation(concatMap);
@@ -886,7 +880,7 @@ AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
  if (maybeMap)
    return maybeMap.getValue();
  if (rank == 0)
    return AffineMap();
    return AffineMap::get(context);
  return AffineMap::getMultiDimIdentityMap(rank, context);
}

+15 −31
Original line number Diff line number Diff line
@@ -37,6 +37,8 @@ using edsc::op::operator==;
static SmallVector<ValueHandle, 8>
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
                           ArrayRef<Value> vals) {
  if (map.isEmpty())
    return {};
  assert(map.getNumSymbols() == 0);
  assert(map.getNumInputs() == vals.size());
  SmallVector<ValueHandle, 8> res;
@@ -241,26 +243,17 @@ public:

    // 1.a. Emit std_load from input views.
    for (unsigned i = 0; i < nInputs; ++i) {
      Value input = genericOp.getInput(i);
      if (input.getType().cast<ShapedType>().getRank()) {
      ValueHandleArray indexing(makeCanonicalAffineApplies(
          b, loc, genericOp.getInputIndexingMap(i), allIvs));
        indexedValues[i] = std_load(input, indexing);
      } else {
        indexedValues[i] = std_load(input);
      }
      indexedValues[i] = std_load(genericOp.getInput(i), indexing);
    }

    // 1.b. Emit std_load from output views.
    for (unsigned i = 0; i < nOutputs; ++i) {
      Value output = genericOp.getOutputBuffer(i);
      if (output.getType().cast<ShapedType>().getRank()) {
      ValueHandleArray indexing(makeCanonicalAffineApplies(
          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
      indexedValues[nInputs + i] = std_load(output, indexing);
      } else {
        indexedValues[nInputs + i] = std_load(output);
      }
    }

    auto funcOp = genericOp.getFunction();
@@ -272,13 +265,9 @@ public:
      // 3. Emit std_store.
      for (unsigned i = 0; i < nOutputs; ++i) {
        Value output = genericOp.getOutputBuffer(i);
        if (output.getType().cast<ShapedType>().getRank()) {
        ValueHandleArray indexing(makeCanonicalAffineApplies(
            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
        std_store(callOp->getResult(i), output, indexing);
        } else {
          std_store(callOp->getResult(i), output);
        }
      }
      return;
    }
@@ -297,15 +286,10 @@ public:
    auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
    assert(yieldOp->getNumOperands() == nOutputs);
    for (unsigned i = 0; i < nOutputs; ++i) {
      Value output = genericOp.getOutputBuffer(i);
      if (output.getType().cast<ShapedType>().getRank()) {
      ValueHandleArray indexing(makeCanonicalAffineApplies(
          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
      std_store(map.lookup(yieldOp->getOperand(i)),
                genericOp.getOutputBuffer(i), indexing);
      } else {
        std_store(map.lookup(yieldOp->getOperand(i)), output);
      }
    }
  }
};
Loading