Unverified Commit d463a276 authored by Sergio Afonso's avatar Sergio Afonso Committed by GitHub
Browse files

[OpenMP][OMPIRBuilder] Support parallel in Generic kernels (#150926)

This patch introduces codegen logic to produce a wrapper function
argument for the `__kmpc_parallel_51` DeviceRTL function needed to
handle arguments passed using device shared memory in Generic mode.
parent 82f25499
Loading
Loading
Loading
Loading
+94 −6
Original line number Diff line number Diff line
@@ -1511,6 +1511,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
  return Error::success();
}

/// Create wrapper function used to gather the outlined function's argument
/// structure from a shared buffer and to forward them to it when running in
/// Generic mode.
///
/// The outlined function is expected to receive 2 integer arguments followed by
/// an optional pointer argument to an argument structure holding the rest.
static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
                                             Function &OutlinedFn) {
  size_t NumArgs = OutlinedFn.arg_size();
  assert((NumArgs == 2 || NumArgs == 3) &&
         "expected a 2-3 argument parallel outlined function");
  bool UseArgStruct = NumArgs == 3;

  IRBuilder<> &Builder = OMPIRBuilder->Builder;
  IRBuilder<>::InsertPointGuard IPG(Builder);
  auto *FnTy = FunctionType::get(Builder.getVoidTy(),
                                 {Builder.getInt16Ty(), Builder.getInt32Ty()},
                                 /*isVarArg=*/false);
  auto *WrapperFn =
      Function::Create(FnTy, GlobalValue::InternalLinkage,
                       OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);

  WrapperFn->addParamAttr(0, Attribute::NoUndef);
  WrapperFn->addParamAttr(0, Attribute::ZExt);
  WrapperFn->addParamAttr(1, Attribute::NoUndef);

  BasicBlock *EntryBB =
      BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
  Builder.SetInsertPoint(EntryBB);

  // Allocation.
  Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
                                           /*ArraySize=*/nullptr, "addr");
  AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
      AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
      AddrAlloca->getName() + ".ascast");

  Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
                                           /*ArraySize=*/nullptr, "zero");
  ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
      ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
      ZeroAlloca->getName() + ".ascast");

  Value *ArgsAlloca = nullptr;
  if (UseArgStruct) {
    ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
                                      /*ArraySize=*/nullptr, "global_args");
    ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
        ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
        ArgsAlloca->getName() + ".ascast");
  }

  // Initialization.
  Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
  Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
  if (UseArgStruct) {
    Builder.CreateCall(
        OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
            llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
        {ArgsAlloca});
  }

  SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};

  // Load structArg from global_args.
  if (UseArgStruct) {
    Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
    StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
                                          {Builder.getInt64(0)});
    StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
    Args.push_back(StructArg);
  }

  // Call the outlined function holding the parallel body.
  Builder.CreateCall(&OutlinedFn, Args);
  Builder.CreateRetVoid();

  return WrapperFn;
}

// Callback used to create OpenMP runtime calls to support
// omp parallel clause for the device.
// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1520,6 +1600,10 @@ static void targetParallelCallback(
    BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
    Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
    Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
  assert(OutlinedFn.arg_size() >= 2 &&
         "Expected at least tid and bounded tid as arguments");
  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;

  // Add some known attributes.
  IRBuilder<> &Builder = OMPIRBuilder->Builder;
  OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1528,17 +1612,12 @@ static void targetParallelCallback(
  OutlinedFn.addParamAttr(1, Attribute::NoUndef);
  OutlinedFn.addFnAttr(Attribute::NoUnwind);

  assert(OutlinedFn.arg_size() >= 2 &&
         "Expected at least tid and bounded tid as arguments");
  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;

  CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
  assert(CI && "Expected call instruction to outlined function");
  CI->getParent()->setName("omp_parallel");

  Builder.SetInsertPoint(CI);
  Type *PtrTy = OMPIRBuilder->VoidPtr;
  Value *NullPtrValue = Constant::getNullValue(PtrTy);

  // Add alloca for kernel args
  OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1564,6 +1643,15 @@ static void targetParallelCallback(
      IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
                  : Builder.getInt32(1);

  // If this is not a Generic kernel, we can skip generating the wrapper.
  std::optional<omp::OMPTgtExecModeFlags> ExecMode =
      getTargetKernelExecMode(*OuterFn);
  Value *WrapperFn;
  if (ExecMode && (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC))
    WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
  else
    WrapperFn = Constant::getNullValue(PtrTy);

  // Build kmpc_parallel_60 call
  Value *Parallel60CallArgs[] = {
      /* identifier*/ Ident,
@@ -1572,7 +1660,7 @@ static void targetParallelCallback(
      /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
      /* Proc bind */ Builder.getInt32(-1),
      /* outlined function */ &OutlinedFn,
      /* wrapper function */ NullPtrValue,
      /* wrapper function */ WrapperFn,
      /* arguments of the outlined funciton*/ Args,
      /* number of arguments */ Builder.getInt64(NumCapturedVars),
      /* strict for number of threads */ Builder.getInt32(0)};
+22 −3
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
// CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
// CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
// CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0)
// CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1, i32 0)
// CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
// CHECK:         call void @__kmpc_target_deinit()

@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (
// CHECK-SAME:  ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
// CHECK-SAME:  i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
// CHECK-SAME:  i32 -1,  ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1, i32 0)
// CHECK-SAME:  i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1, i32 0)

// One of the arguments of  kmpc_parallel_60 function is responsible for handling if clause
// of omp parallel construct for target region. If this  argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (
// CHECK-SAME:  ptr addrspace(1) {{.*}} to ptr),
// CHECK-SAME:  i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr null, ptr {{.*}}, i64 1, i32 0)
// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1, i32 0)

// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])

// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
// CHECK-NOT: define
// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})