Loading mlir/include/mlir-c/Bindings/Python/Interop.h +21 −0 Original line number Diff line number Diff line Loading @@ -26,12 +26,14 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" #define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr" #define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" #define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr" #define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" Loading Loading @@ -240,6 +242,25 @@ static inline MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule) { return affineMap; } /** Creates a capsule object encapsulating the raw C-API MlirIntegerSet. * The returned capsule does not extend or affect ownership of any Python * objects that reference the set in any way. */ static inline PyObject * mlirPythonIntegerSetToCapsule(MlirIntegerSet integerSet) { return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(integerSet), MLIR_PYTHON_CAPSULE_INTEGER_SET, NULL); } /** Extracts an MlirIntegerSet from a capsule as produced from * mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then * a null set is returned (as checked via mlirIntegerSetIsNull). In such a * case, the Python APIs will have already set an error. */ static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) { void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INTEGER_SET); MlirIntegerSet integerSet = {ptr}; return integerSet; } #ifdef __cplusplus } #endif Loading mlir/lib/Bindings/Python/IRModules.cpp +220 −18 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include <pybind11/stl.h> Loading Loading @@ -3331,6 +3332,102 @@ PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { rawAffineMap); } //------------------------------------------------------------------------------ // PyIntegerSet and utilities. //------------------------------------------------------------------------------ class PyIntegerSetConstraint { public: PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} PyAffineExpr getExpr() { return PyAffineExpr(set.getContext(), mlirIntegerSetGetConstraint(set, pos)); } bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } static void bind(py::module &m) { py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint") .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); } private: PyIntegerSet set; intptr_t pos; }; class PyIntegerSetConstraintList : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { public: static constexpr const char *pyClassName = "IntegerSetConstraintList"; PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) : Sliceable(startIndex, length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, step), set(set) {} intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } PyIntegerSetConstraint getElement(intptr_t pos) { return PyIntegerSetConstraint(set, pos); } PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, intptr_t step) { return PyIntegerSetConstraintList(set, startIndex, length, step); } private: PyIntegerSet set; }; bool PyIntegerSet::operator==(const PyIntegerSet &other) { return mlirIntegerSetEqual(integerSet, other.integerSet); } py::object PyIntegerSet::getCapsule() { return py::reinterpret_steal<py::object>( mlirPythonIntegerSetToCapsule(*this)); } PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) throw py::error_already_set(); return PyIntegerSet( PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), rawIntegerSet); } /// Attempts to populate `result` with the content of `list` casted to the /// appropriate type (Python and C types are provided as template arguments). /// Throws errors in case of failure, using "action" to describe what the caller /// was attempting to do. template <typename PyType, typename CType> static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, StringRef action) { result.reserve(py::len(list)); for (py::handle item : list) { try { result.push_back(item.cast<PyType>()); } catch (py::cast_error &err) { std::string msg = (llvm::Twine("Invalid expression when ") + action + " (" + err.what() + ")") .str(); throw py::cast_error(msg); } catch (py::reference_cast_error &err) { std::string msg = (llvm::Twine("Invalid expression (None?) when ") + action + " (" + err.what() + ")") .str(); throw py::cast_error(msg); } } } //------------------------------------------------------------------------------ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ Loading Loading @@ -4152,24 +4249,8 @@ void mlir::python::populateIRSubmodule(py::module &m) { [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, DefaultingPyMlirContext context) { SmallVector<MlirAffineExpr> affineExprs; affineExprs.reserve(py::len(exprs)); for (py::handle expr : exprs) { try { affineExprs.push_back(expr.cast<PyAffineExpr>()); } catch (py::cast_error &err) { std::string msg = std::string("Invalid expression when attempting to create " "an AffineMap (") + err.what() + ")"; throw py::cast_error(msg); } catch (py::reference_cast_error &err) { std::string msg = std::string("Invalid expression (None?) when attempting to " "create an AffineMap (") + err.what() + ")"; throw py::cast_error(msg); } } pyListToVector<PyAffineExpr, MlirAffineExpr>( exprs, affineExprs, "attempting to create an AffineMap"); MlirAffineMap map = mlirAffineMapGet(context->get(), dimCount, symbolCount, affineExprs.size(), affineExprs.data()); Loading Loading @@ -4275,4 +4356,125 @@ void mlir::python::populateIRSubmodule(py::module &m) { return PyAffineMapExprList(self); }); PyAffineMapExprList::bind(m); //---------------------------------------------------------------------------- // Mapping of PyIntegerSet. //---------------------------------------------------------------------------- py::class_<PyIntegerSet>(m, "IntegerSet") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; mlirIntegerSetPrint(self, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }) .def("__repr__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; printAccum.parts.append("IntegerSet("); mlirIntegerSetPrint(self, printAccum.getCallback(), printAccum.getUserData()); printAccum.parts.append(")"); return printAccum.join(); }) .def_property_readonly( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) .def( "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, kDumpDocstring) .def_static( "get", [](intptr_t numDims, intptr_t numSymbols, py::list exprs, std::vector<bool> eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) throw py::value_error( "Expected the number of constraints to match " "that of equality flags"); if (exprs.empty()) throw py::value_error("Expected non-empty list of constraints"); // Copy over to a SmallVector because std::vector has a // specialization for booleans that packs data and does not // expose a `bool *`. SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); SmallVector<MlirAffineExpr> affineExprs; pyListToVector<PyAffineExpr>(exprs, affineExprs, "attempting to create an IntegerSet"); MlirIntegerSet set = mlirIntegerSetGet( context->get(), numDims, numSymbols, exprs.size(), affineExprs.data(), flags.data()); return PyIntegerSet(context->getRef(), set); }, py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), py::arg("eq_flags"), py::arg("context") = py::none()) .def_static( "get_empty", [](intptr_t numDims, intptr_t numSymbols, DefaultingPyMlirContext context) { MlirIntegerSet set = mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); return PyIntegerSet(context->getRef(), set); }, py::arg("num_dims"), py::arg("num_symbols"), py::arg("context") = py::none()) .def("get_replaced", [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, intptr_t numResultDims, intptr_t numResultSymbols) { if (static_cast<intptr_t>(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) throw py::value_error( "Expected the number of dimension replacement expressions " "to match that of dimensions"); if (static_cast<intptr_t>(symbolExprs.size()) != mlirIntegerSetGetNumSymbols(self)) throw py::value_error( "Expected the number of symbol replacement expressions " "to match that of symbols"); SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; pyListToVector<PyAffineExpr>( dimExprs, dimAffineExprs, "attempting to create an IntegerSet by replacing dimensions"); pyListToVector<PyAffineExpr>( symbolExprs, symbolAffineExprs, "attempting to create an IntegerSet by replacing symbols"); MlirIntegerSet set = mlirIntegerSetReplaceGet( self, dimAffineExprs.data(), symbolAffineExprs.data(), numResultDims, numResultSymbols); return PyIntegerSet(self.getContext(), set); }) .def_property_readonly("is_canonical_empty", [](PyIntegerSet &self) { return mlirIntegerSetIsCanonicalEmpty(self); }) .def_property_readonly( "n_dims", [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) .def_property_readonly( "n_symbols", [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) .def_property_readonly( "n_inputs", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) .def_property_readonly("n_equalities", [](PyIntegerSet &self) { return mlirIntegerSetGetNumEqualities(self); }) .def_property_readonly("n_inequalities", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInequalities(self); }) .def_property_readonly("constraints", [](PyIntegerSet &self) { return PyIntegerSetConstraintList(self); }); PyIntegerSetConstraint::bind(m); PyIntegerSetConstraintList::bind(m); } mlir/lib/Bindings/Python/IRModules.h +21 −0 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" namespace mlir { Loading Loading @@ -726,6 +727,26 @@ private: MlirAffineMap affineMap; }; class PyIntegerSet : public BaseContextObject { public: PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} bool operator==(const PyIntegerSet &other); operator MlirIntegerSet() const { return integerSet; } MlirIntegerSet get() const { return integerSet; } /// Gets a capsule wrapping the void* within the MlirIntegerSet. pybind11::object getCapsule(); /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. static PyIntegerSet createFromCapsule(pybind11::object capsule); private: MlirIntegerSet integerSet; }; void populateIRSubmodule(pybind11::module &m); } // namespace python Loading mlir/test/Bindings/Python/ir_integer_set.py 0 → 100644 +128 −0 Original line number Diff line number Diff line # RUN: %PYTHON %s | FileCheck %s import gc from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testIntegerSetCapsule def testIntegerSetCapsule(): with Context() as ctx: is1 = IntegerSet.get_empty(1, 1, ctx) capsule = is1._CAPIPtr # CHECK: mlir.ir.IntegerSet._CAPIPtr print(capsule) is2 = IntegerSet._CAPICreate(capsule) assert is1 == is2 assert is2.context is ctx run(testIntegerSetCapsule) # CHECK-LABEL: TEST: testIntegerSetGet def testIntegerSetGet(): with Context(): d0 = AffineDimExpr.get(0) d1 = AffineDimExpr.get(1) s0 = AffineSymbolExpr.get(0) c42 = AffineConstantExpr.get(42) # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0) set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False]) print(set0) # CHECK: (d0)[s0] : (1 == 0) set1 = IntegerSet.get_empty(1, 1) print(set1) # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0) set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2) print(set2) try: IntegerSet.get(2, 1, [], []) except ValueError as e: # CHECK: Expected non-empty list of constraints print(e) try: IntegerSet.get(2, 1, [d0 - d1], [True, False]) except ValueError as e: # CHECK: Expected the number of constraints to match that of equality flags print(e) try: IntegerSet.get(2, 1, [0], [True]) except RuntimeError as e: # CHECK: Invalid expression when attempting to create an IntegerSet print(e) try: IntegerSet.get(2, 1, [None], [True]) except RuntimeError as e: # CHECK: Invalid expression (None?) when attempting to create an IntegerSet print(e) try: set0.get_replaced([d0], [s0], 1, 1) except ValueError as e: # CHECK: Expected the number of dimension replacement expressions to match that of dimensions print(e) try: set0.get_replaced([d0, d1], [s0, s0], 1, 1) except ValueError as e: # CHECK: Expected the number of symbol replacement expressions to match that of symbols print(e) try: set0.get_replaced([d0, 1], [s0], 1, 1) except RuntimeError as e: # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions print(e) try: set0.get_replaced([d0, d1], [None], 1, 1) except RuntimeError as e: # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols print(e) run(testIntegerSetGet) # CHECK-LABEL: TEST: testIntegerSetProperties def testIntegerSetProperties(): with Context(): d0 = AffineDimExpr.get(0) d1 = AffineDimExpr.get(1) s0 = AffineSymbolExpr.get(0) c42 = AffineConstantExpr.get(42) set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False]) # CHECK: 2 print(set0.n_dims) # CHECK: 1 print(set0.n_symbols) # CHECK: 3 print(set0.n_inputs) # CHECK: 1 print(set0.n_equalities) # CHECK: 2 print(set0.n_inequalities) # CHECK: 3 print(len(set0.constraints)) # CHECK-DAG: d0 - d1 == 0 # CHECK-DAG: s0 - 42 >= 0 # CHECK-DAG: -d0 + s0 >= 0 for cstr in set0.constraints: print(cstr.expr, end='') print(" == 0" if cstr.is_eq else " >= 0") run(testIntegerSetProperties) Loading
mlir/include/mlir-c/Bindings/Python/Interop.h +21 −0 Original line number Diff line number Diff line Loading @@ -26,12 +26,14 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" #define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr" #define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" #define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr" #define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" Loading Loading @@ -240,6 +242,25 @@ static inline MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule) { return affineMap; } /** Creates a capsule object encapsulating the raw C-API MlirIntegerSet. * The returned capsule does not extend or affect ownership of any Python * objects that reference the set in any way. */ static inline PyObject * mlirPythonIntegerSetToCapsule(MlirIntegerSet integerSet) { return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(integerSet), MLIR_PYTHON_CAPSULE_INTEGER_SET, NULL); } /** Extracts an MlirIntegerSet from a capsule as produced from * mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then * a null set is returned (as checked via mlirIntegerSetIsNull). In such a * case, the Python APIs will have already set an error. */ static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) { void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INTEGER_SET); MlirIntegerSet integerSet = {ptr}; return integerSet; } #ifdef __cplusplus } #endif Loading
mlir/lib/Bindings/Python/IRModules.cpp +220 −18 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include <pybind11/stl.h> Loading Loading @@ -3331,6 +3332,102 @@ PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { rawAffineMap); } //------------------------------------------------------------------------------ // PyIntegerSet and utilities. //------------------------------------------------------------------------------ class PyIntegerSetConstraint { public: PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} PyAffineExpr getExpr() { return PyAffineExpr(set.getContext(), mlirIntegerSetGetConstraint(set, pos)); } bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } static void bind(py::module &m) { py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint") .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); } private: PyIntegerSet set; intptr_t pos; }; class PyIntegerSetConstraintList : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { public: static constexpr const char *pyClassName = "IntegerSetConstraintList"; PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) : Sliceable(startIndex, length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, step), set(set) {} intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } PyIntegerSetConstraint getElement(intptr_t pos) { return PyIntegerSetConstraint(set, pos); } PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, intptr_t step) { return PyIntegerSetConstraintList(set, startIndex, length, step); } private: PyIntegerSet set; }; bool PyIntegerSet::operator==(const PyIntegerSet &other) { return mlirIntegerSetEqual(integerSet, other.integerSet); } py::object PyIntegerSet::getCapsule() { return py::reinterpret_steal<py::object>( mlirPythonIntegerSetToCapsule(*this)); } PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) throw py::error_already_set(); return PyIntegerSet( PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), rawIntegerSet); } /// Attempts to populate `result` with the content of `list` casted to the /// appropriate type (Python and C types are provided as template arguments). /// Throws errors in case of failure, using "action" to describe what the caller /// was attempting to do. template <typename PyType, typename CType> static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, StringRef action) { result.reserve(py::len(list)); for (py::handle item : list) { try { result.push_back(item.cast<PyType>()); } catch (py::cast_error &err) { std::string msg = (llvm::Twine("Invalid expression when ") + action + " (" + err.what() + ")") .str(); throw py::cast_error(msg); } catch (py::reference_cast_error &err) { std::string msg = (llvm::Twine("Invalid expression (None?) when ") + action + " (" + err.what() + ")") .str(); throw py::cast_error(msg); } } } //------------------------------------------------------------------------------ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ Loading Loading @@ -4152,24 +4249,8 @@ void mlir::python::populateIRSubmodule(py::module &m) { [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, DefaultingPyMlirContext context) { SmallVector<MlirAffineExpr> affineExprs; affineExprs.reserve(py::len(exprs)); for (py::handle expr : exprs) { try { affineExprs.push_back(expr.cast<PyAffineExpr>()); } catch (py::cast_error &err) { std::string msg = std::string("Invalid expression when attempting to create " "an AffineMap (") + err.what() + ")"; throw py::cast_error(msg); } catch (py::reference_cast_error &err) { std::string msg = std::string("Invalid expression (None?) when attempting to " "create an AffineMap (") + err.what() + ")"; throw py::cast_error(msg); } } pyListToVector<PyAffineExpr, MlirAffineExpr>( exprs, affineExprs, "attempting to create an AffineMap"); MlirAffineMap map = mlirAffineMapGet(context->get(), dimCount, symbolCount, affineExprs.size(), affineExprs.data()); Loading Loading @@ -4275,4 +4356,125 @@ void mlir::python::populateIRSubmodule(py::module &m) { return PyAffineMapExprList(self); }); PyAffineMapExprList::bind(m); //---------------------------------------------------------------------------- // Mapping of PyIntegerSet. //---------------------------------------------------------------------------- py::class_<PyIntegerSet>(m, "IntegerSet") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; mlirIntegerSetPrint(self, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }) .def("__repr__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; printAccum.parts.append("IntegerSet("); mlirIntegerSetPrint(self, printAccum.getCallback(), printAccum.getUserData()); printAccum.parts.append(")"); return printAccum.join(); }) .def_property_readonly( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) .def( "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, kDumpDocstring) .def_static( "get", [](intptr_t numDims, intptr_t numSymbols, py::list exprs, std::vector<bool> eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) throw py::value_error( "Expected the number of constraints to match " "that of equality flags"); if (exprs.empty()) throw py::value_error("Expected non-empty list of constraints"); // Copy over to a SmallVector because std::vector has a // specialization for booleans that packs data and does not // expose a `bool *`. SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); SmallVector<MlirAffineExpr> affineExprs; pyListToVector<PyAffineExpr>(exprs, affineExprs, "attempting to create an IntegerSet"); MlirIntegerSet set = mlirIntegerSetGet( context->get(), numDims, numSymbols, exprs.size(), affineExprs.data(), flags.data()); return PyIntegerSet(context->getRef(), set); }, py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), py::arg("eq_flags"), py::arg("context") = py::none()) .def_static( "get_empty", [](intptr_t numDims, intptr_t numSymbols, DefaultingPyMlirContext context) { MlirIntegerSet set = mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); return PyIntegerSet(context->getRef(), set); }, py::arg("num_dims"), py::arg("num_symbols"), py::arg("context") = py::none()) .def("get_replaced", [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, intptr_t numResultDims, intptr_t numResultSymbols) { if (static_cast<intptr_t>(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) throw py::value_error( "Expected the number of dimension replacement expressions " "to match that of dimensions"); if (static_cast<intptr_t>(symbolExprs.size()) != mlirIntegerSetGetNumSymbols(self)) throw py::value_error( "Expected the number of symbol replacement expressions " "to match that of symbols"); SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; pyListToVector<PyAffineExpr>( dimExprs, dimAffineExprs, "attempting to create an IntegerSet by replacing dimensions"); pyListToVector<PyAffineExpr>( symbolExprs, symbolAffineExprs, "attempting to create an IntegerSet by replacing symbols"); MlirIntegerSet set = mlirIntegerSetReplaceGet( self, dimAffineExprs.data(), symbolAffineExprs.data(), numResultDims, numResultSymbols); return PyIntegerSet(self.getContext(), set); }) .def_property_readonly("is_canonical_empty", [](PyIntegerSet &self) { return mlirIntegerSetIsCanonicalEmpty(self); }) .def_property_readonly( "n_dims", [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) .def_property_readonly( "n_symbols", [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) .def_property_readonly( "n_inputs", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) .def_property_readonly("n_equalities", [](PyIntegerSet &self) { return mlirIntegerSetGetNumEqualities(self); }) .def_property_readonly("n_inequalities", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInequalities(self); }) .def_property_readonly("constraints", [](PyIntegerSet &self) { return PyIntegerSetConstraintList(self); }); PyIntegerSetConstraint::bind(m); PyIntegerSetConstraintList::bind(m); }
mlir/lib/Bindings/Python/IRModules.h +21 −0 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" namespace mlir { Loading Loading @@ -726,6 +727,26 @@ private: MlirAffineMap affineMap; }; class PyIntegerSet : public BaseContextObject { public: PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} bool operator==(const PyIntegerSet &other); operator MlirIntegerSet() const { return integerSet; } MlirIntegerSet get() const { return integerSet; } /// Gets a capsule wrapping the void* within the MlirIntegerSet. pybind11::object getCapsule(); /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. static PyIntegerSet createFromCapsule(pybind11::object capsule); private: MlirIntegerSet integerSet; }; void populateIRSubmodule(pybind11::module &m); } // namespace python Loading
mlir/test/Bindings/Python/ir_integer_set.py 0 → 100644 +128 −0 Original line number Diff line number Diff line # RUN: %PYTHON %s | FileCheck %s import gc from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 # CHECK-LABEL: TEST: testIntegerSetCapsule def testIntegerSetCapsule(): with Context() as ctx: is1 = IntegerSet.get_empty(1, 1, ctx) capsule = is1._CAPIPtr # CHECK: mlir.ir.IntegerSet._CAPIPtr print(capsule) is2 = IntegerSet._CAPICreate(capsule) assert is1 == is2 assert is2.context is ctx run(testIntegerSetCapsule) # CHECK-LABEL: TEST: testIntegerSetGet def testIntegerSetGet(): with Context(): d0 = AffineDimExpr.get(0) d1 = AffineDimExpr.get(1) s0 = AffineSymbolExpr.get(0) c42 = AffineConstantExpr.get(42) # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0) set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False]) print(set0) # CHECK: (d0)[s0] : (1 == 0) set1 = IntegerSet.get_empty(1, 1) print(set1) # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0) set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2) print(set2) try: IntegerSet.get(2, 1, [], []) except ValueError as e: # CHECK: Expected non-empty list of constraints print(e) try: IntegerSet.get(2, 1, [d0 - d1], [True, False]) except ValueError as e: # CHECK: Expected the number of constraints to match that of equality flags print(e) try: IntegerSet.get(2, 1, [0], [True]) except RuntimeError as e: # CHECK: Invalid expression when attempting to create an IntegerSet print(e) try: IntegerSet.get(2, 1, [None], [True]) except RuntimeError as e: # CHECK: Invalid expression (None?) when attempting to create an IntegerSet print(e) try: set0.get_replaced([d0], [s0], 1, 1) except ValueError as e: # CHECK: Expected the number of dimension replacement expressions to match that of dimensions print(e) try: set0.get_replaced([d0, d1], [s0, s0], 1, 1) except ValueError as e: # CHECK: Expected the number of symbol replacement expressions to match that of symbols print(e) try: set0.get_replaced([d0, 1], [s0], 1, 1) except RuntimeError as e: # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions print(e) try: set0.get_replaced([d0, d1], [None], 1, 1) except RuntimeError as e: # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols print(e) run(testIntegerSetGet) # CHECK-LABEL: TEST: testIntegerSetProperties def testIntegerSetProperties(): with Context(): d0 = AffineDimExpr.get(0) d1 = AffineDimExpr.get(1) s0 = AffineSymbolExpr.get(0) c42 = AffineConstantExpr.get(42) set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False]) # CHECK: 2 print(set0.n_dims) # CHECK: 1 print(set0.n_symbols) # CHECK: 3 print(set0.n_inputs) # CHECK: 1 print(set0.n_equalities) # CHECK: 2 print(set0.n_inequalities) # CHECK: 3 print(len(set0.constraints)) # CHECK-DAG: d0 - d1 == 0 # CHECK-DAG: s0 - 42 >= 0 # CHECK-DAG: -d0 + s0 >= 0 for cstr in set0.constraints: print(cstr.expr, end='') print(" == 0" if cstr.is_eq else " >= 0") run(testIntegerSetProperties)