Commit 14e819b8 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python3Packages.skrl: fix jax 0.10.0 compatibility

parent 05f96ae1
Loading
Loading
Loading
Loading
+17 −0
Original line number Diff line number Diff line
@@ -37,6 +37,23 @@ buildPythonPackage (finalAttrs: {
    hash = "sha256-5lkoYAmMIWqK3+E3WxXMWS9zal2DhZkfl30EkrHKpdI=";
  };

  # Fix Jax 0.10.0 compatibility
  # TypeError: clip() got an unexpected keyword argument 'a_min'
  postPatch = ''
    substituteInPlace skrl/models/jax/gaussian.py \
      --replace-fail \
        "jnp.clip(log_std, a_min=log_std_min, a_max=log_std_max)" \
        "jnp.clip(log_std, min=log_std_min, max=log_std_max)" \
      --replace-fail \
        "jnp.clip(actions, a_min=clip_actions_min, a_max=clip_actions_max)" \
        "jnp.clip(actions, min=clip_actions_min, max=clip_actions_max)"

    substituteInPlace skrl/models/jax/deterministic.py \
      --replace-fail \
        "jnp.clip(actions, a_min=self._d_clip_actions_min, a_max=self._d_clip_actions_max)" \
        "jnp.clip(actions, min=self._d_clip_actions_min, max=self._d_clip_actions_max)"
  '';

  build-system = [ setuptools ];

  dependencies = [