Unverified Commit 068b4774 authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

Merge pull request #196977 from samuela/samuela/jax

JAX upgrades
parents 9ad41618 38cd4258
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.3.16";
  version = "0.3.23";
  format = "setuptools";

  disabled = pythonOlder "3.7";
@@ -30,7 +30,7 @@ buildPythonPackage rec {
    owner = "google";
    repo = pname;
    rev = "jax-v${version}";
    hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I=";
    hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4=";
  };

  # jaxlib is _not_ included in propagatedBuildInputs because there are
@@ -92,9 +92,8 @@ buildPythonPackage rec {
    "tests/sparse_test.py"
  ];

  pythonImportsCheck = [
    "jax"
  ];
  # As of 0.3.22, `import jax` does not work without jaxlib being installed.
  pythonImportsCheck = [ ];

  meta = with lib; {
    description = "Differentiate, compile, and transform Numpy code";
+23 −44
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@
# https://storage.googleapis.com/jax-releases/libtpu_releases.html.

# For future reference, the easiest way to test the GPU backend is to run
#   NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
#   NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }"
#   export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
#   python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
#   python -c "from jax import random; random.PRNGKey(0)"
@@ -35,46 +35,32 @@ let
  inherit (cudaPackages) cudatoolkit cudnn;
in

# There are no jaxlib wheels targeting cudnn <8.0.5, and although there are
# wheels for cudatoolkit <11.1, we don't support them.
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";

let
  version = "0.3.0";
  version = "0.3.22";

  pythonVersion = python.pythonVersion;

  # Find new releases at https://storage.googleapis.com/jax-releases. When
  # upgrading, you can get these hashes from prefetch.sh.
  # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
  # When upgrading, you can get these hashes from prefetch.sh. See
  # https://github.com/google/jax/issues/12879 as to why this specific URL is
  # the correct index.
  cpuSrcs = {
    "3.9" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
      hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q=";
    "x86_64-linux" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
      hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4=";
    };
    "3.10" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl";
      hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ=";
    "aarch64-darwin" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
      hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0=";
    };
  };

  gpuSrcs = {
    "3.9-805" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
      hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw=";
    };
    "3.9-82" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
      hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4=";
    };
    "3.10-805" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl";
      hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U=";
    };
    "3.10-82" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl";
      hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go=";
    };
  gpuSrc = fetchurl {
    url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
    hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs=";
  };
in
buildPythonPackage rec {
@@ -82,23 +68,16 @@ buildPythonPackage rec {
  inherit version;
  format = "wheel";

  # At the time of writing (2022-03-03), there are releases for <=3.10.
  # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs
  # python3 version, and 3.10.
  disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10");
  # At the time of writing (2022-10-19), there are releases for <=3.10.
  # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
  # python version.
  disabled = !(pythonVersion == "3.10");

  src =
    if !cudaSupport then cpuSrcs."${pythonVersion}" else
    let
      # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and
      # 8.2. Try to use 8.2 whenever possible.
      cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805";
    in
    gpuSrcs."${pythonVersion}-${cudnnVersion}";
  src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc;

  # Prebuilt wheels are dynamically linked against things that nix can't find.
  # Run `autoPatchelfHook` to automagically fix them.
  nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
  nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
  # Dynamic link dependencies
  buildInputs = [ stdenv.cc.cc ];

@@ -142,6 +121,6 @@ buildPythonPackage rec {
    sourceProvenance = with sourceTypes; [ binaryNativeCode ];
    license = licenses.asl20;
    maintainers = with maintainers; [ samuela ];
    platforms = [ "x86_64-linux" ];
    platforms = [ "aarch64-darwin" "x86_64-linux" ];
  };
}
+8 −6
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ let
  inherit (cudaPackages) cudatoolkit cudnn nccl;

  pname = "jaxlib";
  version = "0.3.15";
  version = "0.3.22";

  meta = with lib; {
    description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -96,7 +96,7 @@ let
      owner = "google";
      repo = "jax";
      rev = "${pname}-v${version}";
      sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s=";
      hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs=";
    };

    nativeBuildInputs = [
@@ -235,11 +235,11 @@ let
    fetchAttrs = {
      sha256 =
        if cudaSupport then
          "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk="
          "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ="
        else if stdenv.isDarwin then
          "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34="
          "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM="
        else
          "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A=";
          "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM=";
    };

    buildAttrs = {
@@ -293,7 +293,9 @@ buildPythonPackage {
  inherit meta pname version;
  format = "wheel";

  src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl";
  src =
    let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
    in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";

  # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
  # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
+0 −6
Original line number Diff line number Diff line
@@ -4780,9 +4780,6 @@ in {

  jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix {
    cudaSupport = pkgs.config.cudaSupport or false;
    # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
    # pin to `cudaPackages_11_6` instead.
    cudaPackages = pkgs.cudaPackages_11_6;
  };

  jaxlib-build = callPackage ../development/python-modules/jaxlib rec {
@@ -4792,9 +4789,6 @@ in {
    };
    # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
    cudaSupport = pkgs.config.cudaSupport or false;
    # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we
    # pin to `cudaPackages_11_6` instead.
    cudaPackages = pkgs.cudaPackages_11_6;
    IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
    protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
  };