Commit b208e5bc authored by Alex Zinenko's avatar Alex Zinenko
Browse files

[mlir] Add Python bindings for IntegerSet

This follows up on the introduction of C API for the same object and is similar
to AffineExpr and AffineMap.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D95437
parent 00773ef7
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -26,12 +26,14 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Pass.h"

#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
@@ -240,6 +242,25 @@ static inline MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule) {
  return affineMap;
}

/** Creates a capsule object encapsulating the raw C-API MlirIntegerSet.
 * The returned capsule does not extend or affect ownership of any Python
 * objects that reference the set in any way. */
static inline PyObject *
mlirPythonIntegerSetToCapsule(MlirIntegerSet integerSet) {
  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(integerSet),
                       MLIR_PYTHON_CAPSULE_INTEGER_SET, NULL);
}

/** Extracts an MlirIntegerSet from a capsule as produced from
 * mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then
 * a null set is returned (as checked via mlirIntegerSetIsNull). In such a
 * case, the Python APIs will have already set an error. */
static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) {
  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INTEGER_SET);
  MlirIntegerSet integerSet = {ptr};
  return integerSet;
}

#ifdef __cplusplus
}
#endif
+220 −18
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Registration.h"
#include "llvm/ADT/SmallVector.h"
#include <pybind11/stl.h>
@@ -3331,6 +3332,102 @@ PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
      rawAffineMap);
}

//------------------------------------------------------------------------------
// PyIntegerSet and utilities.
//------------------------------------------------------------------------------

class PyIntegerSetConstraint {
public:
  PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}

  PyAffineExpr getExpr() {
    return PyAffineExpr(set.getContext(),
                        mlirIntegerSetGetConstraint(set, pos));
  }

  bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }

  static void bind(py::module &m) {
    py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
        .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
        .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
  }

private:
  PyIntegerSet set;
  intptr_t pos;
};

class PyIntegerSetConstraintList
    : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
public:
  static constexpr const char *pyClassName = "IntegerSetConstraintList";

  PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
                             intptr_t length = -1, intptr_t step = 1)
      : Sliceable(startIndex,
                  length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
                  step),
        set(set) {}

  intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }

  PyIntegerSetConstraint getElement(intptr_t pos) {
    return PyIntegerSetConstraint(set, pos);
  }

  PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
                                   intptr_t step) {
    return PyIntegerSetConstraintList(set, startIndex, length, step);
  }

private:
  PyIntegerSet set;
};

bool PyIntegerSet::operator==(const PyIntegerSet &other) {
  return mlirIntegerSetEqual(integerSet, other.integerSet);
}

py::object PyIntegerSet::getCapsule() {
  return py::reinterpret_steal<py::object>(
      mlirPythonIntegerSetToCapsule(*this));
}

PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
  MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
  if (mlirIntegerSetIsNull(rawIntegerSet))
    throw py::error_already_set();
  return PyIntegerSet(
      PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
      rawIntegerSet);
}

/// Attempts to populate `result` with the content of `list` casted to the
/// appropriate type (Python and C types are provided as template arguments).
/// Throws errors in case of failure, using "action" to describe what the caller
/// was attempting to do.
template <typename PyType, typename CType>
static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
                           StringRef action) {
  result.reserve(py::len(list));
  for (py::handle item : list) {
    try {
      result.push_back(item.cast<PyType>());
    } catch (py::cast_error &err) {
      std::string msg = (llvm::Twine("Invalid expression when ") + action +
                         " (" + err.what() + ")")
                            .str();
      throw py::cast_error(msg);
    } catch (py::reference_cast_error &err) {
      std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
                         action + " (" + err.what() + ")")
                            .str();
      throw py::cast_error(msg);
    }
  }
}

//------------------------------------------------------------------------------
// Populates the pybind11 IR submodule.
//------------------------------------------------------------------------------
@@ -4152,24 +4249,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
          [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
             DefaultingPyMlirContext context) {
            SmallVector<MlirAffineExpr> affineExprs;
            affineExprs.reserve(py::len(exprs));
            for (py::handle expr : exprs) {
              try {
                affineExprs.push_back(expr.cast<PyAffineExpr>());
              } catch (py::cast_error &err) {
                std::string msg =
                    std::string("Invalid expression when attempting to create "
                                "an AffineMap (") +
                    err.what() + ")";
                throw py::cast_error(msg);
              } catch (py::reference_cast_error &err) {
                std::string msg =
                    std::string("Invalid expression (None?) when attempting to "
                                "create an AffineMap (") +
                    err.what() + ")";
                throw py::cast_error(msg);
              }
            }
            pyListToVector<PyAffineExpr, MlirAffineExpr>(
                exprs, affineExprs, "attempting to create an AffineMap");
            MlirAffineMap map =
                mlirAffineMapGet(context->get(), dimCount, symbolCount,
                                 affineExprs.size(), affineExprs.data());
@@ -4275,4 +4356,125 @@ void mlir::python::populateIRSubmodule(py::module &m) {
        return PyAffineMapExprList(self);
      });
  PyAffineMapExprList::bind(m);

  //----------------------------------------------------------------------------
  // Mapping of PyIntegerSet.
  //----------------------------------------------------------------------------
  py::class_<PyIntegerSet>(m, "IntegerSet")
      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                             &PyIntegerSet::getCapsule)
      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
      .def("__eq__", [](PyIntegerSet &self,
                        PyIntegerSet &other) { return self == other; })
      .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
      .def("__str__",
           [](PyIntegerSet &self) {
             PyPrintAccumulator printAccum;
             mlirIntegerSetPrint(self, printAccum.getCallback(),
                                 printAccum.getUserData());
             return printAccum.join();
           })
      .def("__repr__",
           [](PyIntegerSet &self) {
             PyPrintAccumulator printAccum;
             printAccum.parts.append("IntegerSet(");
             mlirIntegerSetPrint(self, printAccum.getCallback(),
                                 printAccum.getUserData());
             printAccum.parts.append(")");
             return printAccum.join();
           })
      .def_property_readonly(
          "context",
          [](PyIntegerSet &self) { return self.getContext().getObject(); })
      .def(
          "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
          kDumpDocstring)
      .def_static(
          "get",
          [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
             std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
            if (exprs.size() != eqFlags.size())
              throw py::value_error(
                  "Expected the number of constraints to match "
                  "that of equality flags");
            if (exprs.empty())
              throw py::value_error("Expected non-empty list of constraints");

            // Copy over to a SmallVector because std::vector has a
            // specialization for booleans that packs data and does not
            // expose a `bool *`.
            SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());

            SmallVector<MlirAffineExpr> affineExprs;
            pyListToVector<PyAffineExpr>(exprs, affineExprs,
                                         "attempting to create an IntegerSet");
            MlirIntegerSet set = mlirIntegerSetGet(
                context->get(), numDims, numSymbols, exprs.size(),
                affineExprs.data(), flags.data());
            return PyIntegerSet(context->getRef(), set);
          },
          py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
          py::arg("eq_flags"), py::arg("context") = py::none())
      .def_static(
          "get_empty",
          [](intptr_t numDims, intptr_t numSymbols,
             DefaultingPyMlirContext context) {
            MlirIntegerSet set =
                mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
            return PyIntegerSet(context->getRef(), set);
          },
          py::arg("num_dims"), py::arg("num_symbols"),
          py::arg("context") = py::none())
      .def("get_replaced",
           [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
              intptr_t numResultDims, intptr_t numResultSymbols) {
             if (static_cast<intptr_t>(dimExprs.size()) !=
                 mlirIntegerSetGetNumDims(self))
               throw py::value_error(
                   "Expected the number of dimension replacement expressions "
                   "to match that of dimensions");
             if (static_cast<intptr_t>(symbolExprs.size()) !=
                 mlirIntegerSetGetNumSymbols(self))
               throw py::value_error(
                   "Expected the number of symbol replacement expressions "
                   "to match that of symbols");

             SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
             pyListToVector<PyAffineExpr>(
                 dimExprs, dimAffineExprs,
                 "attempting to create an IntegerSet by replacing dimensions");
             pyListToVector<PyAffineExpr>(
                 symbolExprs, symbolAffineExprs,
                 "attempting to create an IntegerSet by replacing symbols");
             MlirIntegerSet set = mlirIntegerSetReplaceGet(
                 self, dimAffineExprs.data(), symbolAffineExprs.data(),
                 numResultDims, numResultSymbols);
             return PyIntegerSet(self.getContext(), set);
           })
      .def_property_readonly("is_canonical_empty",
                             [](PyIntegerSet &self) {
                               return mlirIntegerSetIsCanonicalEmpty(self);
                             })
      .def_property_readonly(
          "n_dims",
          [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
      .def_property_readonly(
          "n_symbols",
          [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
      .def_property_readonly(
          "n_inputs",
          [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
      .def_property_readonly("n_equalities",
                             [](PyIntegerSet &self) {
                               return mlirIntegerSetGetNumEqualities(self);
                             })
      .def_property_readonly("n_inequalities",
                             [](PyIntegerSet &self) {
                               return mlirIntegerSetGetNumInequalities(self);
                             })
      .def_property_readonly("constraints", [](PyIntegerSet &self) {
        return PyIntegerSetConstraintList(self);
      });
  PyIntegerSetConstraint::bind(m);
  PyIntegerSetConstraintList::bind(m);
}
+21 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "llvm/ADT/DenseMap.h"

namespace mlir {
@@ -726,6 +727,26 @@ private:
  MlirAffineMap affineMap;
};

class PyIntegerSet : public BaseContextObject {
public:
  PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
      : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
  bool operator==(const PyIntegerSet &other);
  operator MlirIntegerSet() const { return integerSet; }
  MlirIntegerSet get() const { return integerSet; }

  /// Gets a capsule wrapping the void* within the MlirIntegerSet.
  pybind11::object getCapsule();

  /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
  /// Note that PyIntegerSet instances may be uniqued, so the returned object
  /// may be a pre-existing object. Integer sets are owned by the context.
  static PyIntegerSet createFromCapsule(pybind11::object capsule);

private:
  MlirIntegerSet integerSet;
};

void populateIRSubmodule(pybind11::module &m);

} // namespace python
+128 −0
Original line number Diff line number Diff line
# RUN: %PYTHON %s | FileCheck %s

import gc
from mlir.ir import *

def run(f):
  print("\nTEST:", f.__name__)
  f()
  gc.collect()
  assert Context._get_live_count() == 0


# CHECK-LABEL: TEST: testIntegerSetCapsule
def testIntegerSetCapsule():
  with Context() as ctx:
    is1 = IntegerSet.get_empty(1, 1, ctx)
  capsule = is1._CAPIPtr
  # CHECK: mlir.ir.IntegerSet._CAPIPtr
  print(capsule)
  is2 = IntegerSet._CAPICreate(capsule)
  assert is1 == is2
  assert is2.context is ctx

run(testIntegerSetCapsule)


# CHECK-LABEL: TEST: testIntegerSetGet
def testIntegerSetGet():
  with Context():
    d0 = AffineDimExpr.get(0)
    d1 = AffineDimExpr.get(1)
    s0 = AffineSymbolExpr.get(0)
    c42 = AffineConstantExpr.get(42)

    # CHECK: (d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)
    set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42], [True, False])
    print(set0)

    # CHECK: (d0)[s0] : (1 == 0)
    set1 = IntegerSet.get_empty(1, 1)
    print(set1)

    # CHECK: (d0)[s0, s1] : (d0 - s1 == 0, s0 - 42 >= 0)
    set2 = set0.get_replaced([d0, AffineSymbolExpr.get(1)], [s0], 1, 2)
    print(set2)

    try:
      IntegerSet.get(2, 1, [], [])
    except ValueError as e:
      # CHECK: Expected non-empty list of constraints
      print(e)

    try:
      IntegerSet.get(2, 1, [d0 - d1], [True, False])
    except ValueError as e:
      # CHECK: Expected the number of constraints to match that of equality flags
      print(e)

    try:
      IntegerSet.get(2, 1, [0], [True])
    except RuntimeError as e:
      # CHECK: Invalid expression when attempting to create an IntegerSet
      print(e)

    try:
      IntegerSet.get(2, 1, [None], [True])
    except RuntimeError as e:
      # CHECK: Invalid expression (None?) when attempting to create an IntegerSet
      print(e)

    try:
      set0.get_replaced([d0], [s0], 1, 1)
    except ValueError as e:
      # CHECK: Expected the number of dimension replacement expressions to match that of dimensions
      print(e)

    try:
      set0.get_replaced([d0, d1], [s0, s0], 1, 1)
    except ValueError as e:
      # CHECK: Expected the number of symbol replacement expressions to match that of symbols
      print(e)

    try:
      set0.get_replaced([d0, 1], [s0], 1, 1)
    except RuntimeError as e:
      # CHECK: Invalid expression when attempting to create an IntegerSet by replacing dimensions
      print(e)

    try:
      set0.get_replaced([d0, d1], [None], 1, 1)
    except RuntimeError as e:
      # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols
      print(e)

run(testIntegerSetGet)


# CHECK-LABEL: TEST: testIntegerSetProperties
def testIntegerSetProperties():
  with Context():
    d0 = AffineDimExpr.get(0)
    d1 = AffineDimExpr.get(1)
    s0 = AffineSymbolExpr.get(0)
    c42 = AffineConstantExpr.get(42)

    set0 = IntegerSet.get(2, 1, [d0 - d1, s0 - c42, s0 - d0], [True, False, False])
    # CHECK: 2
    print(set0.n_dims)
    # CHECK: 1
    print(set0.n_symbols)
    # CHECK: 3
    print(set0.n_inputs)
    # CHECK: 1
    print(set0.n_equalities)
    # CHECK: 2
    print(set0.n_inequalities)

    # CHECK: 3
    print(len(set0.constraints))

    # CHECK-DAG: d0 - d1 == 0
    # CHECK-DAG: s0 - 42 >= 0
    # CHECK-DAG: -d0 + s0 >= 0
    for cstr in set0.constraints:
      print(cstr.expr, end='')
      print(" == 0" if cstr.is_eq else " >= 0")

run(testIntegerSetProperties)