Unverified Commit fad06a41 authored by Sergio Afonso's avatar Sergio Afonso Committed by GitHub
Browse files

[MLIR][OpenMP][OMPIRBuilder] Improve shared memory checks (#161864)

This patch refines checks to decide whether to use device shared memory
or regular stack allocations. In particular, it adds support for
parallel regions residing on standalone target device functions.

The changes are:
- Shared memory is introduced for `omp.target` implicit allocations,
such as those related to privatization and mapping, as long as they are
shared across threads in a nested parallel region.
- Standalone target device functions are interpreted as being part of a
Generic kernel, since the fact that they are present in the module after
filtering means they must be reachable from a target region.
- Prevent allocations whose only shared uses inside of an `omp.parallel`
region are as part of a `private` clause from being moved to device
shared memory.
parent d6061d29
Loading
Loading
Loading
Loading
+12 −13
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@
!===----------------------------------------------------------------------===!

! This tests check that target code nested inside a target data region which
! has only use_device_ptr mapping corectly generates code on the device pass.
! has only use_device_ptr mapping correctly generates code on the device pass.

!REQUIRES: amdgpu-registered-target
!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-version=50 -fopenmp-is-target-device %s -o - | FileCheck %s
@@ -25,22 +25,21 @@ end program

! CHECK:         define weak_odr protected amdgpu_kernel void @__omp_offloading{{.*}}main_
! CHECK-NEXT:       entry:
! CHECK-NEXT:         %[[VAL_3:.*]] = alloca ptr, align 8, addrspace(5)
! CHECK-NEXT:         %[[ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[VAL_3]] to ptr
! CHECK-NEXT:         store ptr %[[VAL_4:.*]], ptr %[[ASCAST]], align 8
! CHECK-NEXT:         %[[VAL_5:.*]] = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @__omp_offloading_{{.*}}_kernel_environment to ptr), ptr %[[VAL_6:.*]])
! CHECK-NEXT:         %[[VAL_7:.*]] = icmp eq i32 %[[VAL_5]], -1
! CHECK-NEXT:         br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]]
! CHECK:            user_code.entry:                                  ; preds = %[[VAL_10:.*]]
! CHECK-NEXT:         %[[VAL_11:.*]] = load ptr, ptr %[[ASCAST]], align 8
! CHECK-NEXT:         %[[VAL_0:.*]] = call i32 @__kmpc_target_init(ptr addrspacecast (ptr addrspace(1) @__omp_offloading_{{.*}}_kernel_environment to ptr), ptr %[[VAL_6:.*]])
! CHECK-NEXT:         %[[VAL_1:.*]] = icmp eq i32 %[[VAL_0]], -1
! CHECK-NEXT:         br i1 %[[VAL_1]], label %[[USER_ENTRY:.*]], label %[[EXIT:.*]]
! CHECK:            [[USER_ENTRY]]:                                  ; preds = %entry
! CHECK-NEXT:         %[[VAL_2:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8) 
! CHECK-NEXT:         store ptr %[[VAL_3:.*]], ptr %[[VAL_2]], align 8
! CHECK-NEXT:         %[[VAL_4:.*]] = load ptr, ptr %[[VAL_2]], align 8
! CHECK-NEXT:         br label %[[AFTER_ALLOC:.*]]

! CHECK:            [[AFTER_ALLOC]]:
! CHECK-NEXT:         br label %[[VAL_12:.*]]
! CHECK-NEXT:         br label %[[VAL_5:.*]]

! CHECK:            [[VAL_12]]:
! CHECK:            [[VAL_5]]:
! CHECK-NEXT:         br label %[[TARGET_REG_ENTRY:.*]]

! CHECK:            [[TARGET_REG_ENTRY]]:                                       ; preds = %[[VAL_12]]
! CHECK-NEXT:         call void @{{.*}}foo{{.*}}(ptr %[[VAL_11]])
! CHECK:            [[TARGET_REG_ENTRY]]:                                       ; preds = %[[VAL_5]]
! CHECK-NEXT:         call void @{{.*}}foo{{.*}}(ptr %[[VAL_4]])
! CHECK-NEXT:         br label
+6 −8
Original line number Diff line number Diff line
@@ -14,16 +14,14 @@
! target code in the same function.

! CHECK: define weak_odr protected amdgpu_kernel void @{{.*}}(ptr %[[ARG1:.*]], ptr %[[ARG2:.*]], ptr %{{.*}}) #{{[0-9]+}} {
! CHECK:  %[[ALLOCA_X:.*]] = alloca ptr, align 8, addrspace(5)
! CHECK:  %[[ASCAST_X:.*]] = addrspacecast ptr addrspace(5) %[[ALLOCA_X]] to ptr
! CHECK:  store ptr %[[ARG1]], ptr %[[ASCAST_X]], align 8
! CHECK:  %[[ALLOC_N:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
! CHECK:  store ptr %[[ARG2]], ptr %[[ALLOC_N]], align 8

! CHECK:  %[[ALLOCA_N:.*]] = alloca ptr, align 8, addrspace(5)
! CHECK:  %[[ASCAST_N:.*]] = addrspacecast ptr addrspace(5) %[[ALLOCA_N]] to ptr
! CHECK:  store ptr %[[ARG2]], ptr %[[ASCAST_N]], align 8
! CHECK:  %[[ALLOC_X:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8)
! CHECK:  store ptr %[[ARG1]], ptr %[[ALLOC_X]], align 8

! CHECK:  %[[LOAD_X:.*]] = load ptr, ptr %[[ASCAST_X]], align 8
! CHECK:  call void @bar_(ptr %[[LOAD_X]], ptr %[[ASCAST_N]])
! CHECK:  %[[LOAD_X:.*]] = load ptr, ptr %[[ALLOC_X]], align 8
! CHECK:  call void @bar_(ptr %[[LOAD_X]], ptr %[[ALLOC_N]])

module test
  implicit none
+1 −1
Original line number Diff line number Diff line
@@ -3609,7 +3609,7 @@ public:

  using TargetGenArgAccessorsCallbackTy = function_ref<InsertPointOrErrorTy(
      Argument &Arg, Value *Input, Value *&RetVal, InsertPointTy AllocaIP,
      InsertPointTy CodeGenIP)>;
      InsertPointTy CodeGenIP, ArrayRef<InsertPointTy> DeallocIPs)>;

  /// Generator for '#omp target'
  ///
+3 −2
Original line number Diff line number Diff line
@@ -8903,8 +8903,9 @@ static Expected<Function *> createOutlinedFunction(
    Argument &Arg = std::get<1>(InArg);
    Value *InputCopy = nullptr;

    llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
        ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
    llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = ArgAccessorFuncCB(
        Arg, Input, InputCopy, AllocaIP, Builder.saveIP(),
        OpenMPIRBuilder::InsertPointTy(ExitBB, ExitBB->begin()));
    if (!AfterIP)
      return AfterIP.takeError();
    Builder.restoreIP(*AfterIP);
+22 −16
Original line number Diff line number Diff line
@@ -6452,7 +6452,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
  auto SimpleArgAccessorCB =
      [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
          llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
        IRBuilderBase::InsertPointGuard guard(Builder);
        Builder.SetCurrentDebugLocation(llvm::DebugLoc());
        if (!OMPBuilder.Config.isTargetDevice()) {
@@ -6618,7 +6619,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
  auto SimpleArgAccessorCB =
      [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
          llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
        IRBuilderBase::InsertPointGuard guard(Builder);
        Builder.SetCurrentDebugLocation(llvm::DebugLoc());
        if (!OMPBuilder.Config.isTargetDevice()) {
@@ -6820,9 +6822,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
    return Builder.saveIP();
  };

  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
                                 OpenMPIRBuilder::InsertPointTy,
                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
  auto SimpleArgAccessorCB =
      [&](Argument &, Value *, Value *&, OpenMPIRBuilder::InsertPointTy,
          OpenMPIRBuilder::InsertPointTy CodeGenIP,
          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy>) {
        Builder.restoreIP(CodeGenIP);
        return Builder.saveIP();
      };
@@ -6920,9 +6923,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
  Function *OutlinedFn = nullptr;
  SmallVector<Value *> CapturedArgs;

  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
                                 OpenMPIRBuilder::InsertPointTy,
                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
  auto SimpleArgAccessorCB =
      [&](Argument &, Value *, Value *&, OpenMPIRBuilder::InsertPointTy,
          OpenMPIRBuilder::InsertPointTy CodeGenIP,
          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy>) {
        Builder.restoreIP(CodeGenIP);
        return Builder.saveIP();
      };
@@ -7019,7 +7023,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
  auto SimpleArgAccessorCB =
      [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
          llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
        IRBuilderBase::InsertPointGuard guard(Builder);
        Builder.SetCurrentDebugLocation(llvm::DebugLoc());
        if (!OMPBuilder.Config.isTargetDevice()) {
@@ -7202,7 +7207,8 @@ TEST_F(OpenMPIRBuilderTest, DebugRecordLoc) {
  auto SimpleArgAccessorCB =
      [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
          llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP,
          llvm::ArrayRef<llvm::OpenMPIRBuilder::InsertPointTy> DeallocIPs) {
        IRBuilderBase::InsertPointGuard guard(Builder);
        Builder.SetCurrentDebugLocation(llvm::DebugLoc());
        if (!OMPBuilder.Config.isTargetDevice()) {
Loading