Commit 731b140a authored by Denis Khalikov's avatar Denis Khalikov Committed by Lei Zhang
Browse files

[mlir][spirv] Add GroupNonUniform arithmetic operations.

Add GroupNonUniform arithmetic operations: FAdd, FMul, IMul.
Unify parser, printer, verifier for GroupNonUniform arithmetic
operations.

Differential Revision: https://reviews.llvm.org/D73491
parent cb74d2e1
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -3158,6 +3158,9 @@ def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 33
def SPV_OC_OpGroupNonUniformElect      : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
def SPV_OC_OpGroupNonUniformBallot     : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPV_OC_OpGroupNonUniformIAdd       : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>;
def SPV_OC_OpGroupNonUniformFAdd       : I32EnumAttrCase<"OpGroupNonUniformFAdd", 350>;
def SPV_OC_OpGroupNonUniformIMul       : I32EnumAttrCase<"OpGroupNonUniformIMul", 351>;
def SPV_OC_OpGroupNonUniformFMul       : I32EnumAttrCase<"OpGroupNonUniformFMul", 352>;
def SPV_OC_OpSubgroupBallotKHR         : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;

def SPV_OpcodeAttr :
@@ -3205,7 +3208,9 @@ def SPV_OpcodeAttr :
      SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
      SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
      SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpSubgroupBallotKHR
      SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
      SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
      SPV_OC_OpSubgroupBallotKHR
    ]>;

// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
+179 −16
Original line number Diff line number Diff line
@@ -14,6 +14,35 @@
#ifndef SPIRV_NON_UNIFORM_OPS
#define SPIRV_NON_UNIFORM_OPS

class SPV_GroupNonUniformArithmeticOp<string mnemonic, Type type,
      list<OpTrait> traits = []> : SPV_Op<mnemonic, traits> {

  let availability = [
    MinVersion<SPV_V_1_3>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[SPV_C_GroupNonUniformArithmetic,
                SPV_C_GroupNonUniformClustered,
                SPV_C_GroupNonUniformPartitionedNV]>
  ];

  let arguments = (ins
    SPV_ScopeAttr:$execution_scope,
    SPV_GroupOperationAttr:$group_operation,
    SPV_ScalarOrVectorOf<type>:$value,
    SPV_Optional<SPV_Integer>:$cluster_size
  );

  let results = (outs
    SPV_ScalarOrVectorOf<type>:$result
  );

  let parser = [{ return parseGroupNonUniformArithmeticOp(parser, result); }];
  let printer = [{ printGroupNonUniformArithmeticOp(getOperation(), p); }];
  let verifier = [{ return ::verifyGroupNonUniformArithmeticOp(getOperation()); }];

}

// -----

def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
@@ -120,7 +149,110 @@ def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> {

// -----

def SPV_GroupNonUniformIAddOp : SPV_Op<"GroupNonUniformIAdd", []> {
def SPV_GroupNonUniformFAddOp :
    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformFAdd", SPV_Float, []> {
  let summary = [{
    A floating point add group operation of all Value operands contributed
    by active invocations in the group.
  }];

  let description = [{
    Result Type  must be a scalar or vector of floating-point type.

    Execution must be Workgroup or Subgroup Scope.

    The identity I for Operation is 0. If Operation is ClusteredReduce,
    ClusterSize must be specified.

     The type of Value must be the same as Result Type.  The method used to
    perform the group operation on the contributed Value(s) from active
    invocations is implementation defined.

    ClusterSize is the size of cluster to use. ClusterSize must be a scalar
    of integer type, whose Signedness operand is 0. ClusterSize must come
    from a constant instruction. ClusterSize must be at least 1, and must be
    a power of 2. If ClusterSize is greater than the declared SubGroupSize,
    executing this instruction results in undefined behavior.

    ### Custom assembly form

    ```
    scope ::= `"Workgroup"` | `"Subgroup"`
    operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
    float-scalar-vector-type ::= float-type |
                                 `vector<` integer-literal `x` float-type `>`
    non-uniform-fadd-op ::= ssa-id `=` `spv.GroupNonUniformFAdd` scope operation
                            ssa-use ( `cluster_size` `(` ssa_use `)` )?
                            `:` float-scalar-vector-type
    ```

    For example:

    ```
    %four = spv.constant 4 : i32
    %scalar = ... : f32
    %vector = ... : vector<4xf32>
    %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %scalar : f32
    %1 = spv.GroupNonUniformFAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32>
    ```
  }];
}

// -----

def SPV_GroupNonUniformFMulOp :
    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformFMul", SPV_Float, []> {
  let summary = [{
    A floating point multiply group operation of all Value operands
    contributed by active invocations in the group.
  }];

  let description = [{
    Result Type  must be a scalar or vector of floating-point type.

    Execution must be Workgroup or Subgroup Scope.

    The identity I for Operation is 1. If Operation is ClusteredReduce,
    ClusterSize must be specified.

     The type of Value must be the same as Result Type.  The method used to
    perform the group operation on the contributed Value(s) from active
    invocations is implementation defined.

    ClusterSize is the size of cluster to use. ClusterSize must be a scalar
    of integer type, whose Signedness operand is 0. ClusterSize must come
    from a constant instruction. ClusterSize must be at least 1, and must be
    a power of 2. If ClusterSize is greater than the declared SubGroupSize,
    executing this instruction results in undefined behavior.

    ### Custom assembly form

    ```
    scope ::= `"Workgroup"` | `"Subgroup"`
    operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
    float-scalar-vector-type ::= float-type |
                                 `vector<` integer-literal `x` float-type `>`
    non-uniform-fmul-op ::= ssa-id `=` `spv.GroupNonUniformFMul` scope operation
                            ssa-use ( `cluster_size` `(` ssa_use `)` )?
                            `:` float-scalar-vector-type
    ```

    For example:

    ```
    %four = spv.constant 4 : i32
    %scalar = ... : f32
    %vector = ... : vector<4xf32>
    %0 = spv.GroupNonUniformFMul "Workgroup" "Reduce" %scalar : f32
    %1 = spv.GroupNonUniformFMul "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32>
    ```
  }];
}

// -----

def SPV_GroupNonUniformIAddOp :
    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformIAdd", SPV_Integer, []> {
  let summary = [{
    An integer add group operation of all Value operands contributed active
    by invocations in the group.
@@ -164,24 +296,55 @@ def SPV_GroupNonUniformIAddOp : SPV_Op<"GroupNonUniformIAdd", []> {
    %1 = spv.GroupNonUniformIAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
    ```
  }];
}

  let availability = [
    MinVersion<SPV_V_1_3>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformClustered, SPV_C_GroupNonUniformPartitionedNV]>
  ];
// -----

  let arguments = (ins
    SPV_ScopeAttr:$execution_scope,
    SPV_GroupOperationAttr:$group_operation,
    SPV_ScalarOrVectorOf<SPV_Integer>:$value,
    SPV_Optional<SPV_Integer>:$cluster_size
  );
def SPV_GroupNonUniformIMulOp :
    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformIMul", SPV_Integer, []> {
  let summary = [{
    An integer multiply group operation of all Value operands contributed by
    active invocations in the group.
  }];

  let results = (outs
    SPV_ScalarOrVectorOf<SPV_Integer>:$result
  );
  let description = [{
    Result Type  must be a scalar or vector of integer type.

    Execution must be Workgroup or Subgroup Scope.

    The identity I for Operation is 1. If Operation is ClusteredReduce,
    ClusterSize must be specified.

     The type of Value must be the same as Result Type.

    ClusterSize is the size of cluster to use. ClusterSize must be a scalar
    of integer type, whose Signedness operand is 0. ClusterSize must come
    from a constant instruction. ClusterSize must be at least 1, and must be
    a power of 2. If ClusterSize is greater than the declared SubGroupSize,
    executing this instruction results in undefined behavior.

    ### Custom assembly form

    ```
    scope ::= `"Workgroup"` | `"Subgroup"`
    operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
    integer-scalar-vector-type ::= integer-type |
                                 `vector<` integer-literal `x` integer-type `>`
    non-uniform-imul-op ::= ssa-id `=` `spv.GroupNonUniformIMul` scope operation
                            ssa-use ( `cluster_size` `(` ssa_use `)` )?
                            `:` integer-scalar-vector-type
    ```

    For example:

    ```
    %four = spv.constant 4 : i32
    %scalar = ... : i32
    %vector = ... : vector<4xi32>
    %0 = spv.GroupNonUniformIMul "Workgroup" "Reduce" %scalar : i32
    %1 = spv.GroupNonUniformIMul "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
    ```
  }];
}

// -----
+82 −76
Original line number Diff line number Diff line
@@ -588,6 +588,88 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
  return success();
}

static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
                                                    OperationState &state) {
  spirv::Scope executionScope;
  spirv::GroupOperation groupOperation;
  OpAsmParser::OperandType valueInfo;
  if (parseEnumAttribute(executionScope, parser, state,
                         kExecutionScopeAttrName) ||
      parseEnumAttribute(groupOperation, parser, state,
                         kGroupOperationAttrName) ||
      parser.parseOperand(valueInfo))
    return failure();

  Optional<OpAsmParser::OperandType> clusterSizeInfo;
  if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
    clusterSizeInfo = OpAsmParser::OperandType();
    if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
        parser.parseRParen())
      return failure();
  }

  Type resultType;
  if (parser.parseColonType(resultType))
    return failure();

  if (parser.resolveOperand(valueInfo, resultType, state.operands))
    return failure();

  if (clusterSizeInfo.hasValue()) {
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
      return failure();
  }

  return parser.addTypeToList(resultType, state.types);
}

static void printGroupNonUniformArithmeticOp(Operation *groupOp,
                                             OpAsmPrinter &printer) {
  printer << groupOp->getName() << " \""
          << stringifyScope(static_cast<spirv::Scope>(
                 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
                     .getInt()))
          << "\" \""
          << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
                 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
                     .getInt()))
          << "\" " << groupOp->getOperand(0);

  if (groupOp->getNumOperands() > 1)
    printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
  printer << " : " << groupOp->getResult(0).getType();
}

static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
  spirv::Scope scope = static_cast<spirv::Scope>(
      groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return groupOp->emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
      groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
  if (operation == spirv::GroupOperation::ClusteredReduce &&
      groupOp->getNumOperands() == 1)
    return groupOp->emitOpError("cluster size operand must be provided for "
                                "'ClusteredReduce' group operation");
  if (groupOp->getNumOperands() > 1) {
    Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
    int32_t clusterSize = 0;

    // TODO(antiagainst): support specialization constant here.
    if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
      return groupOp->emitOpError(
          "cluster size operand must come from a constant op");

    if (!llvm::isPowerOf2_32(clusterSize))
      return groupOp->emitOpError(
          "cluster size operand must be a power of two");
  }
  return success();
}

// Parses an op that has no inputs and no outputs.
static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
  if (parser.parseOptionalAttrDict(state.attributes))
@@ -1939,83 +2021,7 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
  return success();
}

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformIAddOp
//===----------------------------------------------------------------------===//

static ParseResult parseGroupNonUniformIAddOp(OpAsmParser &parser,
                                              OperationState &state) {
  spirv::Scope executionScope;
  spirv::GroupOperation groupOperation;
  OpAsmParser::OperandType valueInfo;
  if (parseEnumAttribute(executionScope, parser, state,
                         kExecutionScopeAttrName) ||
      parseEnumAttribute(groupOperation, parser, state,
                         kGroupOperationAttrName) ||
      parser.parseOperand(valueInfo))
    return failure();

  Optional<OpAsmParser::OperandType> clusterSizeInfo;
  if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
    clusterSizeInfo = OpAsmParser::OperandType();
    if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
        parser.parseRParen())
      return failure();
  }

  Type resultType;
  if (parser.parseColonType(resultType))
    return failure();

  if (parser.resolveOperand(valueInfo, resultType, state.operands))
    return failure();

  if (clusterSizeInfo.hasValue()) {
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
      return failure();
  }

  return parser.addTypeToList(resultType, state.types);
}

static void print(spirv::GroupNonUniformIAddOp groupOp, OpAsmPrinter &printer) {
  printer << spirv::GroupNonUniformIAddOp::getOperationName() << " \""
          << stringifyScope(groupOp.execution_scope()) << "\" \""
          << stringifyGroupOperation(groupOp.group_operation()) << "\" "
          << groupOp.value();
  if (!groupOp.cluster_size().empty())
    printer << " " << kClusterSize << '(' << groupOp.cluster_size() << ')';
  printer << " : " << groupOp.getType();
}

static LogicalResult verify(spirv::GroupNonUniformIAddOp groupOp) {
  spirv::Scope scope = groupOp.execution_scope();
  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
    return groupOp.emitOpError(
        "execution scope must be 'Workgroup' or 'Subgroup'");

  spirv::GroupOperation operation = groupOp.group_operation();
  if (operation == spirv::GroupOperation::ClusteredReduce &&
      groupOp.cluster_size().empty())
    return groupOp.emitOpError("cluster size operand must be provided for "
                               "'ClusteredReduce' group operation");

  if (!groupOp.cluster_size().empty()) {
    Operation *sizeOp = (*groupOp.cluster_size().begin()).getDefiningOp();
    int32_t clusterSize = 0;

    // TODO(antiagainst): support specialization constant here.
    if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
      return groupOp.emitOpError(
          "cluster size operand must come from a constant op");

    if (!llvm::isPowerOf2_32(clusterSize))
      return groupOp.emitOpError("cluster size operand must be a power of two");
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spv.IAdd
+22 −0
Original line number Diff line number Diff line
@@ -15,6 +15,20 @@ spv.module "Logical" "GLSL450" {
    spv.ReturnValue %0: i1
  }

  // CHECK-LABEL: @group_non_uniform_fadd_reduce
  func @group_non_uniform_fadd_reduce(%val: f32) -> f32 {
    // CHECK: %{{.+}} = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %{{.+}} : f32
    %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %val : f32
    spv.ReturnValue %0: f32
  }

  // CHECK-LABEL: @group_non_uniform_fmul_reduce
  func @group_non_uniform_fmul_reduce(%val: f32) -> f32 {
    // CHECK: %{{.+}} = spv.GroupNonUniformFMul "Workgroup" "Reduce" %{{.+}} : f32
    %0 = spv.GroupNonUniformFMul "Workgroup" "Reduce" %val : f32
    spv.ReturnValue %0: f32
  }

  // CHECK-LABEL: @group_non_uniform_iadd_reduce
  func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
    // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32
@@ -29,4 +43,12 @@ spv.module "Logical" "GLSL450" {
    %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32>
    spv.ReturnValue %0: vector<2xi32>
  }

  // CHECK-LABEL: @group_non_uniform_imul_reduce
  func @group_non_uniform_imul_reduce(%val: i32) -> i32 {
    // CHECK: %{{.+}} = spv.GroupNonUniformIMul "Workgroup" "Reduce" %{{.+}} : i32
    %0 = spv.GroupNonUniformIMul "Workgroup" "Reduce" %val : i32
    spv.ReturnValue %0: i32
  }

}
+61 −0
Original line number Diff line number Diff line
@@ -41,6 +41,46 @@ func @group_non_uniform_elect() -> i1 {

// -----

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformFAdd
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @group_non_uniform_fadd_reduce
func @group_non_uniform_fadd_reduce(%val: f32) -> f32 {
  // CHECK: %{{.+}} = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %{{.+}} : f32
  %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %val : f32
  return %0: f32
}

// CHECK-LABEL: @group_non_uniform_fadd_clustered_reduce
func @group_non_uniform_fadd_clustered_reduce(%val: vector<2xf32>) -> vector<2xf32> {
  %four = spv.constant 4 : i32
  // CHECK: %{{.+}} = spv.GroupNonUniformFAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xf32>
  %0 = spv.GroupNonUniformFAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xf32>
  return %0: vector<2xf32>
}

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformFMul
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @group_non_uniform_fmul_reduce
func @group_non_uniform_fmul_reduce(%val: f32) -> f32 {
  // CHECK: %{{.+}} = spv.GroupNonUniformFMul "Workgroup" "Reduce" %{{.+}} : f32
  %0 = spv.GroupNonUniformFMul "Workgroup" "Reduce" %val : f32
  return %0: f32
}

// CHECK-LABEL: @group_non_uniform_fmul_clustered_reduce
func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vector<2xf32> {
  %four = spv.constant 4 : i32
  // CHECK: %{{.+}} = spv.GroupNonUniformFMul "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xf32>
  %0 = spv.GroupNonUniformFMul "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xf32>
  return %0: vector<2xf32>
}

// -----

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformIAdd
//===----------------------------------------------------------------------===//
@@ -92,3 +132,24 @@ func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi
  %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%five) : vector<2xi32>
  return %0: vector<2xi32>
}

// -----

//===----------------------------------------------------------------------===//
// spv.GroupNonUniformIMul
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @group_non_uniform_imul_reduce
func @group_non_uniform_imul_reduce(%val: i32) -> i32 {
  // CHECK: %{{.+}} = spv.GroupNonUniformIMul "Workgroup" "Reduce" %{{.+}} : i32
  %0 = spv.GroupNonUniformIMul "Workgroup" "Reduce" %val : i32
  return %0: i32
}

// CHECK-LABEL: @group_non_uniform_imul_clustered_reduce
func @group_non_uniform_imul_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
  %four = spv.constant 4 : i32
  // CHECK: %{{.+}} = spv.GroupNonUniformIMul "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32>
  %0 = spv.GroupNonUniformIMul "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32>
  return %0: vector<2xi32>
}