Unverified Commit 929a328d authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

Merge pull request #225661 from SomeoneSerge/jax-libstdcxx

python3Packages.jax: fix libstdc++ mismatch when built with CUDA
parents 0fac1123 8fd02ce2
Loading
Loading
Loading
Loading
+8 −2
Original line number Diff line number Diff line
@@ -18,8 +18,14 @@ final: prev: let
  # E.g. for cudaPackages_11_8 we use gcc11 with gcc12's libstdc++
  # Cf. https://github.com/NixOS/nixpkgs/pull/218265 for context
  backendStdenv = final.callPackage ./stdenv.nix {
    nixpkgsStdenv = prev.pkgs.stdenv;
    nvccCompatibleStdenv = prev.pkgs.buildPackages."${finalVersion.gcc}Stdenv";
    # We use buildPackages (= pkgsBuildHost) because we look for a gcc that
    # runs on our build platform, and that produces executables for the host
    # platform (= platform on which we deploy and run the downstream packages).
    # The target platform of buildPackages.gcc is our host platform, so its
    # .lib output should be the libstdc++ we want to be writing in the runpaths
    # Cf. https://github.com/NixOS/nixpkgs/pull/225661#discussion_r1164564576
    nixpkgsCompatibleLibstdcxx = final.pkgs.buildPackages.gcc.cc.lib;
    nvccCompatibleCC = final.pkgs.buildPackages."${finalVersion.gcc}".cc;
  };

  ### Add classic cudatoolkit package
+28 −12
Original line number Diff line number Diff line
{ nixpkgsStdenv
, nvccCompatibleStdenv
{ lib
, nixpkgsCompatibleLibstdcxx
, nvccCompatibleCC
, overrideCC
, stdenv
, wrapCCWith
}:

overrideCC nixpkgsStdenv (wrapCCWith {
  cc = nvccCompatibleStdenv.cc.cc;
let
  cc = wrapCCWith
    {
      cc = nvccCompatibleCC;

      # This option is for clang's libcxx, but we (ab)use it for gcc's libstdc++.
      # Note that libstdc++ maintains forward-compatibility: if we load a newer
@@ -13,5 +17,17 @@ overrideCC nixpkgsStdenv (wrapCCWith {
      # older libstdc++. This, in practice, means that we should use libstdc++ from
      # the same stdenv that the rest of nixpkgs uses.
      # We currently do not try to support anything other than gcc and linux.
  libcxx = nixpkgsStdenv.cc.cc.lib;
})
      libcxx = nixpkgsCompatibleLibstdcxx;
    };
  cudaStdenv = overrideCC stdenv cc;
  passthruExtra = {
    inherit nixpkgsCompatibleLibstdcxx;
    # cc already exposed
  };
  assertCondition = true;
in
lib.extendDerivation
  assertCondition
  passthruExtra
  cudaStdenv
+3 −2
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@
}:

let
  inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl;
  inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;

  pname = "jaxlib";
  version = "0.3.22";
@@ -81,7 +81,7 @@ let
  cudatoolkit_cc_joined = symlinkJoin {
    name = "${cudatoolkit.cc.name}-merged";
    paths = [
      cudatoolkit.cc
      backendStdenv.cc
      binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
    ];
  };
@@ -271,6 +271,7 @@ let
          sed -i 's@include/pybind11@pybind11@g' $src
        done
      '' + lib.optionalString cudaSupport ''
        export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
        patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
      '' + lib.optionalString stdenv.isDarwin ''
        # Framework search paths aren't added by bintools hook