Commit fd226c9b authored by Stella Laurenzo's avatar Stella Laurenzo
Browse files

[mlir][Python] Roll up of python API fixes.

* As discussed, fixes the ordering or (operands, results) -> (results, operands) in various `create` like methods.
* Fixes a syntax error in an ODS accessor method.
* Removes the linalg example in favor of a test case that exercises the same.
* Fixes FuncOp visibility to properly use None instead of the empty string and defaults it to None.
* Implements what was documented for requiring that trailing __init__ args `loc` and `ip` are keyword only.
* Adds a check to `InsertionPoint.insert` so that if attempting to insert past the terminator, an exception is raised telling you what to do instead. Previously, this would crash downstream (i.e. when trying to print the resultant module).
* Renames `_ods_build_default` -> `build_generic` and documents it.
* Removes `result` from the list of prohibited words and for single-result ops, defaults to naming the result `result`, thereby matching expectations and what is already implemented on the base class.
* This was intended to be a relatively small set of changes to be inlined with the broader support for ODS generating the most specific builder, but it spidered out once actually testing various combinations, so rolling up separately.

Differential Revision: https://reviews.llvm.org/D95320
parent 78d41a12
Loading
Loading
Loading
Loading
+11 −6
Original line number Diff line number Diff line
@@ -439,8 +439,9 @@ defaults on `OpView`):
#### Builders

Presently, only a single, default builder is mapped to the `__init__` method.
Generalizing this facility is under active development. It currently accepts
arguments:
The intent is that this `__init__` method represents the *most specific* of
the builders typically generated for C++; however currently it is just the
generic form below.

* One argument for each declared result:
  * For single-valued results: Each will accept an `mlir.ir.Type`.
@@ -453,7 +454,11 @@ arguments:
  * `loc`: An explicit `mlir.ir.Location` to use. Defaults to the location
    bound to the thread (i.e. `with Location.unknown():`) or an error if none
    is bound nor specified.
  * `context`: An explicit `mlir.ir.Context` to use. Default to the context
    bound to the thread (i.e. `with Context():` or implicitly via `Location` or
    `InsertionPoint` context managers) or an error if none is bound nor
    specified.
  * `ip`: An explicit `mlir.ir.InsertionPoint` to use. Default to the insertion
    point bound to the thread (i.e. `with InsertionPoint(...):`).

In addition, each `OpView` inherits a `build_generic` method which allows
construction via a (nested in the case of variadic) sequence of `results` and
`operands`. This can be used to get some default construction semantics for
operations that are otherwise unsupported in Python, at the expense of having
a very generic signature.

mlir/examples/python/.style.yapf

deleted100644 → 0
+0 −4
Original line number Diff line number Diff line
[style]
  based_on_style = google
  column_limit = 80
  indent_width = 2
+0 −81
Original line number Diff line number Diff line
#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# This is a work in progress example to do end2end build and code generation
# of a small linalg program with configuration options. It is currently non
# functional and is being used to elaborate the APIs.

from typing import Tuple

from mlir.ir import *
from mlir.dialects import linalg
from mlir.dialects import std


# TODO: This should be in the core API.
def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
  """Creates a |func| op.
    TODO: This should really be in the MLIR API.
    Returns:
      (operation, entry_block)
    """
  attrs = {
      "type": TypeAttr.get(func_type),
      "sym_name": StringAttr.get(name),
  }
  op = Operation.create("func", regions=1, attributes=attrs)
  body_region = op.regions[0]
  entry_block = body_region.blocks.append(*func_type.inputs)
  return op, entry_block


def build_matmul_buffers_func(func_name, m, k, n, dtype):
  lhs_type = MemRefType.get([m, k], dtype)
  rhs_type = MemRefType.get([k, n], dtype)
  result_type = MemRefType.get([m, n], dtype)
  # TODO: There should be a one-liner for this.
  func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
  _, entry = FuncOp(func_name, func_type)
  lhs, rhs, result = entry.arguments
  with InsertionPoint(entry):
    op = linalg.MatmulOp([lhs, rhs], [result])
    # TODO: Implement support for SingleBlockImplicitTerminator
    block = op.regions[0].blocks.append()
    with InsertionPoint(block):
        linalg.YieldOp(values=[])

    std.ReturnOp([])


def build_matmul_tensors_func(func_name, m, k, n, dtype):
  lhs_type = RankedTensorType.get([m, k], dtype)
  rhs_type = RankedTensorType.get([k, n], dtype)
  result_type = RankedTensorType.get([m, n], dtype)
  # TODO: There should be a one-liner for this.
  func_type = FunctionType.get([lhs_type, rhs_type], [result_type])
  _, entry = FuncOp(func_name, func_type)
  lhs, rhs = entry.arguments
  with InsertionPoint(entry):
    op = linalg.MatmulOp([lhs, rhs], results=[result_type])
    # TODO: Implement support for SingleBlockImplicitTerminator
    block = op.regions[0].blocks.append()
    with InsertionPoint(block):
        linalg.YieldOp(values=[])
    std.ReturnOp([op.result])


def run():
  with Context() as c, Location.unknown():
    module = Module.create()
    # TODO: This at_block_terminator vs default construct distinction feels
    # wrong and is error-prone.
    with InsertionPoint.at_block_terminator(module.body):
      build_matmul_buffers_func('main_buffers', 18, 32, 96, F32Type.get())
      build_matmul_tensors_func('main_tensors', 18, 32, 96, F32Type.get())

    print(module)


if __name__ == '__main__':
  run()
+26 −16
Original line number Diff line number Diff line
@@ -891,8 +891,8 @@ PyBlock PyOperation::getBlock() {
}

py::object PyOperation::create(
    std::string name, llvm::Optional<std::vector<PyValue *>> operands,
    llvm::Optional<std::vector<PyType *>> results,
    std::string name, llvm::Optional<std::vector<PyType *>> results,
    llvm::Optional<std::vector<PyValue *>> operands,
    llvm::Optional<py::dict> attributes,
    llvm::Optional<std::vector<PyBlock *>> successors, int regions,
    DefaultingPyLocation location, py::object maybeIp) {
@@ -1039,8 +1039,8 @@ py::object PyOperation::createOpView() {
//------------------------------------------------------------------------------

py::object
PyOpView::odsBuildDefault(py::object cls, py::list operandList,
                          py::list resultTypeList,
PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
                       py::list operandList,
                       llvm::Optional<py::dict> attributes,
                       llvm::Optional<std::vector<PyBlock *>> successors,
                       llvm::Optional<int> regions,
@@ -1288,8 +1288,9 @@ PyOpView::odsBuildDefault(py::object cls, py::list operandList,
  }

  // Delegate to create.
  return PyOperation::create(std::move(name), /*operands=*/std::move(operands),
  return PyOperation::create(std::move(name),
                             /*results=*/std::move(resultTypes),
                             /*operands=*/std::move(operands),
                             /*attributes=*/std::move(attributes),
                             /*successors=*/std::move(successors),
                             /*regions=*/*regions, location, maybeIp);
@@ -1357,6 +1358,16 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) {
    // Insert before operation.
    (*refOperation)->checkValid();
    beforeOp = (*refOperation)->get();
  } else {
    // Insert at end (before null) is only valid if the block does not
    // already end in a known terminator (violating this will cause assertion
    // failures later).
    if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
      throw py::index_error("Cannot insert operation at the end of a block "
                            "that already has a terminator. Did you mean to "
                            "use 'InsertionPoint.at_block_terminator(block)' "
                            "versus 'InsertionPoint(block)'?");
    }
  }
  mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
  operation.setAttached();
@@ -3646,8 +3657,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {

  py::class_<PyOperation, PyOperationBase>(m, "Operation")
      .def_static("create", &PyOperation::create, py::arg("name"),
                  py::arg("operands") = py::none(),
                  py::arg("results") = py::none(),
                  py::arg("operands") = py::none(),
                  py::arg("attributes") = py::none(),
                  py::arg("successors") = py::none(), py::arg("regions") = 0,
                  py::arg("loc") = py::none(), py::arg("ip") = py::none(),
@@ -3681,12 +3692,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
  opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
  opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
  opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
  opViewClass.attr("_ods_build_default") = classmethod(
      &PyOpView::odsBuildDefault, py::arg("cls"),
      py::arg("operands") = py::none(), py::arg("results") = py::none(),
      py::arg("attributes") = py::none(), py::arg("successors") = py::none(),
      py::arg("regions") = py::none(), py::arg("loc") = py::none(),
      py::arg("ip") = py::none(),
  opViewClass.attr("build_generic") = classmethod(
      &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
      py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
      py::arg("successors") = py::none(), py::arg("regions") = py::none(),
      py::arg("loc") = py::none(), py::arg("ip") = py::none(),
      "Builds a specific, generated OpView based on class level attributes.");

  //----------------------------------------------------------------------------
+8 −8
Original line number Diff line number Diff line
@@ -455,8 +455,8 @@ public:

  /// Creates an operation. See corresponding python docstring.
  static pybind11::object
  create(std::string name, llvm::Optional<std::vector<PyValue *>> operands,
         llvm::Optional<std::vector<PyType *>> results,
  create(std::string name, llvm::Optional<std::vector<PyType *>> results,
         llvm::Optional<std::vector<PyValue *>> operands,
         llvm::Optional<pybind11::dict> attributes,
         llvm::Optional<std::vector<PyBlock *>> successors, int regions,
         DefaultingPyLocation location, pybind11::object ip);
@@ -498,8 +498,8 @@ public:
  pybind11::object getOperationObject() { return operationObject; }

  static pybind11::object
  odsBuildDefault(pybind11::object cls, pybind11::list operandList,
                  pybind11::list resultTypeList,
  buildGeneric(pybind11::object cls, pybind11::list resultTypeList,
               pybind11::list operandList,
               llvm::Optional<pybind11::dict> attributes,
               llvm::Optional<std::vector<PyBlock *>> successors,
               llvm::Optional<int> regions, DefaultingPyLocation location,
Loading