Unverified Commit 3c78204c authored by Gaétan Lepage's avatar Gaétan Lepage Committed by GitHub
Browse files

python3Packages.torch*: move env variables into env for structuredAttrs (#491399)

parents 7257c3f5 fd14d209
Loading
Loading
Loading
Loading
+78 −77
Original line number Diff line number Diff line
@@ -277,7 +277,7 @@ let
  # From here on, `stdenv` shall be `stdenv'`.
  stdenv = stdenv';
in
buildPythonPackage.override { inherit stdenv; } rec {
buildPythonPackage.override { inherit stdenv; } (finalAttrs: {
  pname = "torch";
  # Don't forget to update torch-bin to the same version.
  version = "2.9.1";
@@ -293,11 +293,13 @@ buildPythonPackage.override { inherit stdenv; } rec {

  src = callPackage ./src.nix {
    inherit
      version
      fetchFromGitHub
      fetchFromGitLab
      runCommand
      ;
    inherit (finalAttrs)
      version
      ;
  };

  patches = [
@@ -420,6 +422,7 @@ buildPythonPackage.override { inherit stdenv; } rec {
  # causes possible redefinition of _FORTIFY_SOURCE
  hardeningDisable = [ "fortify3" ];

  env = {
    BUILD_NAMEDTENSOR = setBool true;
    BUILD_DOCS = setBool buildDocs;

@@ -450,55 +453,21 @@ buildPythonPackage.override { inherit stdenv; } rec {
    # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
    USE_DISTRIBUTED = setBool true;

  cmakeFlags = [
    (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
    # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
    (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
  ]
  ++ lib.optionals cudaSupport [
    # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
    # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
    (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
  ];

  preBuild = ''
    export MAX_JOBS=$NIX_BUILD_CORES
    ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
    ${cmake}/bin/cmake build
  '';

  preFixup = ''
    function join_by { local IFS="$1"; shift; echo "$*"; }
    function strip2 {
      IFS=':'
      read -ra RP <<< $(patchelf --print-rpath $1)
      IFS=' '
      RP_NEW=$(join_by : ''${RP[@]:2})
      patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
    }
    for f in $(find ''${out} -name 'libcaffe2*.so')
    do
      strip2 $f
    done
  '';

    # Override the (weirdly) wrong version set by default. See
    # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
    # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
  PYTORCH_BUILD_VERSION = version;
    PYTORCH_BUILD_VERSION = finalAttrs.version;
    PYTORCH_BUILD_NUMBER = 0;

    # In-tree builds of NCCL are not supported.
    # Use NCCL when cudaSupport is enabled and nccl is available.
    USE_NCCL = setBool useSystemNccl;
  USE_SYSTEM_NCCL = USE_NCCL;
  USE_STATIC_NCCL = USE_NCCL;
    USE_SYSTEM_NCCL = finalAttrs.env.USE_NCCL;
    USE_STATIC_NCCL = finalAttrs.env.USE_NCCL;

    # Set the correct Python library path, broken since
    # https://github.com/pytorch/pytorch/commit/3d617333e
    PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";

  env = {
    # disable warnings as errors as they break the build on every compiler
    # bump, among other things.
    # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
@@ -524,6 +493,38 @@ buildPythonPackage.override { inherit stdenv; } rec {
    USE_FBGEMM_GENAI = setBool false;
  };

  cmakeFlags = [
    (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
    # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
    (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
  ]
  ++ lib.optionals cudaSupport [
    # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
    # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
    (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
  ];

  preBuild = ''
    export MAX_JOBS=$NIX_BUILD_CORES
    ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
    ${cmake}/bin/cmake build
  '';

  preFixup = ''
    function join_by { local IFS="$1"; shift; echo "$*"; }
    function strip2 {
      IFS=':'
      read -ra RP <<< $(patchelf --print-rpath $1)
      IFS=' '
      RP_NEW=$(join_by : ''${RP[@]:2})
      patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
    }
    for f in $(find ''${out} -name 'libcaffe2*.so')
    do
      strip2 $f
    done
  '';

  nativeBuildInputs = [
    cmake
    which
@@ -650,7 +651,7 @@ buildPythonPackage.override { inherit stdenv; } rec {
        # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build

        # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
        (optionalString (majorMinor version == "1.3") "tensorboard")
        (optionalString (majorMinor finalAttrs.version == "1.3") "tensorboard")
      ])
      "runHook postCheck"
    ];
@@ -692,7 +693,7 @@ buildPythonPackage.override { inherit stdenv; } rec {
      --replace-fail "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"

    substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
      --replace-fail "/build/${src.name}/torch/include" "$dev/include"
      --replace-fail "/build/${finalAttrs.src.name}/torch/include" "$dev/include"
  '';

  postFixup = ''
@@ -751,7 +752,7 @@ buildPythonPackage.override { inherit stdenv; } rec {
  };

  meta = {
    changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
    changelog = "https://github.com/pytorch/pytorch/releases/tag/v${finalAttrs.version}";
    # keep PyTorch in the description so the package can be found under that name on search.nixos.org
    description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
    homepage = "https://pytorch.org/";
@@ -767,4 +768,4 @@ buildPythonPackage.override { inherit stdenv; } rec {
      lib.platforms.linux ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
    broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
  };
}
})