Unverified Commit 3a993d32 authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

Merge pull request #291705 from GaetanLepage/jax

python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.24 -> 0.4.28
parents 8fe1aa68 32d1bb1e
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@

buildPythonPackage rec {
  pname = "blackjax";
  version = "1.2.0";
  version = "1.2.1";
  pyproject = true;

  disabled = pythonOlder "3.9";
@@ -25,7 +25,7 @@ buildPythonPackage rec {
    owner = "blackjax-devs";
    repo = "blackjax";
    rev = "refs/tags/${version}";
    hash = "sha256-vXyxK3xALKG61YGK7fmoqQNGfOiagHFrvnU02WKZThw=";
    hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak=";
  };

  build-system = [
@@ -56,6 +56,10 @@ buildPythonPackage rec {
  disabledTests = [
    # too slow
    "test_adaptive_tempered_smc"
  ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
    # Numerical test (AssertionError)
    # https://github.com/blackjax-devs/blackjax/issues/668
    "test_chees_adaptation"
  ];

  pythonImportsCheck = [
+15 −2
Original line number Diff line number Diff line
@@ -48,8 +48,21 @@ buildPythonPackage rec {
  pythonImportsCheck = [ "equinox" ];

  disabledTests = [
    # Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
    "test_tracetime"
    # For simplicity, JAX has removed its internal frames from the traceback of the following exception.
    # https://github.com/patrick-kidger/equinox/issues/716
    "test_abstract"
    "test_complicated"
    "test_grad"
    "test_jvp"
    "test_mlp"
    "test_num_traces"
    "test_pytree_in"
    "test_simple"
    "test_vmap"

    # AssertionError: assert 'foo:\n   pri...pe=float32)\n' == 'foo:\n   pri...pe=float32)\n'
    # Also reported in patrick-kidger/equinox#716
    "test_backward_nan"
  ];

  meta = with lib; {
+4 −4
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@

buildPythonPackage rec {
  pname = "flax";
  version = "0.8.2";
  version = "0.8.3";
  pyproject = true;

  disabled = pythonOlder "3.9";
@@ -34,16 +34,16 @@ buildPythonPackage rec {
    owner = "google";
    repo = "flax";
    rev = "refs/tags/v${version}";
    hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
    hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU=";
  };

  nativeBuildInputs = [
  build-system = [
    jaxlib
    pythonRelaxDepsHook
    setuptools-scm
  ];

  propagatedBuildInputs = [
  dependencies = [
    jax
    msgpack
    numpy
+10 −2
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.25";
  version = "0.4.28";
  pyproject = true;

  disabled = pythonOlder "3.9";
@@ -39,7 +39,7 @@ buildPythonPackage rec {
    repo = "jax";
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/jax-v${version}";
    hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok=";
    hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
  };

  nativeBuildInputs = [
@@ -81,6 +81,14 @@ buildPythonPackage rec {
    "tests/"
  ];

  # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
  # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py'
  # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241
  # NOTE: this doesn't seem to be an issue on linux
  preCheck = lib.optionalString stdenv.isDarwin ''
    export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d)
  '';

  disabledTests = [
    # Exceeds tolerance when the machine is busy
    "test_custom_linear_solve_aux"
+22 −38
Original line number Diff line number Diff line
@@ -20,17 +20,17 @@
, stdenv
  # Options:
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages
}:

let
  inherit (cudaPackagesGoogle) cudaVersion;
  inherit (cudaPackages) cudaVersion;

  version = "0.4.24";
  version = "0.4.28";

  inherit (python) pythonVersion;

  cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [
  cudaLibPath = lib.makeLibraryPath (with cudaPackages; [
    cuda_cudart.lib # libcudart.so
    cuda_cupti.lib # libcupti.so
    cudnn.lib # libcudnn.so
@@ -56,65 +56,65 @@ let
      "3.9-x86_64-linux" = getSrcFromPypi {
        platform = "manylinux2014_x86_64";
        dist = "cp39";
        hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE=";
        hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw=";
      };
      "3.9-aarch64-darwin" = getSrcFromPypi {
        platform = "macosx_11_0_arm64";
        dist = "cp39";
        hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU=";
        hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw=";
      };
      "3.9-x86_64-darwin" = getSrcFromPypi {
        platform = "macosx_10_14_x86_64";
        dist = "cp39";
        hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik=";
        hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c=";
      };

      "3.10-x86_64-linux" = getSrcFromPypi {
        platform = "manylinux2014_x86_64";
        dist = "cp310";
        hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY=";
        hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps=";
      };
      "3.10-aarch64-darwin" = getSrcFromPypi {
        platform = "macosx_11_0_arm64";
        dist = "cp310";
        hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw=";
        hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk=";
      };
      "3.10-x86_64-darwin" = getSrcFromPypi {
        platform = "macosx_10_14_x86_64";
        dist = "cp310";
        hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ=";
        hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY=";
      };

      "3.11-x86_64-linux" = getSrcFromPypi {
        platform = "manylinux2014_x86_64";
        dist = "cp311";
        hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8=";
        hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU=";
      };
      "3.11-aarch64-darwin" = getSrcFromPypi {
        platform = "macosx_11_0_arm64";
        dist = "cp311";
        hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE=";
        hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck=";
      };
      "3.11-x86_64-darwin" = getSrcFromPypi {
        platform = "macosx_10_14_x86_64";
        dist = "cp311";
        hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ=";
        hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU=";
      };

      "3.12-x86_64-linux" = getSrcFromPypi {
        platform = "manylinux2014_x86_64";
        dist = "cp312";
        hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo=";
        hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40=";
      };
      "3.12-aarch64-darwin" = getSrcFromPypi {
        platform = "macosx_11_0_arm64";
        dist = "cp312";
        hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0=";
        hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10=";
      };
      "3.12-x86_64-darwin" = getSrcFromPypi {
        platform = "macosx_10_14_x86_64";
        dist = "cp312";
        hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE=";
        hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A=";
      };
    };

@@ -130,35 +130,19 @@ let
  gpuSrcs = {
    "cuda12.2-3.9" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
      hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM=";
      hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw=";
    };
    "cuda12.2-3.10" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
      hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE=";
      hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ=";
    };
    "cuda12.2-3.11" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
      hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ=";
      hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o=";
    };
    "cuda12.2-3.12" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
      hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q=";
    };
    "cuda11.8-3.9" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
      hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU=";
    };
    "cuda11.8-3.10" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
      hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk=";
    };
    "cuda11.8-3.11" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
      hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw=";
    };
    "cuda11.8-3.12" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl";
      hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00=";
      hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU=";
    };
  };

@@ -213,7 +197,7 @@ buildPythonPackage {
  # for more info.
  postInstall = lib.optional cudaSupport ''
    mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
    ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
  '';

  inherit (jaxlib-build) pythonImportsCheck;
@@ -227,7 +211,7 @@ buildPythonPackage {
    platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
    broken =
      !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1")
      || !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2")
      || !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2")
      || !(cudaSupport -> stdenv.isLinux)
      || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"))
      # Fails at pythonImportsCheckPhase:
Loading