Commit bbf4436a authored by Tobias Gysi's avatar Tobias Gysi
Browse files

[mlir][linalg] Remove the StructuredOp capture mechanism.

After https://reviews.llvm.org/D104109, structured ops support scalar inputs. As a result, the capture mechanism meant to pass non-shaped parameters got redundant. The patch removes the capture semantics after the FillOp migrated to use scalar operands https://reviews.llvm.org/D104121.

Differential Revision: https://reviews.llvm.org/D104785
parent a1c0f09a
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -18,11 +18,9 @@ extern "C" {
#endif

/// Apply the special region builder for the builtin named Linalg op.
/// The list of `capture` MlirValue is passed as-is to the region builder.
/// Assert that `op` is a builtin named Linalg op.
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op,
                                   intptr_t n, MlirValue const *mlirCaptures);
mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);

+1 −1
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ def Linalg_Dialect : Dialect {
      kInplaceableAttrName = "linalg.inplaceable";

    using RegionBuilderFunType =
      llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &, ValueRange)>;
      llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
    RegionBuilderFunType getRegionBuilder(StringRef name) {
      return namedStructuredOpRegionBuilders.lookup(name);
    }
+1 −1
Original line number Diff line number Diff line
@@ -901,7 +901,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
        Returns a null function if this named op does not define a region
        builder.
      }],
      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ValueRange)>",
      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &)>",
      /*methodName=*/"getRegionBuilder",
      (ins),
      [{ return ConcreteOp::getRegionBuilder(); }]
+6 −12
Original line number Diff line number Diff line
@@ -153,10 +153,8 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
    Value getSource() { return input();}
    Value getTarget() { return output(); }

    static void regionBuilder(
      ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
    static std::function<
      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
    static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
    getRegionBuilder() {
      return &regionBuilder;
    }
@@ -200,10 +198,8 @@ def FillOp : LinalgStructured_Op<"fill", []> {
          extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
    }

    static void regionBuilder(
      ImplicitLocOpBuilder &b, Block &block, ValueRange captures);
    static std::function<
      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
    static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
    getRegionBuilder() {
      return &regionBuilder;
    }
@@ -291,8 +287,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
      return padding().getValue().getValue<int64_t>({i, 1});
    }

    static std::function<
      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
    getRegionBuilder() {
      return nullptr;
    }
@@ -533,8 +528,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
        library_call()->str() : "op_has_no_registered_library_name";
    }

    static std::function<
      void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)>
    static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
    getRegionBuilder() {
      return nullptr;
    }
+3 −8
Original line number Diff line number Diff line
@@ -21,15 +21,10 @@ using namespace mlir::python;
void mlir::python::populateDialectLinalgSubmodule(py::module m) {
  m.def(
      "fill_builtin_region",
      [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) {
        llvm::SmallVector<MlirValue, 4> mlirOperands;
        mlirOperands.reserve(captures.size());
        for (auto v : captures)
          mlirOperands.push_back(py::cast<PyValue *>(v)->get());
        mlirLinalgFillBuiltinNamedOpRegion(
            dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data());
      [](PyDialectDescriptor &dialect, PyOperation &op) {
        mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
      },
      py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(),
      py::arg("dialect"), py::arg("op"),
      "Fill the region for `op`, which is assumed to be a builtin named Linalg "
      "op.");
}
Loading