Commit 2bcb6208 authored by HazemAbdelhafez's avatar HazemAbdelhafez Committed by Lei Zhang
Browse files

[mlir][spirv] Add TransposeOp

Add Transpose operation to SPIRV dialect.

Differential Revision: https://reviews.llvm.org/D82308
parent 090c108d
......@@ -3141,6 +3141,7 @@ def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>
def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>;
def SPV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>;
def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>;
def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>;
def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>;
......@@ -3265,20 +3266,21 @@ def SPV_OpcodeAttr :
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
......
......@@ -45,6 +45,13 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Matrix]>
];
let arguments = (ins
SPV_AnyMatrix:$matrix,
SPV_Float:$scalar
......@@ -72,4 +79,58 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> {
// -----
def SPV_TransposeOp : SPV_Op<"Transpose", []> {
let summary = "Transpose a matrix.";
let description = [{
Result Type must be an OpTypeMatrix.
Matrix must be an object of type OpTypeMatrix. The number of columns and
the column size of Matrix must be the reverse of those in Result Type.
The types of the scalar components in Matrix and Result Type must be the
same.
Matrix must have of type of OpTypeMatrix.
<!-- End of AutoGen section -->
```
transpose-op ::= ssa-id `=` `spv.Transpose` ssa-use `:` matrix-type `->`
matrix-type
```mlir
#### Example:
```
%0 = spv.Transpose %matrix: !spv.matrix<2 x vector<3xf32>> ->
!spv.matrix<3 x vector<2xf32>>
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Matrix]>
];
let arguments = (ins
SPV_AnyMatrix:$matrix
);
let results = (outs
SPV_AnyMatrix:$result
);
let assemblyFormat = [{
operands attr-dict `:` type($matrix) `->` type($result)
}];
let verifier = [{ return verifyTranspose(*this); }];
}
// -----
#endif // SPIRV_MATRIX_OPS
\ No newline at end of file
......@@ -2815,6 +2815,36 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.Transpose
//===----------------------------------------------------------------------===//
static LogicalResult verifyTranspose(spirv::TransposeOp op) {
auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
// Verify that the input and output matrices have correct shapes.
if (auto inputMatrixColumns =
inputMatrix.getElementType().dyn_cast<VectorType>()) {
if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements())
return op.emitError("input matrix rows count must be equal to "
"output matrix columns count");
if (auto resultMatrixColumns =
resultMatrix.getElementType().dyn_cast<VectorType>()) {
if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements())
return op.emitError("input matrix columns count must be equal "
"to output matrix rows count");
// Verify that the input and output matrices have the same component type
if (inputMatrixColumns.getElementType() !=
resultMatrixColumns.getElementType())
return op.emitError("input and output matrices must have the "
"same component type");
}
}
return success();
}
namespace mlir {
namespace spirv {
......
......@@ -22,6 +22,13 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
}
// CHECK-LABEL: @matrix_transpose_1
spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
}
}
// -----
......
......@@ -2,11 +2,25 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @matrix_times_scalar
spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
spv.func @matrix_times_scalar(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_transpose_1
spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>>
}
// CHECK-LABEL: @matrix_transpose_2
spv.func @matrix_transpose_2(%arg0 : !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None" {
// CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>>
}
}
// -----
......@@ -37,5 +51,26 @@ func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 :
%result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>>
}
// -----
func @transpose_op_shape_mismatch_1(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<3 x vector<3xf32>>
spv.Return
}
// -----
func @transpose_op_shape_mismatch_2(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<2 x vector<4xf32>>
spv.Return
}
// -----
func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () {
// expected-error @+1 {{input and output matrices must have the same component type}}
%result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>>
spv.Return
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment