Commit 65365504 authored by Jon Chesterfield's avatar Jon Chesterfield
Browse files

[libomptarget][cuda] Handle missing _v2 symbols gracefully

[libomptarget][cuda] Handle missing _v2 symbols gracefully

Follow on from D95367. Dlsym the _v2 symbols if present, otherwise use the
unsuffixed version. Builds a hashtable for the check, can revise for zero
heap allocations later if necessary.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D95415
parent 65e2fa50
Loading
Loading
Loading
Loading
+30 −0
Original line number Diff line number Diff line
@@ -15,6 +15,9 @@
#include "Debug.h"
#include "dlwrap.h"

#include <string>
#include <unordered_map>

#include <dlfcn.h>

DLWRAP_INTERNAL(cuInit, 1);
@@ -67,6 +70,21 @@ DLWRAP_FINALIZE();
static bool checkForCUDA() {
  // return true if dlopen succeeded and all functions found

  // Prefer _v2 versions of functions if found in the library
  std::unordered_map<std::string, const char *> TryFirst = {
      {"cuMemAlloc", "cuMemAlloc_v2"},
      {"cuMemFree", "cuMemFree_v2"},
      {"cuMemcpyDtoH", "cuMemcpyDtoH_v2"},
      {"cuMemcpyHtoD", "cuMemcpyHtoD_v2"},
      {"cuStreamDestroy", "cuStreamDestroy_v2"},
      {"cuModuleGetGlobal", "cuModuleGetGlobal_v2"},
      {"cuMemcpyDtoHAsync", "cuMemcpyDtoHAsync_v2"},
      {"cuMemcpyDtoDAsync", "cuMemcpyDtoDAsync_v2"},
      {"cuMemcpyHtoDAsync", "cuMemcpyHtoDAsync_v2"},
      {"cuDevicePrimaryCtxRelease", "cuDevicePrimaryCtxRelease_v2"},
      {"cuDevicePrimaryCtxSetFlags", "cuDevicePrimaryCtxSetFlags_v2"},
  };

  const char *CudaLib = DYNAMIC_CUDA_PATH;
  void *DynlibHandle = dlopen(CudaLib, RTLD_NOW);
  if (!DynlibHandle) {
@@ -77,11 +95,23 @@ static bool checkForCUDA() {
  for (size_t I = 0; I < dlwrap::size(); I++) {
    const char *Sym = dlwrap::symbol(I);

    auto It = TryFirst.find(Sym);
    if (It != TryFirst.end()) {
      const char *First = It->second;
      void *P = dlsym(DynlibHandle, First);
      if (P) {
        DP("Implementing %s with dlsym(%s) -> %p\n", Sym, First, P);
        *dlwrap::pointer(I) = P;
        continue;
      }
    }

    void *P = dlsym(DynlibHandle, Sym);
    if (P == nullptr) {
      DP("Unable to find '%s' in '%s'!\n", Sym, CudaLib);
      return false;
    }
    DP("Implementing %s with dlsym(%s) -> %p\n", Sym, Sym, P);

    *dlwrap::pointer(I) = P;
  }
+0 −12
Original line number Diff line number Diff line
@@ -49,18 +49,6 @@ typedef enum CUctx_flags_enum {
  CU_CTX_SCHED_MASK = 0x07,
} CUctx_flags;

#define cuMemFree cuMemFree_v2
#define cuMemAlloc cuMemAlloc_v2
#define cuMemcpyDtoH cuMemcpyDtoH_v2
#define cuMemcpyHtoD cuMemcpyHtoD_v2
#define cuStreamDestroy cuStreamDestroy_v2
#define cuModuleGetGlobal cuModuleGetGlobal_v2
#define cuMemcpyDtoHAsync cuMemcpyDtoHAsync_v2
#define cuMemcpyDtoDAsync cuMemcpyDtoDAsync_v2
#define cuMemcpyHtoDAsync cuMemcpyHtoDAsync_v2
#define cuDevicePrimaryCtxRelease cuDevicePrimaryCtxRelease_v2
#define cuDevicePrimaryCtxSetFlags cuDevicePrimaryCtxSetFlags_v2

CUresult cuCtxGetDevice(CUdevice *);
CUresult cuDeviceGet(CUdevice *, int);
CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice);