Unverified Commit 47f07cae authored by Connor Baker's avatar Connor Baker Committed by GitHub
Browse files

Merge pull request #266081 from ConnorBaker/fix/torch-jetson

python3Packages.torch: patch cpp_extension.py for Jetson support
parents 417c2051 2a425031
Loading
Loading
Loading
Loading
+14 −1
Original line number Diff line number Diff line
@@ -48,7 +48,10 @@

let
  inherit (lib) attrsets lists strings trivial;
  inherit (cudaPackages) cudaFlags cudnn nccl;
  inherit (cudaPackages) cudaFlags cudnn;

  # Some packages are not available on all platforms
  nccl = cudaPackages.nccl or null;

  setBool = v: if v then "1" else "0";

@@ -178,6 +181,13 @@ in buildPythonPackage rec {
        'message(FATAL_ERROR "Found NCCL header version and library version' \
        'message(WARNING "Found NCCL header version and library version'
  ''
  # TODO(@connorbaker): Remove this patch after 2.1.0 lands.
  + lib.optionalString cudaSupport ''
    substituteInPlace torch/utils/cpp_extension.py \
      --replace \
        "'8.6', '8.9'" \
        "'8.6', '8.7', '8.9'"
  ''
  # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
  # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
  + lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") ''
@@ -253,6 +263,7 @@ in buildPythonPackage rec {
  PYTORCH_BUILD_VERSION = version;
  PYTORCH_BUILD_NUMBER = 0;

  USE_NCCL = setBool (nccl != null);
  USE_SYSTEM_NCCL = setBool useSystemNccl;                  # don't build pytorch's third_party NCCL
  USE_STATIC_NCCL = setBool useSystemNccl;

@@ -316,6 +327,8 @@ in buildPythonPackage rec {
      libcusolver.lib
      libcusparse.dev
      libcusparse.lib
    ] ++ lists.optionals (nccl != null) [
      # Some platforms do not support NCCL (i.e., Jetson)
      nccl.dev # Provides nccl.h AND a static copy of NCCL!
    ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
      cuda_nvprof.dev # <cuda_profiler_api.h>