Unverified Commit b98f6d90 authored by Nick Cao's avatar Nick Cao Committed by GitHub
Browse files

Merge pull request #246712 from NickCao/jax-rework

python3Packages.{jax,jaxlib}: update to 0.4.14
parents 0ece6cc4 f4d63170
Loading
Loading
Loading
Loading
+29 −10
Original line number Diff line number Diff line
@@ -10,9 +10,12 @@ args@{
, bazelFlags ? []
, bazelBuildFlags ? []
, bazelTestFlags ? []
, bazelRunFlags ? []
, runTargetFlags ? []
, bazelFetchFlags ? []
, bazelTargets
, bazelTargets ? []
, bazelTestTargets ? []
, bazelRunTarget ? null
, buildAttrs
, fetchAttrs

@@ -46,17 +49,23 @@ args@{

let
  fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
    name = name;
    bazelFlags = bazelFlags;
    bazelBuildFlags = bazelBuildFlags;
    bazelTestFlags = bazelTestFlags;
    bazelFetchFlags = bazelFetchFlags;
    bazelTestTargets = bazelTestTargets;
    dontAddBazelOpts = dontAddBazelOpts;
    inherit
      name
      bazelFlags
      bazelBuildFlags
      bazelTestFlags
      bazelRunFlags
      runTargetFlags
      bazelFetchFlags
      bazelTargets
      bazelTestTargets
      bazelRunTarget
      dontAddBazelOpts
      ;
  };
  fBuildAttrs = fArgs // buildAttrs;
  fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
  bazelCmd = { cmd, additionalFlags, targets }:
  bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }:
    lib.optionalString (targets != [ ]) ''
      # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
      BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
@@ -73,7 +82,8 @@ let
        "''${host_linkopts[@]}" \
        $bazelFlags \
        ${lib.strings.concatStringsSep " " additionalFlags} \
        ${lib.strings.concatStringsSep " " targets}
        ${lib.strings.concatStringsSep " " targets} \
        ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags}
    '';
  # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
  # chmod: cannot operate on dangling symlink '$symlink'
@@ -262,6 +272,15 @@ stdenv.mkDerivation (fBuildAttrs // {
        targets = fBuildAttrs.bazelTargets;
      }
    }
    ${
      bazelCmd {
        cmd = "run";
        additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ];
        # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list.
        targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ];
        targetRunFlags = fBuildAttrs.runTargetFlags;
      }
    }
    runHook postBuild
  '';
})
+15 −25
Original line number Diff line number Diff line
{ lib
, absl-py
, blas
, buildPythonPackage
, etils
, setuptools
, importlib-metadata
, fetchFromGitHub
, jaxlib
, jaxlib-bin
, lapack
, matplotlib
, ml-dtypes
, numpy
, opt-einsum
, pytestCheckHook
@@ -15,7 +16,6 @@
, pythonOlder
, scipy
, stdenv
, typing-extensions
}:

let
@@ -27,30 +27,32 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.5";
  format = "setuptools";
  version = "0.4.14";
  format = "pyproject";

  disabled = pythonOlder "3.7";
  disabled = pythonOlder "3.9";

  src = fetchFromGitHub {
    owner = "google";
    repo = pname;
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/${pname}-v${version}";
    hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
    hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
  };

  nativeBuildInputs = [
    setuptools
  ];

  # jaxlib is _not_ included in propagatedBuildInputs because there are
  # different versions of jaxlib depending on the desired target hardware. The
  # JAX project ships separate wheels for CPU, GPU, and TPU.
  propagatedBuildInputs = [
    absl-py
    etils
    ml-dtypes
    numpy
    opt-einsum
    scipy
    typing-extensions
  ] ++ etils.optional-dependencies.epath;
  ] ++ lib.optional (pythonOlder "3.10") importlib-metadata;

  nativeCheckInputs = [
    jaxlib'
@@ -96,24 +98,12 @@ buildPythonPackage rec {
    "testScanGrad_jit_scan"
  ];

  # See https://github.com/google/jax/issues/11722. This is a temporary fix in
  # order to unblock etils, and upgrading jax/jaxlib to the latest version. See
  # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
  disabledTestPaths = [
    "tests/api_test.py"
    "tests/core_test.py"
    "tests/lax_numpy_indexing_test.py"
    "tests/lax_numpy_test.py"
    "tests/nn_test.py"
    "tests/random_test.py"
    "tests/sparse_test.py"
  ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
  disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
    # RuntimeWarning: invalid value encountered in cast
    "tests/lax_test.py"
  ];

  # As of 0.3.22, `import jax` does not work without jaxlib being installed.
  pythonImportsCheck = [ ];
  pythonImportsCheck = [ "jax" ];

  meta = with lib; {
    description = "Differentiate, compile, and transform Numpy code";
+51 −33
Original line number Diff line number Diff line
@@ -18,11 +18,12 @@
, autoPatchelfHook
, buildPythonPackage
, config
, cudnn ? cudaPackages.cudnn
, fetchPypi
, fetchurl
, flatbuffers
, isPy39
, jaxlib-build
, lib
, ml-dtypes
, python
, scipy
, stdenv
@@ -35,46 +36,57 @@ let
  inherit (cudaPackages) cudatoolkit cudnn;
in

assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;

let
  version = "0.4.4";
  version = "0.4.14";

  pythonVersion = python.pythonVersion;
  inherit (python) pythonVersion;

  # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
  # When upgrading, you can get these hashes from prefetch.sh. See
  # https://github.com/google/jax/issues/12879 as to why this specific URL is
  # the correct index.
  cpuSrcs = {
    "x86_64-linux" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
      hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
  # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
  # official instructions recommend installing CPU-only versions via PyPI.
  cpuSrcs =
    let
      getSrcFromPypi = { platform, hash }: fetchPypi {
        inherit version platform hash;
        pname = "jaxlib";
        format = "wheel";
        # See the `disabled` attr comment below.
        dist = "cp310";
        python = "cp310";
        abi = "cp310";
      };
    "aarch64-darwin" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
      hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
    in
    {
      "x86_64-linux" = getSrcFromPypi {
        platform = "manylinux2014_x86_64";
        hash = "sha256-nyylSZfqHeftlvVgJZFCN1ldjluZVJIYu4ZSsVxvXf8=";
      };
    "x86_64-darwin" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
      hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
      "aarch64-darwin" = getSrcFromPypi {
        platform = "macosx_11_0_arm64";
        hash = "sha256-La3wYbGCjWTl7krBD6BaBRqyBD8R530Lckbz0AWv0FM=";
      };
      "x86_64-darwin" = getSrcFromPypi {
        platform = "macosx_10_14_x86_64";
        hash = "sha256-hDg5+qisgtgOrdvbjxsUgI73cW6Aah8NLjhPe4kMAsM=";
      };
    };


  # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
  # When upgrading, you can get these hashes from prefetch.sh. See
  # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
  gpuSrc = fetchurl {
    url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
    hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
    url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
    hash = "sha256-CcQ5kjp4XfUX4/RwFY3T5G3kVKAeyoCTXu1Lo4O16Qo=";
  };

in
buildPythonPackage rec {
buildPythonPackage {
  pname = "jaxlib";
  inherit version;
  format = "wheel";

  # At the time of writing (2022-10-19), there are releases for <=3.10.
  # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
  # python version.
  disabled = !(pythonVersion == "3.10");

  # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
@@ -87,9 +99,10 @@ buildPythonPackage rec {

  # Prebuilt wheels are dynamically linked against things that nix can't find.
  # Run `autoPatchelfHook` to automagically fix them.
  nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
  nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
    ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
  # Dynamic link dependencies
  buildInputs = [ stdenv.cc.cc ];
  buildInputs = [ stdenv.cc.cc.lib ];

  # jaxlib contains shared libraries that open other shared libraries via dlopen
  # and these implicit dependencies are not recognized by ldd or
@@ -113,7 +126,12 @@ buildPythonPackage rec {
    done
  '';

  propagatedBuildInputs = [ absl-py flatbuffers scipy ];
  propagatedBuildInputs = [
    absl-py
    flatbuffers
    ml-dtypes
    scipy
  ];

  # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
  # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
@@ -123,7 +141,7 @@ buildPythonPackage rec {
    ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
  '';

  pythonImportsCheck = [ "jaxlib" ];
  inherit (jaxlib-build) pythonImportsCheck;

  meta = with lib; {
    description = "XLA library for JAX";
+44 −41
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@

  # Build-time dependencies:
, addOpenGLRunpath
, bazel_5
, bazel_6
, binutils
, buildBazelPackage
, buildPythonPackage
@@ -21,11 +21,13 @@
, setuptools
, symlinkJoin
, wheel
, build
, which

  # Python dependencies:
, absl-py
, flatbuffers
, ml-dtypes
, numpy
, scipy
, six
@@ -35,7 +37,6 @@
, giflib
, grpc
, libjpeg_turbo
, protobuf
, python
, snappy
, zlib
@@ -53,7 +54,7 @@ let
  inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;

  pname = "jaxlib";
  version = "0.4.4";
  version = "0.4.14";

  meta = with lib; {
    description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -99,7 +100,9 @@ let
    # "com_github_googleapis_googleapis"
    # "com_github_googlecloudplatform_google_cloud_cpp"
    "com_github_grpc_grpc"
    "com_google_protobuf"
    # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
    # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
    # "com_google_protobuf"
    # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
    # "com_googlesource_code_re2"
    "curl"
@@ -120,7 +123,9 @@ let
    "org_sqlite"
    "pasta"
    "png"
    "pybind11"
    # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
    # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
    # "pybind11"
    "six_archive"
    "snappy"
    "tblib_archive"
@@ -138,14 +143,15 @@ let
  bazel-build = buildBazelPackage rec {
    name = "bazel-build-${pname}-${version}";

    bazel = bazel_5;
    # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
    bazel = bazel_6;

    src = fetchFromGitHub {
      owner = "google";
      repo = "jax";
      # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
      rev = "refs/tags/${pname}-v${version}";
      hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
      hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
    };

    nativeBuildInputs = [
@@ -154,6 +160,7 @@ let
      git
      setuptools
      wheel
      build
      which
    ] ++ lib.optionals stdenv.isDarwin [
      cctools
@@ -169,7 +176,7 @@ let
      numpy
      openssl
      pkgs.flatbuffers
      protobuf
      pkgs.protobuf
      pybind11
      scipy
      six
@@ -188,7 +195,8 @@ let
      rm -f .bazelversion
    '';

    bazelTargets = [ "//build:build_wheel" ];
    bazelRunTarget = "//jaxlib/tools:build_wheel";
    runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];

    removeRulesCC = false;

@@ -207,7 +215,11 @@ let
      build --action_env=PYENV_ROOT
      build --python_path="${python}/bin/python"
      build --distinct_host_configuration=false
      build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
      build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
    '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
      build --config=avx_posix
    '' + lib.optionalString mklSupport ''
      build --config=mkl_open_source_only
    '' + lib.optionalString cudaSupport ''
      build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
      build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@@ -234,7 +246,7 @@ let
    fetchAttrs = {
      TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
      # we have to force @mkl_dnn_v1 since it's not needed on darwin
      bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
      bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
      bazelFlags = bazelFlags ++ [
        "--config=avx_posix"
      ] ++ lib.optionals cudaSupport [
@@ -247,11 +259,12 @@ let
        "--config=mkl_open_source_only"
      ];

      sha256 =
        if cudaSupport then
          "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
        else
          "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
      sha256 = (if cudaSupport then {
        x86_64-linux = "sha256-8QaXoZq6oITRsYn4RdLUXcKQv3PJ4Q3ItX9PkBwxGBI=";
      } else {
        x86_64-linux = "sha256-M/h5EZmyiV4QvzgKRjdz7V1LHENUJlc/ig1QAItnWVQ=";
        aarch64-linux = "sha256-edkYcdlvOLNGRSanch1fGCZwq8SFn3TzcUNt1LhzG/E=";
      }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
    };

    buildAttrs = {
@@ -261,25 +274,13 @@ let
        "nsync" # fails to build on darwin
      ]);

      bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
        "--config=avx_posix"
      ] ++ lib.optionals cudaSupport [
        "--config=cuda"
      ] ++ lib.optionals mklSupport [
        "--config=mkl_open_source_only"
      ];
      # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
      # 1) Fix pybind11 include paths.
      # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
      # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
      #    loading multiple extensions in the same python program due to duplicate protobuf DBs.
      # 3) Patch python path in the compiler driver.
      preBuild = ''
        for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
          sed -i 's@include/pybind11@pybind11@g' $src
        done
      '' + lib.optionalString cudaSupport ''
      # 2) Patch python path in the compiler driver.
      preBuild = lib.optionalString cudaSupport ''
        export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
        patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
        patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
      '' + lib.optionalString stdenv.isDarwin ''
        # Framework search paths aren't added by bintools hook
        # https://github.com/NixOS/nixpkgs/pull/41914
@@ -289,16 +290,12 @@ let
        substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
          --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
      '' + (if stdenv.cc.isGNU then ''
        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
      '' else if stdenv.cc.isClang then ''
        sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
        sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
      '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");

      installPhase = ''
        ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
      '';
    };

    inherit meta;
@@ -345,13 +342,19 @@ buildPythonPackage {
    grpc
    jsoncpp
    libjpeg_turbo
    ml-dtypes
    numpy
    scipy
    six
    snappy
  ];

  pythonImportsCheck = [ "jaxlib" ];
  pythonImportsCheck = [
    "jaxlib"
    # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
    "jaxlib.cpu_feature_guard"
    "jaxlib.xla_client"
  ];

  # Without it there are complaints about libcudart.so.11.0 not being found
  # because RPATH path entries added above are stripped.
+15 −7
Original line number Diff line number Diff line
version="$1"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)"
nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"
#!/usr/bin/env bash

prefetch () {
    expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url"
    url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r)
    echo "$url"
    sha256=$(nix-prefetch-url "$url")
    nix hash to-sri --type sha256 "$sha256"
    echo
}

prefetch "x86_64-linux" "false"
prefetch "aarch64-darwin" "false"
prefetch "x86_64-darwin" "false"
prefetch "x86_64-linux" "true"
Loading