Unverified Commit 6cf1820d authored by Someone Serge's avatar Someone Serge
Browse files

python3Packages.numba: unbreak cuda

Dont use the cuda runfile; unbreak CDLL on NixOS; add a passthru
"tester" script for cuda; do not *propagate* the cuda thrash!
parent 0cfd293a
Loading
Loading
Loading
Loading
+29 −15
Original line number Diff line number Diff line
@@ -12,19 +12,23 @@
, importlib-metadata
, substituteAll
, runCommand
, symlinkJoin
, writers
, numba

, config

# CUDA-only dependencies:
, addOpenGLRunpath ? null
, cudaPackages ? {}
, addDriverRunpath
, autoAddDriverRunpath ? cudaPackages.autoAddDriverRunpathHook or cudaPackages.autoAddOpenGLRunpathHook
, cudaPackages

# CUDA flags:
, cudaSupport ? config.cudaSupport
}:

let
  inherit (cudaPackages) cudatoolkit;
  cudatoolkit = cudaPackages.cuda_nvcc;
in buildPythonPackage rec {
  # Using an untagged version, with numpy 1.25 support, when it's released
  # also drop the versioneer patch in postPatch
@@ -52,12 +56,25 @@ in buildPythonPackage rec {
    # relevant strings ourselves, using `sed` commands, in extraPostFetch.
    hash = "sha256-wd4TujPhV2Jy/HUUXLHAlcbVFm4gfQNWxWFXD+jeZC4=";
  };

  postPatch = ''
    substituteInPlace numba/cuda/cudadrv/driver.py \
      --replace-fail \
        "dldir = [" \
        "dldir = [ '${addDriverRunpath.driverLink}/lib', "
  '';

  env.NIX_CFLAGS_COMPILE = lib.optionalString stdenv.isDarwin "-I${lib.getDev libcxx}/include/c++/v1";

  nativeBuildInputs = [
    numpy
  ] ++ lib.optionals cudaSupport [
    addOpenGLRunpath
    autoAddDriverRunpath
    cudaPackages.cuda_nvcc
  ];

  buildInputs = with cudaPackages; [
    cuda_cudart
  ];

  propagatedBuildInputs = [
@@ -66,26 +83,16 @@ in buildPythonPackage rec {
    setuptools
  ] ++ lib.optionals (pythonOlder "3.9") [
    importlib-metadata
  ] ++ lib.optionals cudaSupport [
    cudatoolkit
    cudatoolkit.lib
  ];

  patches = lib.optionals cudaSupport [
    (substituteAll {
      src = ./cuda_path.patch;
      cuda_toolkit_path = cudatoolkit;
      cuda_toolkit_lib_path = cudatoolkit.lib;
      cuda_toolkit_lib_path = lib.getLib cudatoolkit;
    })
  ];

  postFixup = lib.optionalString cudaSupport ''
    find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
      addOpenGLRunpath "$lib"
      patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:$(patchelf --print-rpath "$lib")" "$lib"
    done
  '';

  # run a smoke test in a temporary directory so that
  # a) Python picks up the installed library in $out instead of the build files
  # b) we have somewhere to put $HOME so some caching tests work
@@ -104,6 +111,13 @@ in buildPythonPackage rec {
    "numba"
  ];

  passthru.testers.cuda-detect =
    writers.writePython3Bin "numba-cuda-detect"
      { libraries = [ (numba.override { cudaSupport = true; }) ]; }
      ''
        from numba import cuda
        cuda.detect()
      '';
  passthru.tests = {
    # CONTRIBUTOR NOTE: numba also contains CUDA tests, though these cannot be run in
    # this sandbox environment. Consider running similar commands to those below outside the