Commit f696ae5b authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python312Packages.waymax: adapt to jax 0.6.0 new API changes

parent b4ec483f
Loading
Loading
Loading
Loading
+16 −1
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@
  tqdm,
}:

buildPythonPackage rec {
buildPythonPackage {
  pname = "waymax";
  version = "0-unstable-2025-03-25";
  pyproject = true;
@@ -31,6 +31,21 @@ buildPythonPackage rec {
    hash = "sha256-B1Rp5MATbEelp6G6K2wwV83QpINhOHgvAxb3mBN52Eg=";
  };

  # AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  # https://github.com/waymo-research/waymax/pull/77
  postPatch = ''
    substituteInPlace \
      waymax/agents/expert.py \
      waymax/agents/waypoint_following_agent.py \
      waymax/agents/waypoint_following_agent_test.py \
      waymax/dynamics/abstract_dynamics_test.py \
      waymax/dynamics/state_dynamics_test.py \
      waymax/env/base_environment_test.py \
      waymax/env/rollout_test.py \
      waymax/env/wrappers/brax_wrapper_test.py \
      --replace-fail "jax.tree_map" "jax.tree_util.tree_map"
  '';

  build-system = [ setuptools ];

  dependencies = [