Unverified Commit 2d79f0cc authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

Merge pull request #221390 from NixOS/uri/jax

python3Packages.jaxlib-build: share fetch derivation between different build derivations
parents 73869ed2 c734173b
Loading
Loading
Loading
Loading
+74 −59
Original line number Diff line number Diff line
@@ -86,7 +86,50 @@ let
    ];
  };

  bazel-build = buildBazelPackage {
  # Copy-paste from TF derivation.
  # Most of these are not really used in jaxlib compilation but it's simpler to keep it
  # 'as is' so that it's more compatible with TF derivation.
  tf_system_libs = [
    "absl_py"
    "astor_archive"
    "astunparse_archive"
    "boringssl"
    # Not packaged in nixpkgs
    # "com_github_googleapis_googleapis"
    # "com_github_googlecloudplatform_google_cloud_cpp"
    "com_github_grpc_grpc"
    "com_google_protobuf"
    # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
    # "com_googlesource_code_re2"
    "curl"
    "cython"
    "dill_archive"
    "double_conversion"
    "flatbuffers"
    "functools32_archive"
    "gast_archive"
    "gif"
    "hwloc"
    "icu"
    "jsoncpp_git"
    "libjpeg_turbo"
    "lmdb"
    "nasm"
    "opt_einsum_archive"
    "org_sqlite"
    "pasta"
    "png"
    "pybind11"
    "six_archive"
    "snappy"
    "tblib_archive"
    "termcolor_archive"
    "typing_extensions_archive"
    "wrapt"
    "zlib"
  ];

  bazel-build = buildBazelPackage rec {
    name = "bazel-build-${pname}-${version}";

    bazel = bazel_5;
@@ -169,61 +212,10 @@ let
      CFG
    '';

    # Copy-paste from TF derivation.
    # Most of these are not really used in jaxlib compilation but it's simpler to keep it
    # 'as is' so that it's more compatible with TF derivation.
    TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
      "absl_py"
      "astor_archive"
      "astunparse_archive"
      "boringssl"
      # Not packaged in nixpkgs
      # "com_github_googleapis_googleapis"
      # "com_github_googlecloudplatform_google_cloud_cpp"
      "com_github_grpc_grpc"
      "com_google_protobuf"
      # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
      # "com_googlesource_code_re2"
      "curl"
      "cython"
      "dill_archive"
      "double_conversion"
      "flatbuffers"
      "functools32_archive"
      "gast_archive"
      "gif"
      "hwloc"
      "icu"
      "jsoncpp_git"
      "libjpeg_turbo"
      "lmdb"
      "nasm"
      "opt_einsum_archive"
      "org_sqlite"
      "pasta"
      "png"
      "pybind11"
      "six_archive"
      "snappy"
      "tblib_archive"
      "termcolor_archive"
      "typing_extensions_archive"
      "wrapt"
      "zlib"
    ] ++ lib.optionals (!stdenv.isDarwin) [
      "nsync" # fails to build on darwin
    ]);

    # Make sure Bazel knows about our configuration flags during fetching so that the
    # relevant dependencies can be downloaded.
    bazelFlags = [
      "-c opt"
    ] ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
      "--config=avx_posix"
    ] ++ lib.optionals cudaSupport [
      "--config=cuda"
    ] ++ lib.optionals mklSupport [
      "--config=mkl_open_source_only"
    ] ++ lib.optionals stdenv.cc.isClang [
      # bazel depends on the compiler frontend automatically selecting these flags based on file
      # extension but our clang doesn't.
@@ -231,21 +223,44 @@ let
      "--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
    ];

    # We intentionally overfetch so we can share the fetch derivation across all the different configurations
    fetchAttrs = {
      TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
      # we have to force @mkl_dnn_v1 since it's not needed on darwin
      bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
      bazelFlags = bazelFlags ++ [
        "--config=avx_posix"
      ] ++ lib.optionals cudaSupport [
        # ideally we'd add this unconditionally too, but it doesn't work on darwin
        # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
        # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
        # have access to darwin machines
        "--config=cuda"
      ] ++ [
        "--config=mkl_open_source_only"
      ];

      sha256 =
        if cudaSupport then
          "sha256-n8wo+hD9ZYO1SsJKgyJzUmjRlsz45WT6tt5ZLleGvGY="
        else {
          x86_64-linux = "sha256-A0A18kxgGNGHNQ67ZPUzh3Yq2LEcRV7CqR9EfP80NQk=";
          aarch64-linux = "sha256-mU2jzuDu89jVmaG/M5bA3jSd7n7lDi+h8sdhs1z8p1A=";
          x86_64-darwin = "sha256-9nNTpetvjyipD/l8vKlregl1j/OnZKAcOCoZQeRBvts=";
          aarch64-darwin = "sha256-FqYwI1YC5eqSv+DYj09DC5IaBfFDUCO97y+TFhGiWAA=";
        }.${stdenv.system} or (throw "unsupported system ${stdenv.system}");
          "sha256-4yu4y4SwSQoeaOz9yojhvCRGSC6jp61ycVDIKyIK/l8="
        else
          "sha256-CyRfPfJc600M7VzR3/SQX/EAyeaXRJwDQWot5h2XnFU=";
    };

    buildAttrs = {
      outputs = [ "out" ];

      TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!stdenv.isDarwin) [
        "nsync" # fails to build on darwin
      ]);

      bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
        "--config=avx_posix"
      ] ++ lib.optionals cudaSupport [
        "--config=cuda"
      ] ++ lib.optionals mklSupport [
        "--config=mkl_open_source_only"
      ];
      # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
      # 1) Fix pybind11 include paths.
      # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on