Loading pkgs/development/python-modules/blackjax/default.nix +4 −2 Original line number Diff line number Diff line Loading @@ -75,10 +75,12 @@ buildPythonPackage rec { "test_nuts__without_device" "test_nuts__without_jit" "test_smc_waste_free__with_jit" ] ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [ # Numerical test (AssertionError) # First report, when the failure was only happening on aarch64-linux: # https://github.com/blackjax-devs/blackjax/issues/668 # Second report, when the test started happening on x86_64-linux too after Jax was updated to 0.7.0 # https://github.com/blackjax-devs/blackjax/issues/795 "test_chees_adaptation" ]; Loading pkgs/development/python-modules/chex/default.nix +0 −1 Original line number Diff line number Diff line Loading @@ -18,7 +18,6 @@ cloudpickle, dm-tree, pytestCheckHook, pythonOlder, }: buildPythonPackage rec { Loading pkgs/development/python-modules/distrax/default.nix +17 −2 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ lib, buildPythonPackage, fetchFromGitHub, fetchpatch, fetchpatch2, chex, jaxlib, numpy, Loading @@ -26,13 +26,28 @@ buildPythonPackage rec { patches = [ # TODO: remove at the next release (already on master) (fetchpatch { (fetchpatch2 { name = "fix-jax-0.6.0-compat"; url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch"; hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78="; }) # https://github.com/google-deepmind/distrax/pull/289 (fetchpatch2 { name = "fix-jax-0.7.0-compat"; url = "https://github.com/google-deepmind/distrax/commit/7fc5bd7efff4a7144d175199159f115c3e68a3cf.patch"; hash = "sha256-TiD72YIb6ajpaCO1yOGl/+JCuaikQ879Zcpaf2wzMq4="; }) ]; # TODO: remove at the next release (already on master) # https://github.com/google-deepmind/distrax/pull/293 postPatch = '' substituteInPlace distrax/_src/utils/transformations.py \ --replace-fail \ "jax.experimental.pjit.pjit_p" \ "jex.core.primitives.jit_p" ''; dependencies = [ chex jaxlib Loading pkgs/development/python-modules/equinox/default.nix +21 −1 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ stdenv, buildPythonPackage, fetchFromGitHub, fetchpatch2, # build-system hatchling, Loading Loading @@ -32,12 +33,31 @@ buildPythonPackage rec { hash = "sha256-zXgAuFGWKHShKodi9swnWIry4VU9s4pBhBRoK5KzaL0="; }; patches = [ # The following two patches have been merged upstream and should be removed when updating to the next release # They fix the incompatibilities with jax>=0.7.0 # https://github.com/patrick-kidger/equinox/pull/1086 (fetchpatch2 { name = "remove-deprecated-batching-NotMapped"; url = "https://github.com/patrick-kidger/equinox/commit/6a6a441ced2fe64191a087752f1c2e71a6ce39f1.patch"; hash = "sha256-tzHFjMI3gAIh5MPkdbmzsky/oFjDEbOIkPGQMQ+gcQQ="; }) # https://github.com/patrick-kidger/equinox/pull/1082 (fetchpatch2 { name = "allow-creating-weak-references-to-flatten"; url = "https://github.com/patrick-kidger/equinox/commit/62b3c94ad56bdb63524702b320e977d2d93dbe72.patch"; hash = "sha256-c1FKCnC3/okuP2VJV4h7sPRYQeYJZSdzEG5ETL2M35k="; }) ]; # Relax speed constraints on tests that can fail on busy builders postPatch = '' substituteInPlace tests/test_while_loop.py \ --replace-fail "speed < 0.1" "speed < 0.5" \ --replace-fail "speed < 0.5" "speed < 1" \ --replace-fail "speed < 1" "speed < 4" \ --replace-fail "speed < 1" "speed < 20" ''; build-system = [ hatchling ]; Loading pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +3 −3 Original line number Diff line number Diff line Loading @@ -44,14 +44,14 @@ buildPythonPackage rec { dist = "py3"; platform = { x86_64-linux = "manylinux2014_x86_64"; x86_64-linux = "manylinux_2_27_x86_64"; aarch64-linux = "manylinux2014_aarch64"; } .${stdenv.hostPlatform.system}; hash = { x86_64-linux = "sha256-jNnq15SOosd4pQj+9dEVnot6v0/MxwN8P+Hb/NlQEtw="; aarch64-linux = "sha256-IvrwINLo98oeKRVjMkH333Z4tzxwePXwsvETJIM3994="; x86_64-linux = "sha256-quPn80WASloSKaxxSMvRNEUMgWGYu7/4WnYiGC7H9ko="; aarch64-linux = "sha256-f8QFBJYN5W4jL5MmMKJ1fs8/hzZlsTmDF9Jfa3RF1WA="; } .${stdenv.hostPlatform.system}; }; Loading Loading
pkgs/development/python-modules/blackjax/default.nix +4 −2 Original line number Diff line number Diff line Loading @@ -75,10 +75,12 @@ buildPythonPackage rec { "test_nuts__without_device" "test_nuts__without_jit" "test_smc_waste_free__with_jit" ] ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [ # Numerical test (AssertionError) # First report, when the failure was only happening on aarch64-linux: # https://github.com/blackjax-devs/blackjax/issues/668 # Second report, when the test started happening on x86_64-linux too after Jax was updated to 0.7.0 # https://github.com/blackjax-devs/blackjax/issues/795 "test_chees_adaptation" ]; Loading
pkgs/development/python-modules/chex/default.nix +0 −1 Original line number Diff line number Diff line Loading @@ -18,7 +18,6 @@ cloudpickle, dm-tree, pytestCheckHook, pythonOlder, }: buildPythonPackage rec { Loading
pkgs/development/python-modules/distrax/default.nix +17 −2 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ lib, buildPythonPackage, fetchFromGitHub, fetchpatch, fetchpatch2, chex, jaxlib, numpy, Loading @@ -26,13 +26,28 @@ buildPythonPackage rec { patches = [ # TODO: remove at the next release (already on master) (fetchpatch { (fetchpatch2 { name = "fix-jax-0.6.0-compat"; url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch"; hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78="; }) # https://github.com/google-deepmind/distrax/pull/289 (fetchpatch2 { name = "fix-jax-0.7.0-compat"; url = "https://github.com/google-deepmind/distrax/commit/7fc5bd7efff4a7144d175199159f115c3e68a3cf.patch"; hash = "sha256-TiD72YIb6ajpaCO1yOGl/+JCuaikQ879Zcpaf2wzMq4="; }) ]; # TODO: remove at the next release (already on master) # https://github.com/google-deepmind/distrax/pull/293 postPatch = '' substituteInPlace distrax/_src/utils/transformations.py \ --replace-fail \ "jax.experimental.pjit.pjit_p" \ "jex.core.primitives.jit_p" ''; dependencies = [ chex jaxlib Loading
pkgs/development/python-modules/equinox/default.nix +21 −1 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ stdenv, buildPythonPackage, fetchFromGitHub, fetchpatch2, # build-system hatchling, Loading Loading @@ -32,12 +33,31 @@ buildPythonPackage rec { hash = "sha256-zXgAuFGWKHShKodi9swnWIry4VU9s4pBhBRoK5KzaL0="; }; patches = [ # The following two patches have been merged upstream and should be removed when updating to the next release # They fix the incompatibilities with jax>=0.7.0 # https://github.com/patrick-kidger/equinox/pull/1086 (fetchpatch2 { name = "remove-deprecated-batching-NotMapped"; url = "https://github.com/patrick-kidger/equinox/commit/6a6a441ced2fe64191a087752f1c2e71a6ce39f1.patch"; hash = "sha256-tzHFjMI3gAIh5MPkdbmzsky/oFjDEbOIkPGQMQ+gcQQ="; }) # https://github.com/patrick-kidger/equinox/pull/1082 (fetchpatch2 { name = "allow-creating-weak-references-to-flatten"; url = "https://github.com/patrick-kidger/equinox/commit/62b3c94ad56bdb63524702b320e977d2d93dbe72.patch"; hash = "sha256-c1FKCnC3/okuP2VJV4h7sPRYQeYJZSdzEG5ETL2M35k="; }) ]; # Relax speed constraints on tests that can fail on busy builders postPatch = '' substituteInPlace tests/test_while_loop.py \ --replace-fail "speed < 0.1" "speed < 0.5" \ --replace-fail "speed < 0.5" "speed < 1" \ --replace-fail "speed < 1" "speed < 4" \ --replace-fail "speed < 1" "speed < 20" ''; build-system = [ hatchling ]; Loading
pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +3 −3 Original line number Diff line number Diff line Loading @@ -44,14 +44,14 @@ buildPythonPackage rec { dist = "py3"; platform = { x86_64-linux = "manylinux2014_x86_64"; x86_64-linux = "manylinux_2_27_x86_64"; aarch64-linux = "manylinux2014_aarch64"; } .${stdenv.hostPlatform.system}; hash = { x86_64-linux = "sha256-jNnq15SOosd4pQj+9dEVnot6v0/MxwN8P+Hb/NlQEtw="; aarch64-linux = "sha256-IvrwINLo98oeKRVjMkH333Z4tzxwePXwsvETJIM3994="; x86_64-linux = "sha256-quPn80WASloSKaxxSMvRNEUMgWGYu7/4WnYiGC7H9ko="; aarch64-linux = "sha256-f8QFBJYN5W4jL5MmMKJ1fs8/hzZlsTmDF9Jfa3RF1WA="; } .${stdenv.hostPlatform.system}; }; Loading