Commit ed9e52f3 authored by Alex Zinenko's avatar Alex Zinenko
Browse files

[mlir][python] Usability improvements for Python bindings

Provide a couple of quality-of-life usability improvements for Python bindings,
in particular:

  * give access to the list of types for the list of op results or block
    arguments, similarly to ValueRange->TypeRange,

  * allow for constructing empty dictionary arrays,

  * support construction of array attributes by concatenating an existing
    attribute with a Python list of attributes.

All these are required for the upcoming customization of builtin and standard
ops.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110946
parent c7bd6435
Loading
Loading
Loading
Loading
+37 −19
Original line number Diff line number Diff line
@@ -18,7 +18,6 @@ using namespace mlir;
using namespace mlir::python;

using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;

namespace {
@@ -44,6 +43,24 @@ public:
  }
};

template <typename T>
static T pyTryCast(py::handle object) {
  try {
    return object.cast<T>();
  } catch (py::cast_error &err) {
    std::string msg =
        std::string(
            "Invalid attribute when attempting to create an ArrayAttribute (") +
        err.what() + ")";
    throw py::cast_error(msg);
  } catch (py::reference_cast_error &err) {
    std::string msg = std::string("Invalid attribute (None?) when attempting "
                                  "to create an ArrayAttribute (") +
                      err.what() + ")";
    throw py::cast_error(msg);
  }
}

class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
public:
  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
@@ -76,6 +93,10 @@ public:
    int nextIndex = 0;
  };

  PyAttribute getItem(intptr_t i) {
    return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
  }

  static void bindDerived(ClassTy &c) {
    c.def_static(
        "get",
@@ -83,21 +104,7 @@ public:
          SmallVector<MlirAttribute> mlirAttributes;
          mlirAttributes.reserve(py::len(attributes));
          for (auto attribute : attributes) {
            try {
              mlirAttributes.push_back(attribute.cast<PyAttribute>());
            } catch (py::cast_error &err) {
              std::string msg = std::string("Invalid attribute when attempting "
                                            "to create an ArrayAttribute (") +
                                err.what() + ")";
              throw py::cast_error(msg);
            } catch (py::reference_cast_error &err) {
              // This exception seems thrown when the value is "None".
              std::string msg =
                  std::string("Invalid attribute (None?) when attempting to "
                              "create an ArrayAttribute (") +
                  err.what() + ")";
              throw py::cast_error(msg);
            }
            mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
          }
          MlirAttribute attr = mlirArrayAttrGet(
              context->get(), mlirAttributes.size(), mlirAttributes.data());
@@ -109,8 +116,7 @@ public:
          [](PyArrayAttribute &arr, intptr_t i) {
            if (i >= mlirArrayAttrGetNumElements(arr))
              throw py::index_error("ArrayAttribute index out of range");
            return PyAttribute(arr.getContext(),
                               mlirArrayAttrGetElement(arr, i));
            return arr.getItem(i);
          })
        .def("__len__",
             [](const PyArrayAttribute &arr) {
@@ -119,6 +125,18 @@ public:
        .def("__iter__", [](const PyArrayAttribute &arr) {
          return PyArrayAttributeIterator(arr);
        });
    c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
      std::vector<MlirAttribute> attributes;
      intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
      attributes.reserve(numOldElements + py::len(extras));
      for (intptr_t i = 0; i < numOldElements; ++i)
        attributes.push_back(arr.getItem(i));
      for (py::handle attr : extras)
        attributes.push_back(pyTryCast<PyAttribute>(attr));
      MlirAttribute arrayAttr = mlirArrayAttrGet(
          arr.getContext()->get(), attributes.size(), attributes.data());
      return PyArrayAttribute(arr.getContext(), arrayAttr);
    });
  }
};

@@ -602,7 +620,7 @@ public:
                                    mlirNamedAttributes.data());
          return PyDictAttribute(context->getRef(), attr);
        },
        py::arg("value"), py::arg("context") = py::none(),
        py::arg("value") = py::dict(), py::arg("context") = py::none(),
        "Gets an uniqued dict attribute");
    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
      MlirAttribute attr =
+25 −0
Original line number Diff line number Diff line
@@ -1590,6 +1590,19 @@ public:
  }
};

/// Returns the list of types of the values held by container.
template <typename Container>
static std::vector<PyType> getValueTypes(Container &container,
                                         PyMlirContextRef &context) {
  std::vector<PyType> result;
  result.reserve(container.getNumElements());
  for (int i = 0, e = container.getNumElements(); i < e; ++i) {
    result.push_back(
        PyType(context, mlirValueGetType(container.getElement(i).get())));
  }
  return result;
}

/// A list of block arguments. Internally, these are stored as consecutive
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
@@ -1625,6 +1638,12 @@ public:
    return PyBlockArgumentList(operation, block, startIndex, length, step);
  }

  static void bindDerived(ClassTy &c) {
    c.def_property_readonly("types", [](PyBlockArgumentList &self) {
      return getValueTypes(self, self.operation->getContext());
    });
  }

private:
  PyOperationRef operation;
  MlirBlock block;
@@ -1712,6 +1731,12 @@ public:
    return PyOpResultList(operation, startIndex, length, step);
  }

  static void bindDerived(ClassTy &c) {
    c.def_property_readonly("types", [](PyOpResultList &self) {
      return getValueTypes(self, self.operation->getContext());
    });
  }

private:
  PyOperationRef operation;
};
+9 −0
Original line number Diff line number Diff line
@@ -343,6 +343,9 @@ def testDictAttr():
    else:
      assert False, "expected IndexError on accessing an out-of-bounds attribute"

    # CHECK "empty: {}"
    print("empty: ", DictAttr.get())


# CHECK-LABEL: TEST: testTypeAttr
@run
@@ -404,3 +407,9 @@ def testArrayAttr():
    except RuntimeError as e:
      # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
      print("Error: ", e)

  with Context():
    array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
    array = array + [StringAttr.get("c")]
    # CHECK: concat: ["a", "b", "c"]
    print("concat: ", array)
+12 −0
Original line number Diff line number Diff line
@@ -145,6 +145,12 @@ def testBlockArgumentList():
    print("Length: ",
          len(entry_block.arguments[:2] + entry_block.arguments[1:]))

    # CHECK: Type: i8
    # CHECK: Type: i16
    # CHECK: Type: i24
    for t in entry_block.arguments.types:
      print("Type: ", t)


run(testBlockArgumentList)

@@ -380,6 +386,12 @@ def testOperationResultList():
  for res in call.results:
    print(f"Result {res.result_number}, type {res.type}")

  # CHECK: Result type i32
  # CHECK: Result type f64
  # CHECK: Result type index
  for t in call.results.types:
    print(f"Result type {t}")


run(testOperationResultList)