Unverified Commit 8faef6de authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

Merge pull request #219778 from samuela/samuela/jax

Update JAX and fix aarch64-darwin build
parents 3df03f23 3fa9f1f0
Loading
Loading
Loading
Loading
+17 −6
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
, etils
, fetchFromGitHub
, jaxlib
, jaxlib-bin
, lapack
, matplotlib
, numpy
@@ -13,15 +14,20 @@
, pytest-xdist
, pythonOlder
, scipy
, stdenv
, typing-extensions
}:

let
  usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
  # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
  # fine. jaxlib is only used in the checkPhase, so switching backends does not
  # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
  jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.1";
  version = "0.4.5";
  format = "setuptools";

  disabled = pythonOlder "3.7";
@@ -29,14 +35,14 @@ buildPythonPackage rec {
  src = fetchFromGitHub {
    owner = "google";
    repo = pname;
    rev = "refs/tags/jaxlib-v${version}";
    hash = "sha256-ajLI0iD0YZRK3/uKSbhlIZGc98MdW174vA34vhoy7Iw=";
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/${pname}-v${version}";
    hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
  };

  # jaxlib is _not_ included in propagatedBuildInputs because there are
  # different versions of jaxlib depending on the desired target hardware. The
  # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
  # CPU wheel is packaged.
  # JAX project ships separate wheels for CPU, GPU, and TPU.
  propagatedBuildInputs = [
    absl-py
    etils
@@ -47,7 +53,7 @@ buildPythonPackage rec {
  ] ++ etils.optional-dependencies.epath;

  nativeCheckInputs = [
    jaxlib
    jaxlib'
    matplotlib
    pytestCheckHook
    pytest-xdist
@@ -83,6 +89,11 @@ buildPythonPackage rec {
    "test_custom_linear_solve_cholesky"
    "test_custom_root_with_aux"
    "testEigvalsGrad_shape"
  ] ++ lib.optionals (stdenv.isAarch64 && stdenv.isDarwin) [
    # See https://github.com/google/jax/issues/14793.
    "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
    "testQdwhWithRandomMatrix3"
    "testScanGrad_jit_scan"
  ];

  # See https://github.com/google/jax/issues/11722. This is a temporary fix in
+12 −6
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";

let
  version = "0.3.22";
  version = "0.4.4";

  pythonVersion = python.pythonVersion;

@@ -50,21 +50,21 @@ let
  cpuSrcs = {
    "x86_64-linux" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
      hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4=";
      hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
    };
    "aarch64-darwin" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
      hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0=";
      hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
    };
    "x86_64-darwin" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
      hash = "sha256-bOoQI+T+YsTUNA+cDu6wwYTcq9fyyzCpK9qrdCrNVoA=";
      hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
    };
  };

  gpuSrc = fetchurl {
    url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
    hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs=";
    hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
  };
in
buildPythonPackage rec {
@@ -77,7 +77,13 @@ buildPythonPackage rec {
  # python version.
  disabled = !(pythonVersion == "3.10");

  src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc;
  # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
  src =
    if !cudaSupport then
      (
        cpuSrcs."${stdenv.hostPlatform.system}"
          or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
      ) else gpuSrc;

  # Prebuilt wheels are dynamically linked against things that nix can't find.
  # Run `autoPatchelfHook` to automagically fix them.
+6 −5
Original line number Diff line number Diff line
@@ -52,7 +52,7 @@ let
  inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;

  pname = "jaxlib";
  version = "0.3.22";
  version = "0.4.4";

  meta = with lib; {
    description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -137,8 +137,9 @@ let
    src = fetchFromGitHub {
      owner = "google";
      repo = "jax";
      rev = "${pname}-v${version}";
      hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs=";
      # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
      rev = "refs/tags/${pname}-v${version}";
      hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
    };

    nativeBuildInputs = [
@@ -242,9 +243,9 @@ let

      sha256 =
        if cudaSupport then
          "sha256-4yu4y4SwSQoeaOz9yojhvCRGSC6jp61ycVDIKyIK/l8="
          "sha256-cgsiloW77p4+TKRrYequZ/UwKwfO2jsHKtZ+aA30H7E="
        else
          "sha256-CyRfPfJc600M7VzR3/SQX/EAyeaXRJwDQWot5h2XnFU=";
          "sha256-D7WYG3YUaWq+4APYx8WpA191VVtoHG0fth3uEHXOeos=";
    };

    buildAttrs = {