Loading pkgs/development/python-modules/numpyro/default.nix +5 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,11 @@ buildPythonPackage (finalAttrs: { hash = "sha256-sNqllL9nBwXp0kn+HAjvIaHf7LR0UKh9q7DZ20yCr5A="; }; patches = [ # Remove usage of xla_pmap_p which was removed in jax 0.10.0 ./fix-jax-0.10.0-compat.patch ]; build-system = [ setuptools ]; dependencies = [ Loading pkgs/development/python-modules/numpyro/fix-jax-0.10.0-compat.patch 0 → 100644 +20 −0 Original line number Diff line number Diff line diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index 1234567..abcdefg 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -4,7 +4,7 @@ import jax from jax.api_util import debug_info, flatten_fun, shaped_abstractify from jax.extend.core import Literal -from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p +from jax.extend.core.primitives import call_p, closed_call_p, jit_p import jax.extend.linear_util as lu from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic @@ -114,7 +114,6 @@ def track_deps_call_rule(eqn, provenance_inputs): track_deps_rules[call_p] = track_deps_call_rule -track_deps_rules[xla_pmap_p] = track_deps_call_rule def track_deps_closed_call_rule(eqn, provenance_inputs): Loading
pkgs/development/python-modules/numpyro/default.nix +5 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,11 @@ buildPythonPackage (finalAttrs: { hash = "sha256-sNqllL9nBwXp0kn+HAjvIaHf7LR0UKh9q7DZ20yCr5A="; }; patches = [ # Remove usage of xla_pmap_p which was removed in jax 0.10.0 ./fix-jax-0.10.0-compat.patch ]; build-system = [ setuptools ]; dependencies = [ Loading
pkgs/development/python-modules/numpyro/fix-jax-0.10.0-compat.patch 0 → 100644 +20 −0 Original line number Diff line number Diff line diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index 1234567..abcdefg 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -4,7 +4,7 @@ import jax from jax.api_util import debug_info, flatten_fun, shaped_abstractify from jax.extend.core import Literal -from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p +from jax.extend.core.primitives import call_p, closed_call_p, jit_p import jax.extend.linear_util as lu from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic @@ -114,7 +114,6 @@ def track_deps_call_rule(eqn, provenance_inputs): track_deps_rules[call_p] = track_deps_call_rule -track_deps_rules[xla_pmap_p] = track_deps_call_rule def track_deps_closed_call_rule(eqn, provenance_inputs):