Unverified Commit f0124cff authored by Connor Baker's avatar Connor Baker Committed by GitHub
Browse files

Merge pull request #272082 from ConnorBaker/fix/torch-optional-cuda-deps

python3Packages.torch: enable cuDNN & NCCL only if available
parents 92df577d 5bf016e1
Loading
Loading
Loading
Loading
+9 −10
Original line number Diff line number Diff line
@@ -56,10 +56,7 @@

let
  inherit (lib) attrsets lists strings trivial;
  inherit (cudaPackages) cudaFlags cudnn;

  # Some packages are not available on all platforms
  nccl = cudaPackages.nccl or null;
  inherit (cudaPackages) cudaFlags cudnn nccl;

  setBool = v: if v then "1" else "0";

@@ -212,10 +209,11 @@ in buildPythonPackage rec {
  # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
  preConfigure = lib.optionalString cudaSupport ''
    export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
    export CUDNN_LIB_DIR=${cudnn.lib}/lib
    export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
    export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
  '' + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
    export CUDNN_LIB_DIR=${cudnn.lib}/lib
  '' + lib.optionalString rocmSupport ''
    export ROCM_PATH=${rocmtoolkit_joined}
    export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
@@ -273,7 +271,7 @@ in buildPythonPackage rec {
  PYTORCH_BUILD_VERSION = version;
  PYTORCH_BUILD_NUMBER = 0;

  USE_NCCL = setBool (nccl != null);
  USE_NCCL = setBool (cudaPackages ? nccl);
  USE_SYSTEM_NCCL = setBool useSystemNccl;                  # don't build pytorch's third_party NCCL
  USE_STATIC_NCCL = setBool useSystemNccl;

@@ -348,8 +346,6 @@ in buildPythonPackage rec {
      cuda_nvrtc.lib
      cuda_nvtx.dev
      cuda_nvtx.lib # -llibNVToolsExt
      cudnn.dev
      cudnn.lib
      libcublas.dev
      libcublas.lib
      libcufft.dev
@@ -360,7 +356,10 @@ in buildPythonPackage rec {
      libcusolver.lib
      libcusparse.dev
      libcusparse.lib
    ] ++ lists.optionals (nccl != null) [
    ] ++ lists.optionals (cudaPackages ? cudnn) [
      cudnn.dev
      cudnn.lib
    ] ++ lists.optionals (useSystemNccl && cudaPackages ? nccl) [
      # Some platforms do not support NCCL (i.e., Jetson)
      nccl.dev # Provides nccl.h AND a static copy of NCCL!
    ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [