Unverified Commit 7604c9e1 authored by Stefan Frijters's avatar Stefan Frijters
Browse files

python3Packages.torch*: move env variables into env for structuredAttrs

parent b2bb8b19
Loading
Loading
Loading
Loading
+70 −71
Original line number Diff line number Diff line
@@ -420,6 +420,7 @@ buildPythonPackage.override { inherit stdenv; } rec {
  # causes possible redefinition of _FORTIFY_SOURCE
  hardeningDisable = [ "fortify3" ];

  env = rec {
    BUILD_NAMEDTENSOR = setBool true;
    BUILD_DOCS = setBool buildDocs;

@@ -450,38 +451,6 @@ buildPythonPackage.override { inherit stdenv; } rec {
    # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
    USE_DISTRIBUTED = setBool true;

  cmakeFlags = [
    (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
    # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
    (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
  ]
  ++ lib.optionals cudaSupport [
    # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
    # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
    (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
  ];

  preBuild = ''
    export MAX_JOBS=$NIX_BUILD_CORES
    ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
    ${cmake}/bin/cmake build
  '';

  preFixup = ''
    function join_by { local IFS="$1"; shift; echo "$*"; }
    function strip2 {
      IFS=':'
      read -ra RP <<< $(patchelf --print-rpath $1)
      IFS=' '
      RP_NEW=$(join_by : ''${RP[@]:2})
      patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
    }
    for f in $(find ''${out} -name 'libcaffe2*.so')
    do
      strip2 $f
    done
  '';

    # Override the (weirdly) wrong version set by default. See
    # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
    # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
@@ -497,8 +466,6 @@ buildPythonPackage.override { inherit stdenv; } rec {
    # Set the correct Python library path, broken since
    # https://github.com/pytorch/pytorch/commit/3d617333e
    PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";

  env = {
    # disable warnings as errors as they break the build on every compiler
    # bump, among other things.
    # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
@@ -524,6 +491,38 @@ buildPythonPackage.override { inherit stdenv; } rec {
    USE_FBGEMM_GENAI = setBool false;
  };

  cmakeFlags = [
    (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
    # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
    (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
  ]
  ++ lib.optionals cudaSupport [
    # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
    # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
    (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
  ];

  preBuild = ''
    export MAX_JOBS=$NIX_BUILD_CORES
    ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
    ${cmake}/bin/cmake build
  '';

  preFixup = ''
    function join_by { local IFS="$1"; shift; echo "$*"; }
    function strip2 {
      IFS=':'
      read -ra RP <<< $(patchelf --print-rpath $1)
      IFS=' '
      RP_NEW=$(join_by : ''${RP[@]:2})
      patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
    }
    for f in $(find ''${out} -name 'libcaffe2*.so')
    do
      strip2 $f
    done
  '';

  nativeBuildInputs = [
    cmake
    which