Unverified Commit 25073147 authored by Chaitanya's avatar Chaitanya Committed by GitHub
Browse files

[OpenMP][mlir] Add DynGroupPrivateClause in omp dialect (#153562)

- The `dyn_groupprivate` clause allows to dynamically allocate
group-private memory in OpenMP parallel regions, specifically for
`target` and `teams` directives.
- This clause enables runtime-sized private memory allocation and
applicable to target and teams ops.

This PR enables dyn_groupprivate clause in openmp mlir dialect and adds
it to Teams and Target ops. Also includes parser, printer and
verification for clause.
parent 8b258206
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -313,6 +313,7 @@ private:
      targetOp.setDependIteratedKindsAttr(nullptr);
      targetOp.getDeviceMutable().clear();
      targetOp.getIfExprMutable().clear();
      targetOp.getDynGroupprivateSizeMutable().clear();

      // TODO: Clear some of these operands rather than rewriting them,
      // depending on whether they are needed by device codegen once support for
+14 −6
Original line number Diff line number Diff line
@@ -761,7 +761,9 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
      targetOp.getDynGroupprivateFallbackAttr(),
      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
@@ -1482,8 +1484,10 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars,
      targetOp.getIfExpr(), targetOp.getInReductionVars(),
      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
      targetOp.getDynGroupprivateFallbackAttr(),
      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
      preHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
      targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
      targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
@@ -1573,7 +1577,9 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
      targetOp.getDynGroupprivateFallbackAttr(),
      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
      isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
      targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
@@ -1654,8 +1660,10 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
      targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(),
      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars,
      targetOp.getIfExpr(), targetOp.getInReductionVars(),
      targetOp.getDevice(), targetOp.getDynGroupprivateAccessGroupAttr(),
      targetOp.getDynGroupprivateFallbackAttr(),
      targetOp.getDynGroupprivateSize(), targetOp.getHasDeviceAddrVars(),
      postHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
      targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
      targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
      targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
+35 −0
Original line number Diff line number Diff line
@@ -1845,4 +1845,39 @@ class OpenMP_UniformClauseSkip<

def OpenMP_UniformClause : OpenMP_UniformClauseSkip<>;

//===----------------------------------------------------------------------===//
// V6.1 `dyn_groupprivate` clause
//===----------------------------------------------------------------------===//

class OpenMP_DynGroupprivateClauseSkip<
    bit traits = false, bit arguments = false, bit assemblyFormat = false,
    bit description = false, bit extraClassDeclaration = false
  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                    extraClassDeclaration> {

  let arguments = (ins
    OptionalAttr<AccessGroupModifierAttr>:$dyn_groupprivate_access_group,
    OptionalAttr<FallbackModifierAttr>:$dyn_groupprivate_fallback,
    Optional<AnyInteger>:$dyn_groupprivate_size
  );

  let description = [{
    The `dyn_groupprivate_access_group` attribute specifies the access group
    modifier for the dynamically allocated group-private memory. The
    `dyn_groupprivate_fallback` attribute specifies the fallback behavior when
    allocation fails. The `dyn_groupprivate_size` operand specifies the size in
    bytes to allocate.
  }];

  let optAssemblyFormat = [{
    `dyn_groupprivate` `(`
      custom<DynGroupprivateClause>($dyn_groupprivate_access_group,
      $dyn_groupprivate_fallback,
      $dyn_groupprivate_size, type($dyn_groupprivate_size))
    `)`
  }];
}

def OpenMP_DynGroupprivateClause : OpenMP_DynGroupprivateClauseSkip<>;

#endif // OPENMP_CLAUSES
+38 −0
Original line number Diff line number Diff line
@@ -337,4 +337,42 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
  let assemblyFormat = "`(` $value `)`";
}

//===----------------------------------------------------------------------===//
// access_group_modifier enum.
//===----------------------------------------------------------------------===//

def AccessGroupCGroup : I32EnumAttrCase<"cgroup", 0>;

def AccessGroupModifier : OpenMP_I32EnumAttr<
    "AccessGroupModifier",
    "access group modifier", [
      AccessGroupCGroup
    ]>;

def AccessGroupModifierAttr : OpenMP_EnumAttr<AccessGroupModifier,
                                            "access_group_modifier"> {
  let assemblyFormat = "`(` $value `)`";
}

//===----------------------------------------------------------------------===//
// fallback_modifier enum.
//===----------------------------------------------------------------------===//

def FallbackAbort : I32EnumAttrCase<"abort", 0>;
def FallbackNull : I32EnumAttrCase<"null", 1>;
def FallbackDefaultMem : I32EnumAttrCase<"default_mem", 2>;

def FallbackModifier : OpenMP_I32EnumAttr<
    "FallbackModifier",
    "fallback modifier", [
      FallbackAbort,
      FallbackNull,
      FallbackDefaultMem
    ]>;

def FallbackModifierAttr : OpenMP_EnumAttr<FallbackModifier,
                                            "fallback_modifier"> {
  let assemblyFormat = "`(` $value `)`";
}

#endif // OPENMP_ENUMS
+5 −3
Original line number Diff line number Diff line
@@ -240,8 +240,9 @@ def TerminatorOp : OpenMP_Op<"terminator", [Terminator, Pure]> {
def TeamsOp : OpenMP_Op<"teams", traits = [
    AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
  ], clauses = [
    OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
    OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
    OpenMP_AllocateClause, OpenMP_DynGroupprivateClause, OpenMP_IfClause,
    OpenMP_NumTeamsClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
    OpenMP_ThreadLimitClause
  ], singleRegion = true> {
  let summary = "teams construct";
  let description = [{
@@ -1579,7 +1580,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
  ], clauses = [
    // TODO: Complete clause list (defaultmap, uses_allocators).
    OpenMP_AllocateClause, OpenMP_BareClause, OpenMP_DependClause,
    OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
    OpenMP_DeviceClause, OpenMP_DynGroupprivateClause,
    OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
    OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
    OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
    OpenMP_PrivateClause, OpenMP_ThreadLimitClause
Loading