Commit 5f5f0d62 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python3Packages.numpyro: fix jax 0.10.0 compatibility

parent bf948e86
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -40,6 +40,11 @@ buildPythonPackage (finalAttrs: {
    hash = "sha256-sNqllL9nBwXp0kn+HAjvIaHf7LR0UKh9q7DZ20yCr5A=";
  };

  patches = [
    # Remove usage of xla_pmap_p which was removed in jax 0.10.0
    ./fix-jax-0.10.0-compat.patch
  ];

  build-system = [ setuptools ];

  dependencies = [
+20 −0
Original line number Diff line number Diff line
diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py
index 1234567..abcdefg 100644
--- a/numpyro/ops/provenance.py
+++ b/numpyro/ops/provenance.py
@@ -4,7 +4,7 @@
 import jax
 from jax.api_util import debug_info, flatten_fun, shaped_abstractify
 from jax.extend.core import Literal
-from jax.extend.core.primitives import call_p, closed_call_p, jit_p, xla_pmap_p
+from jax.extend.core.primitives import call_p, closed_call_p, jit_p
 import jax.extend.linear_util as lu
 from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic

@@ -114,7 +114,6 @@ def track_deps_call_rule(eqn, provenance_inputs):

 track_deps_rules[call_p] = track_deps_call_rule
-track_deps_rules[xla_pmap_p] = track_deps_call_rule


 def track_deps_closed_call_rule(eqn, provenance_inputs):