Loading mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt +6 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td) mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls) mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs) add_public_tablegen_target(MLIRSparseTensorTypesIncGen) set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td) mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen) add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +1 −0 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" Loading mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h 0 → 100644 +31 −0 Original line number Diff line number Diff line //===- SparseTensorInterfaces.h - sparse tensor operations //interfaces-------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ #include "mlir/IR/OpDefinition.h" namespace mlir { class PatternRewriter; namespace sparse_tensor { class StageWithSortSparseOp; namespace detail { LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op, PatternRewriter &rewriter); } // namespace detail } // namespace sparse_tensor } // namespace mlir /// Include the generated interface declarations. #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc" #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td 0 → 100644 +45 −0 Original line number Diff line number Diff line //===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES #define SPARSETENSOR_IR_SPARSETENSORINTERFACES include "mlir/IR/OpBase.td" def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> { let description = [{ A stage-with-sort sparse tensor operation is an operation that produces unordered intermediate output. An extra sort is required to obtain the final ordered result. E.g., convert csr -> csc need to be implemented as convert csr -> unordered coo -> sort by column -> csc; and concatenate csr, csc -> csr can be staged into concatenate csr, csr -> unordered coo -> sort by row -> csr. }]; let cppNamespace = "::mlir::sparse_tensor"; let methods = [ InterfaceMethod< /*desc=*/"Return true if the operation needs an extra sort to produce the final result.", /*retTy=*/"bool", /*methodName=*/"needsExtraSort", /*args=*/(ins), /*methodBody=*/"">, InterfaceMethod< /*desc=*/"Stage the operation, return the final result value after staging.", /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"stageWithSort", /*args=*/(ins "::mlir::PatternRewriter &":$rewriter), /*methodBody=*/[{ return detail::stageWithSortImpl($_op, rewriter); }]>, ]; } #endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +13 −5 Original line number Diff line number Diff line Loading @@ -12,6 +12,7 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" Loading Loading @@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria } def SparseTensor_ConvertOp : SparseTensor_Op<"convert", [Pure]>, [Pure, StageWithSortSparseOpInterface]>, Arguments<(ins AnyTensor:$source)>, Results<(outs AnyTensor:$dest)> { string summary = "Converts between different tensor types"; Loading Loading @@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert", }]; let extraClassDeclaration = [{ // Whether the convert can be done by a single step (either a sort or a foreach), // or it would require a tmp buffer (sort, then foreach). bool directConvertable(); // Whether the convert can be done by a single step or it would require // an extra sort. Inherited from StageWithSortSparseOpInterface. bool needsExtraSort(); }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; Loading Loading @@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure] let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; } def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>, def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure, StageWithSortSparseOpInterface]>, Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>, Results<(outs AnyRankedTensor:$result)> { Loading @@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>, ``` }]; let extraClassDeclaration = [{ // Whether the concatenate can be done by a single step or it would require // an extra sort. Inherited from StageWithSortSparseOpInterface. bool needsExtraSort(); }]; let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)"; let hasVerifier = 1; } Loading Loading
mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt +6 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td) mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls) mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs) add_public_tablegen_target(MLIRSparseTensorTypesIncGen) set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td) mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen) add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +1 −0 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" Loading
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h 0 → 100644 +31 −0 Original line number Diff line number Diff line //===- SparseTensorInterfaces.h - sparse tensor operations //interfaces-------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_ #include "mlir/IR/OpDefinition.h" namespace mlir { class PatternRewriter; namespace sparse_tensor { class StageWithSortSparseOp; namespace detail { LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op, PatternRewriter &rewriter); } // namespace detail } // namespace sparse_tensor } // namespace mlir /// Include the generated interface declarations. #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc" #endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td 0 → 100644 +45 −0 Original line number Diff line number Diff line //===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES #define SPARSETENSOR_IR_SPARSETENSORINTERFACES include "mlir/IR/OpBase.td" def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> { let description = [{ A stage-with-sort sparse tensor operation is an operation that produces unordered intermediate output. An extra sort is required to obtain the final ordered result. E.g., convert csr -> csc need to be implemented as convert csr -> unordered coo -> sort by column -> csc; and concatenate csr, csc -> csr can be staged into concatenate csr, csr -> unordered coo -> sort by row -> csr. }]; let cppNamespace = "::mlir::sparse_tensor"; let methods = [ InterfaceMethod< /*desc=*/"Return true if the operation needs an extra sort to produce the final result.", /*retTy=*/"bool", /*methodName=*/"needsExtraSort", /*args=*/(ins), /*methodBody=*/"">, InterfaceMethod< /*desc=*/"Stage the operation, return the final result value after staging.", /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"stageWithSort", /*args=*/(ins "::mlir::PatternRewriter &":$rewriter), /*methodBody=*/[{ return detail::stageWithSortImpl($_op, rewriter); }]>, ]; } #endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +13 −5 Original line number Diff line number Diff line Loading @@ -12,6 +12,7 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" Loading Loading @@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria } def SparseTensor_ConvertOp : SparseTensor_Op<"convert", [Pure]>, [Pure, StageWithSortSparseOpInterface]>, Arguments<(ins AnyTensor:$source)>, Results<(outs AnyTensor:$dest)> { string summary = "Converts between different tensor types"; Loading Loading @@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert", }]; let extraClassDeclaration = [{ // Whether the convert can be done by a single step (either a sort or a foreach), // or it would require a tmp buffer (sort, then foreach). bool directConvertable(); // Whether the convert can be done by a single step or it would require // an extra sort. Inherited from StageWithSortSparseOpInterface. bool needsExtraSort(); }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; Loading Loading @@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure] let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; } def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>, def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure, StageWithSortSparseOpInterface]>, Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>, Results<(outs AnyRankedTensor:$result)> { Loading @@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>, ``` }]; let extraClassDeclaration = [{ // Whether the concatenate can be done by a single step or it would require // an extra sort. Inherited from StageWithSortSparseOpInterface. bool needsExtraSort(); }]; let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)"; let hasVerifier = 1; } Loading