Commit 8262cd8a authored by Christian Sigg's avatar Christian Sigg
Browse files

[mlir] Set CUDA/ROCm context before creating resources.

The current context is thread-local state, and in preparation of GPU async execution (on multiple threads) we need to set the context before calling API that create resources.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D94495
parent 5f9707b7
Loading
Loading
Loading
Loading
+24 −4
Original line number Diff line number Diff line
@@ -32,17 +32,33 @@
    llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n";        \
  }(expr)

// Static initialization of CUDA context for device ordinal 0.
static auto InitializeCtx = [] {
// Static reference to CUDA primary context for device ordinal 0.
static CUcontext Context = [] {
  CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
  CUdevice device;
  CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
  CUcontext context;
  CUDA_REPORT_IF_ERROR(cuCtxCreate(&context, /*flags=*/0, device));
  return 0;
  CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device));
  return context;
}();

// Sets the `Context` for the duration of the instance and restores the previous
// context on destruction.
class ScopedContext {
public:
  ScopedContext() {
    CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous));
    CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context));
  }

  ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); }

private:
  CUcontext previous;
};

extern "C" CUmodule mgpuModuleLoad(void *data) {
  ScopedContext scopedContext;
  CUmodule module = nullptr;
  CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
  return module;
@@ -66,12 +82,14 @@ extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
                                 intptr_t blockX, intptr_t blockY,
                                 intptr_t blockZ, int32_t smem, CUstream stream,
                                 void **params, void **extra) {
  ScopedContext scopedContext;
  CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
                                      blockY, blockZ, smem, stream, params,
                                      extra));
}

extern "C" CUstream mgpuStreamCreate() {
  ScopedContext scopedContext;
  CUstream stream = nullptr;
  CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
  return stream;
@@ -90,6 +108,7 @@ extern "C" void mgpuStreamWaitEvent(CUstream stream, CUevent event) {
}

extern "C" CUevent mgpuEventCreate() {
  ScopedContext scopedContext;
  CUevent event = nullptr;
  CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING));
  return event;
@@ -108,6 +127,7 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
}

extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) {
  ScopedContext scopedContext;
  CUdeviceptr ptr;
  CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
  return reinterpret_cast<void *>(ptr);
+24 −4
Original line number Diff line number Diff line
@@ -31,17 +31,33 @@
    llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n";        \
  }(expr)

// Static initialization of HIP context for device ordinal 0.
static auto InitializeCtx = [] {
// Static reference to HIP primary context for device ordinal 0.
static hipCtx_t Context = [] {
  HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0));
  hipDevice_t device;
  HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0));
  hipCtx_t context;
  HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device));
  return 0;
  HIP_REPORT_IF_ERROR(hipDevicePrimaryCtxRetain(&context, device));
  return context;
}();

// Sets the `Context` for the duration of the instance and restores the previous
// context on destruction.
class ScopedContext {
public:
  ScopedContext() {
    HIP_REPORT_IF_ERROR(hipCtxGetCurrent(&previous));
    HIP_REPORT_IF_ERROR(hipCtxSetCurrent(Context));
  }

  ~ScopedContext() { HIP_REPORT_IF_ERROR(hipCtxSetCurrent(previous)); }

private:
  hipCtx_t previous;
};

extern "C" hipModule_t mgpuModuleLoad(void *data) {
  ScopedContext scopedContext;
  hipModule_t module = nullptr;
  HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
  return module;
@@ -67,12 +83,14 @@ extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
                                 intptr_t blockZ, int32_t smem,
                                 hipStream_t stream, void **params,
                                 void **extra) {
  ScopedContext scopedContext;
  HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
                                            blockX, blockY, blockZ, smem,
                                            stream, params, extra));
}

extern "C" hipStream_t mgpuStreamCreate() {
  ScopedContext scopedContext;
  hipStream_t stream = nullptr;
  HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
  return stream;
@@ -91,6 +109,7 @@ extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) {
}

extern "C" hipEvent_t mgpuEventCreate() {
  ScopedContext scopedContext;
  hipEvent_t event = nullptr;
  HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming));
  return event;
@@ -109,6 +128,7 @@ extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
}

extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
  ScopedContext scopedContext;
  void *ptr;
  HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
  return ptr;