Loading pkgs/development/python-modules/brax/default.nix +15 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,7 @@ buildPythonPackage (finalAttrs: { pname = "brax"; version = "0.14.2"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google"; Loading @@ -48,6 +49,20 @@ buildPythonPackage (finalAttrs: { hash = "sha256-/oznBa44xKl+9T1YrTVhCbuKZj16V1BTlnmCGRbF45g="; }; patches = [ # AttributeError: jax.device_put_replicated is deprecated; use jax.device_put instead. # See https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for a drop-in replacement. ./dont-use-device_put_replicated-compat.patch ]; # TypeError: clip() got an unexpected keyword argument 'a_min' postPatch = '' substituteInPlace brax/fluid.py \ --replace-fail \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)" \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), min=1e-12)" ''; build-system = [ hatchling ]; Loading pkgs/development/python-modules/brax/dont-use-device_put_replicated-compat.patch 0 → 100644 +68 −0 Original line number Diff line number Diff line diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index f5fcb0e..87b198f 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -310,7 +310,7 @@ def train( specs.Array((env.observation_size,), jnp.dtype(dtype)) ), ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 9aec960..6624733 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -753,7 +753,7 @@ def train( {}, ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index be716e9..8dcf3bf 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -108,7 +108,7 @@ def _init_training_state( alpha_params=log_alpha, normalizer_params=normalizer_params, ) - return jax.device_put_replicated( + return pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/pmap.py b/brax/training/pmap.py index 82760fc..af62ef8 100644 --- a/brax/training/pmap.py +++ b/brax/training/pmap.py @@ -19,12 +19,23 @@ from typing import Any import jax import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + + +def device_put_replicated(x, devices): + """Drop-in replacement for jax.device_put_replicated supporting pytrees.""" + mesh = Mesh(np.array(devices), ('x',)) + sharding = NamedSharding(mesh, P('x')) + return jax.tree.map( + lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x + ) def bcast_local_devices(value, local_devices_to_use=1): """Broadcasts an object to all local devices.""" devices = jax.local_devices()[:local_devices_to_use] - return jax.device_put_replicated(value, devices) + return device_put_replicated(value, devices) def synchronize_hosts(): Loading
pkgs/development/python-modules/brax/default.nix +15 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,7 @@ buildPythonPackage (finalAttrs: { pname = "brax"; version = "0.14.2"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google"; Loading @@ -48,6 +49,20 @@ buildPythonPackage (finalAttrs: { hash = "sha256-/oznBa44xKl+9T1YrTVhCbuKZj16V1BTlnmCGRbF45g="; }; patches = [ # AttributeError: jax.device_put_replicated is deprecated; use jax.device_put instead. # See https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for a drop-in replacement. ./dont-use-device_put_replicated-compat.patch ]; # TypeError: clip() got an unexpected keyword argument 'a_min' postPatch = '' substituteInPlace brax/fluid.py \ --replace-fail \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)" \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), min=1e-12)" ''; build-system = [ hatchling ]; Loading
pkgs/development/python-modules/brax/dont-use-device_put_replicated-compat.patch 0 → 100644 +68 −0 Original line number Diff line number Diff line diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index f5fcb0e..87b198f 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -310,7 +310,7 @@ def train( specs.Array((env.observation_size,), jnp.dtype(dtype)) ), ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 9aec960..6624733 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -753,7 +753,7 @@ def train( {}, ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index be716e9..8dcf3bf 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -108,7 +108,7 @@ def _init_training_state( alpha_params=log_alpha, normalizer_params=normalizer_params, ) - return jax.device_put_replicated( + return pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/pmap.py b/brax/training/pmap.py index 82760fc..af62ef8 100644 --- a/brax/training/pmap.py +++ b/brax/training/pmap.py @@ -19,12 +19,23 @@ from typing import Any import jax import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + + +def device_put_replicated(x, devices): + """Drop-in replacement for jax.device_put_replicated supporting pytrees.""" + mesh = Mesh(np.array(devices), ('x',)) + sharding = NamedSharding(mesh, P('x')) + return jax.tree.map( + lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x + ) def bcast_local_devices(value, local_devices_to_use=1): """Broadcasts an object to all local devices.""" devices = jax.local_devices()[:local_devices_to_use] - return jax.device_put_replicated(value, devices) + return device_put_replicated(value, devices) def synchronize_hosts():