Unverified Commit 442b7942 authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

Merge pull request #220609 from ConnorBaker/fix/torch-cuda-optional

torch: passthru cudaCapabilities only when cudaSupport is true
parents ac718d02 504d7531
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -364,10 +364,14 @@ in buildPythonPackage rec {
  requiredSystemFeatures = [ "big-parallel" ];

  passthru = {
    inherit cudaSupport cudaPackages gpuTargetString;
    cudaCapabilities = supportedCudaCapabilities;
    inherit cudaSupport cudaPackages;
    # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
    blasProvider = blas.provider;
  } // lib.optionalAttrs cudaSupport {
    # NOTE: supportedCudaCapabilities isn't computed unless cudaSupport is true, so we can't use
    #   it in the passthru set above because a downstream package might try to access it even
    #   when cudaSupport is false. Better to have it missing than null or an empty list by default.
    cudaCapabilities = supportedCudaCapabilities;
  };

  meta = with lib; {
+2 −3
Original line number Diff line number Diff line
{ buildPythonPackage
, cudaSupport ? torch.cudaSupport or false # by default uses the value from torch
, fetchFromGitHub
, lib
, libjpeg_turbo
@@ -15,7 +14,7 @@
}:

let
  inherit (torch) cudaPackages gpuTargetString;
  inherit (torch) cudaCapabilities cudaPackages cudaSupport;
  inherit (cudaPackages) cudatoolkit cudaFlags cudaVersion;

  # NOTE: torchvision doesn't use cudnn; torch does!
@@ -68,7 +67,7 @@ buildPythonPackage {
  + lib.optionalString cudaSupport ''
    export CC=${cudatoolkit.cc}/bin/cc
    export CXX=${cudatoolkit.cc}/bin/c++
    export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
    export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
    export FORCE_CUDA=1
  '';