Unverified Commit 6232bc9c authored by Gaetan Lepage's avatar Gaetan Lepage Committed by Nick Cao
Browse files

python3Packages.jaxlib: 0.4.4 -> 0.4.14

parent 06ef57da
Loading
Loading
Loading
Loading
+40 −38
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@

  # Build-time dependencies:
, addOpenGLRunpath
, bazel_5
, bazel_6
, binutils
, buildBazelPackage
, buildPythonPackage
@@ -21,11 +21,13 @@
, setuptools
, symlinkJoin
, wheel
, build
, which

  # Python dependencies:
, absl-py
, flatbuffers
, ml-dtypes
, numpy
, scipy
, six
@@ -35,7 +37,6 @@
, giflib
, grpc
, libjpeg_turbo
, protobuf
, python
, snappy
, zlib
@@ -53,7 +54,7 @@ let
  inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;

  pname = "jaxlib";
  version = "0.4.4";
  version = "0.4.14";

  meta = with lib; {
    description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -99,7 +100,9 @@ let
    # "com_github_googleapis_googleapis"
    # "com_github_googlecloudplatform_google_cloud_cpp"
    "com_github_grpc_grpc"
    "com_google_protobuf"
    # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
    # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
    # "com_google_protobuf"
    # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
    # "com_googlesource_code_re2"
    "curl"
@@ -120,7 +123,9 @@ let
    "org_sqlite"
    "pasta"
    "png"
    "pybind11"
    # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
    # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
    # "pybind11"
    "six_archive"
    "snappy"
    "tblib_archive"
@@ -138,14 +143,15 @@ let
  bazel-build = buildBazelPackage rec {
    name = "bazel-build-${pname}-${version}";

    bazel = bazel_5;
    # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
    bazel = bazel_6;

    src = fetchFromGitHub {
      owner = "google";
      repo = "jax";
      # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
      rev = "refs/tags/${pname}-v${version}";
      hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
      hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
    };

    nativeBuildInputs = [
@@ -154,6 +160,7 @@ let
      git
      setuptools
      wheel
      build
      which
    ] ++ lib.optionals stdenv.isDarwin [
      cctools
@@ -169,7 +176,7 @@ let
      numpy
      openssl
      pkgs.flatbuffers
      protobuf
      pkgs.protobuf
      pybind11
      scipy
      six
@@ -188,7 +195,8 @@ let
      rm -f .bazelversion
    '';

    bazelTargets = [ "//build:build_wheel" ];
    bazelRunTarget = "//jaxlib/tools:build_wheel";
    runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];

    removeRulesCC = false;

@@ -207,7 +215,11 @@ let
      build --action_env=PYENV_ROOT
      build --python_path="${python}/bin/python"
      build --distinct_host_configuration=false
      build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
      build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
    '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
      build --config=avx_posix
    '' + lib.optionalString mklSupport ''
      build --config=mkl_open_source_only
    '' + lib.optionalString cudaSupport ''
      build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
      build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@@ -234,7 +246,7 @@ let
    fetchAttrs = {
      TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
      # we have to force @mkl_dnn_v1 since it's not needed on darwin
      bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
      bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
      bazelFlags = bazelFlags ++ [
        "--config=avx_posix"
      ] ++ lib.optionals cudaSupport [
@@ -249,9 +261,9 @@ let

      sha256 =
        if cudaSupport then
          "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
          "sha256-8QaXoZq6oITRsYn4RdLUXcKQv3PJ4Q3ItX9PkBwxGBI="
        else
          "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
          "sha256-M/h5EZmyiV4QvzgKRjdz7V1LHENUJlc/ig1QAItnWVQ=";
    };

    buildAttrs = {
@@ -261,25 +273,13 @@ let
        "nsync" # fails to build on darwin
      ]);

      bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
        "--config=avx_posix"
      ] ++ lib.optionals cudaSupport [
        "--config=cuda"
      ] ++ lib.optionals mklSupport [
        "--config=mkl_open_source_only"
      ];
      # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
      # 1) Fix pybind11 include paths.
      # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
      # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
      #    loading multiple extensions in the same python program due to duplicate protobuf DBs.
      # 3) Patch python path in the compiler driver.
      preBuild = ''
        for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
          sed -i 's@include/pybind11@pybind11@g' $src
        done
      '' + lib.optionalString cudaSupport ''
      # 2) Patch python path in the compiler driver.
      preBuild = lib.optionalString cudaSupport ''
        export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
        patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
        patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
      '' + lib.optionalString stdenv.isDarwin ''
        # Framework search paths aren't added by bintools hook
        # https://github.com/NixOS/nixpkgs/pull/41914
@@ -289,16 +289,12 @@ let
        substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
          --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
      '' + (if stdenv.cc.isGNU then ''
        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
      '' else if stdenv.cc.isClang then ''
        sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
      '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");

      installPhase = ''
        ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
      '';
    };

    inherit meta;
@@ -345,13 +341,19 @@ buildPythonPackage {
    grpc
    jsoncpp
    libjpeg_turbo
    ml-dtypes
    numpy
    scipy
    six
    snappy
  ];

  pythonImportsCheck = [ "jaxlib" ];
  pythonImportsCheck = [
    "jaxlib"
    # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
    "jaxlib.cpu_feature_guard"
    "jaxlib.xla_client"
  ];

  # Without it there are complaints about libcudart.so.11.0 not being found
  # because RPATH path entries added above are stripped.
+0 −1
Original line number Diff line number Diff line
@@ -5310,7 +5310,6 @@ self: super: with self; {
    # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
    inherit (pkgs.config) cudaSupport;
    IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
    protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
  };

  jaxlib = self.jaxlib-build;