Commit a639b846 authored by Kirill Radzikhovskyy's avatar Kirill Radzikhovskyy
Browse files

python3Packages.mmcv: fix build with CUDA support

parent ed0ddfe1
Loading
Loading
Loading
Loading
+20 −30
Original line number Diff line number Diff line
@@ -21,35 +21,12 @@
  tifffile,
  lmdb,
  mmengine,
  symlinkJoin,
}:

let
  inherit (torch) cudaCapabilities cudaPackages cudaSupport;
  inherit (cudaPackages) backendStdenv cudaVersion;
  inherit (cudaPackages) backendStdenv;

  cuda-common-redist = with cudaPackages; [
    cuda_cccl # <thrust/*>
    libcublas # cublas_v2.h
    libcusolver # cusolverDn.h
    libcusparse # cusparse.h
  ];

  cuda-native-redist = symlinkJoin {
    name = "cuda-native-redist-${cudaVersion}";
    paths =
      with cudaPackages;
      [
        cuda_cudart # cuda_runtime.h
        cuda_nvcc
      ]
      ++ cuda-common-redist;
  };

  cuda-redist = symlinkJoin {
    name = "cuda-redist-${cudaVersion}";
    paths = cuda-common-redist;
  };
in
buildPythonPackage rec {
  pname = "mmcv";
@@ -65,6 +42,8 @@ buildPythonPackage rec {
    hash = "sha256-NNF9sLJWV1q6uBE73LUW4UWwYm4TBMTBJjJkFArBmsc=";
  };

  env.CUDA_HOME = lib.optionalString cudaSupport (lib.getDev cudaPackages.cuda_nvcc);

  preConfigure =
    ''
      export MMCV_WITH_OPS=1
@@ -77,7 +56,7 @@ buildPythonPackage rec {
    '';

  postPatch = ''
    substituteInPlace setup.py --replace "cpu_use = 4" "cpu_use = $NIX_BUILD_CORES"
    substituteInPlace setup.py --replace-fail "cpu_use = 4" "cpu_use = $NIX_BUILD_CORES"
  '';

  preCheck = ''
@@ -102,12 +81,23 @@ buildPythonPackage rec {
  nativeBuildInputs = [
    ninja
    which
  ] ++ lib.optionals cudaSupport [ cuda-native-redist ];
  ];

  buildInputs = [
  buildInputs =
    [
      pybind11
      torch
  ] ++ lib.optionals cudaSupport [ cuda-redist ];
    ]
    ++ lib.optionals cudaSupport (
      with cudaPackages;
      [
        cuda_cudart # cuda_runtime.h
        cuda_cccl # <thrust/*>
        libcublas # cublas_v2.h
        libcusolver # cusolverDn.h
        libcusparse # cusparse.h
      ]
    );

  nativeCheckInputs = [
    pytestCheckHook