Unverified Commit 761c9dd9 authored by Peiming Liu's avatar Peiming Liu Committed by GitHub
Browse files

[mlir][sparse] implementating stageSparseOpPass as an interface (#69022)

parent a22a1fe1
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td)
mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls)
mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRSparseTensorTypesIncGen)

set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td)
mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen)
add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)
+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+31 −0
Original line number Diff line number Diff line
//===- SparseTensorInterfaces.h - sparse tensor operations
//interfaces-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_

#include "mlir/IR/OpDefinition.h"

namespace mlir {
class PatternRewriter;

namespace sparse_tensor {
class StageWithSortSparseOp;

namespace detail {
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
                                PatternRewriter &rewriter);
} // namespace detail
} // namespace sparse_tensor
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc"

#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
+45 −0
Original line number Diff line number Diff line
//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES
#define SPARSETENSOR_IR_SPARSETENSORINTERFACES

include "mlir/IR/OpBase.td"

def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
  let description = [{
    A stage-with-sort sparse tensor operation is an operation that produces
    unordered intermediate output. An extra sort is required to obtain the final
    ordered result.

    E.g., convert csr -> csc need to be implemented as
          convert csr -> unordered coo -> sort by column -> csc; and
          concatenate csr, csc -> csr can be staged into
          concatenate csr, csr -> unordered coo -> sort by row -> csr.
  }];
  let cppNamespace = "::mlir::sparse_tensor";
  let methods = [
    InterfaceMethod<
    /*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
    /*retTy=*/"bool",
    /*methodName=*/"needsExtraSort",
    /*args=*/(ins),
    /*methodBody=*/"">,
    InterfaceMethod<
    /*desc=*/"Stage the operation, return the final result value after staging.",
    /*retTy=*/"::mlir::LogicalResult",
    /*methodName=*/"stageWithSort",
    /*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
    /*methodBody=*/[{
        return detail::stageWithSortImpl($_op, rewriter);
    }]>,
  ];
}


#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
+13 −5
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@
include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

@@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
}

def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
  [Pure]>,
  [Pure, StageWithSortSparseOpInterface]>,
    Arguments<(ins AnyTensor:$source)>,
    Results<(outs AnyTensor:$dest)> {
  string summary = "Converts between different tensor types";
@@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
  }];

  let extraClassDeclaration = [{
     // Whether the convert can be done by a single step (either a sort or a foreach),
     // or it would require a tmp buffer (sort, then foreach).
     bool directConvertable();
     // Whether the convert can be done by a single step or it would require
     // an extra sort. Inherited from StageWithSortSparseOpInterface.
     bool needsExtraSort();
  }];

  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
@@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure]
  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
}

def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
                                 [Pure, StageWithSortSparseOpInterface]>,
    Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>,
    Results<(outs AnyRankedTensor:$result)> {

@@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
     ```
   }];

  let extraClassDeclaration = [{
     // Whether the concatenate can be done by a single step or it would require
     // an extra sort. Inherited from StageWithSortSparseOpInterface.
     bool needsExtraSort();
  }];

  let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
  let hasVerifier = 1;
}
Loading