Unverified Commit e250e010 authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

python3Packages.jax[lib]: 0.6.2 -> 0.7.1 (#427588)

parents a251b8a9 4cfbd1dc
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -75,10 +75,12 @@ buildPythonPackage rec {
    "test_nuts__without_device"
    "test_nuts__without_jit"
    "test_smc_waste_free__with_jit"
  ]
  ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [

    # Numerical test (AssertionError)
    # First report, when the failure was only happening on aarch64-linux:
    # https://github.com/blackjax-devs/blackjax/issues/668
    # Second report, when the test started happening on x86_64-linux too after Jax was updated to 0.7.0
    # https://github.com/blackjax-devs/blackjax/issues/795
    "test_chees_adaptation"
  ];

+0 −1
Original line number Diff line number Diff line
@@ -18,7 +18,6 @@
  cloudpickle,
  dm-tree,
  pytestCheckHook,
  pythonOlder,
}:

buildPythonPackage rec {
+17 −2
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@
  lib,
  buildPythonPackage,
  fetchFromGitHub,
  fetchpatch,
  fetchpatch2,
  chex,
  jaxlib,
  numpy,
@@ -26,13 +26,28 @@ buildPythonPackage rec {

  patches = [
    # TODO: remove at the next release (already on master)
    (fetchpatch {
    (fetchpatch2 {
      name = "fix-jax-0.6.0-compat";
      url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch";
      hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78=";
    })
    # https://github.com/google-deepmind/distrax/pull/289
    (fetchpatch2 {
      name = "fix-jax-0.7.0-compat";
      url = "https://github.com/google-deepmind/distrax/commit/7fc5bd7efff4a7144d175199159f115c3e68a3cf.patch";
      hash = "sha256-TiD72YIb6ajpaCO1yOGl/+JCuaikQ879Zcpaf2wzMq4=";
    })
  ];

  # TODO: remove at the next release (already on master)
  # https://github.com/google-deepmind/distrax/pull/293
  postPatch = ''
    substituteInPlace distrax/_src/utils/transformations.py \
      --replace-fail \
        "jax.experimental.pjit.pjit_p" \
        "jex.core.primitives.jit_p"
  '';

  dependencies = [
    chex
    jaxlib
+21 −1
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
  stdenv,
  buildPythonPackage,
  fetchFromGitHub,
  fetchpatch2,

  # build-system
  hatchling,
@@ -32,12 +33,31 @@ buildPythonPackage rec {
    hash = "sha256-zXgAuFGWKHShKodi9swnWIry4VU9s4pBhBRoK5KzaL0=";
  };

  patches = [
    # The following two patches have been merged upstream and should be removed when updating to the next release
    # They fix the incompatibilities with jax>=0.7.0

    # https://github.com/patrick-kidger/equinox/pull/1086
    (fetchpatch2 {
      name = "remove-deprecated-batching-NotMapped";
      url = "https://github.com/patrick-kidger/equinox/commit/6a6a441ced2fe64191a087752f1c2e71a6ce39f1.patch";
      hash = "sha256-tzHFjMI3gAIh5MPkdbmzsky/oFjDEbOIkPGQMQ+gcQQ=";
    })

    # https://github.com/patrick-kidger/equinox/pull/1082
    (fetchpatch2 {
      name = "allow-creating-weak-references-to-flatten";
      url = "https://github.com/patrick-kidger/equinox/commit/62b3c94ad56bdb63524702b320e977d2d93dbe72.patch";
      hash = "sha256-c1FKCnC3/okuP2VJV4h7sPRYQeYJZSdzEG5ETL2M35k=";
    })
  ];

  # Relax speed constraints on tests that can fail on busy builders
  postPatch = ''
    substituteInPlace tests/test_while_loop.py \
      --replace-fail "speed < 0.1" "speed < 0.5" \
      --replace-fail "speed < 0.5" "speed < 1" \
      --replace-fail "speed < 1" "speed < 4" \
      --replace-fail "speed < 1" "speed < 20"
  '';

  build-system = [ hatchling ];
+3 −3
Original line number Diff line number Diff line
@@ -44,14 +44,14 @@ buildPythonPackage rec {
    dist = "py3";
    platform =
      {
        x86_64-linux = "manylinux2014_x86_64";
        x86_64-linux = "manylinux_2_27_x86_64";
        aarch64-linux = "manylinux2014_aarch64";
      }
      .${stdenv.hostPlatform.system};
    hash =
      {
        x86_64-linux = "sha256-jNnq15SOosd4pQj+9dEVnot6v0/MxwN8P+Hb/NlQEtw=";
        aarch64-linux = "sha256-IvrwINLo98oeKRVjMkH333Z4tzxwePXwsvETJIM3994=";
        x86_64-linux = "sha256-quPn80WASloSKaxxSMvRNEUMgWGYu7/4WnYiGC7H9ko=";
        aarch64-linux = "sha256-f8QFBJYN5W4jL5MmMKJ1fs8/hzZlsTmDF9Jfa3RF1WA=";
      }
      .${stdenv.hostPlatform.system};
  };
Loading