Unverified Commit d4088e7d authored by Yinying Li's avatar Yinying Li Committed by GitHub
Browse files

[mlir][sparse] Populate lvlToDim (#68937)

Updates:
1. Infer lvlToDim from dimToLvl
2. Add more tests for block sparsity
3. Finish TODOs related to lvlToDim, including adding lvlToDim to python
binding

Verification of lvlToDim that user provides will be implemented in the
next PR.
parent 9922aadf
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -51,11 +51,10 @@ MLIR_CAPI_EXPORTED bool
mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);

/// Creates a `sparse_tensor.encoding` attribute with the given parameters.
/// TODO: add a version that supplied lvlToDim when it cannot be inferred
MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
    MlirContext ctx, intptr_t lvlRank,
    enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl,
    int posWidth, int crdWidth);
    MlirAffineMap lvlTodim, int posWidth, int crdWidth);

/// Returns the level-rank of the `sparse_tensor.encoding` attribute.
MLIR_CAPI_EXPORTED intptr_t
+13 −0
Original line number Diff line number Diff line
@@ -160,6 +160,19 @@ inline bool hasAnySparseOperandOrResult(Operation *op) {
  return hasAnySparseOperand(op) || hasAnySparseResult(op);
}

//
// Inference.
//

/// Given the dimToLvl map, infers the lvlToDim map, or returns
/// empty Affine map when inference fails.
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);

/// Returns the lvlToDim map for the given dimToLvl map specific
/// to the block sparse cases.
/// Asserts on failure (so only use when known to succeed).
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);

//
// Reordering.
//
+3 −0
Original line number Diff line number Diff line
@@ -307,6 +307,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
                     "AffineMap":$lvlToDim,
                     "unsigned":$posWidth,
                     "unsigned":$crdWidth), [{
      if (!lvlToDim) {
        lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
      }
      return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
        ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
    }]>
+13 −4
Original line number Diff line number Diff line
@@ -41,16 +41,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
      .def_classmethod(
          "get",
          [](py::object cls, std::vector<MlirSparseTensorDimLevelType> lvlTypes,
             std::optional<MlirAffineMap> dimToLvl, int posWidth, int crdWidth,
             std::optional<MlirAffineMap> dimToLvl,
             std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
             MlirContext context) {
            // TODO: provide dimToLvl
            return cls(mlirSparseTensorEncodingAttrGet(
                context, lvlTypes.size(), lvlTypes.data(),
                dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth,
                dimToLvl ? *dimToLvl : MlirAffineMap{nullptr},
                lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth,
                crdWidth));
          },
          py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"),
          py::arg("pos_width"), py::arg("crd_width"),
          py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"),
          py::arg("context") = py::none(),
          "Gets a sparse_tensor.encoding from parameters.")
      .def_property_readonly(
@@ -71,6 +72,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
              return {};
            return ret;
          })
      .def_property_readonly(
          "lvl_to_dim",
          [](MlirAttribute self) -> std::optional<MlirAffineMap> {
            MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self);
            if (mlirAffineMapIsNull(ret))
              return {};
            return ret;
          })
      .def_property_readonly("pos_width",
                             mlirSparseTensorEncodingAttrGetPosWidth)
      .def_property_readonly("crd_width",
+3 −4
Original line number Diff line number Diff line
@@ -48,15 +48,14 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
MlirAttribute
mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
                                MlirSparseTensorDimLevelType const *lvlTypes,
                                MlirAffineMap dimToLvl, int posWidth,
                                int crdWidth) {
                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
                                int posWidth, int crdWidth) {
  SmallVector<DimLevelType> cppLvlTypes;
  cppLvlTypes.reserve(lvlRank);
  for (intptr_t l = 0; l < lvlRank; ++l)
    cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
  mlir::AffineMap lvlToDim; // TODO: provide in API
  return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
                                            unwrap(dimToLvl), lvlToDim,
                                            unwrap(dimToLvl), unwrap(lvlToDim),
                                            posWidth, crdWidth));
}

Loading