Loading pkgs/development/python-modules/brax/default.nix +15 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,7 @@ buildPythonPackage (finalAttrs: { pname = "brax"; version = "0.14.2"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google"; Loading @@ -48,6 +49,20 @@ buildPythonPackage (finalAttrs: { hash = "sha256-/oznBa44xKl+9T1YrTVhCbuKZj16V1BTlnmCGRbF45g="; }; patches = [ # AttributeError: jax.device_put_replicated is deprecated; use jax.device_put instead. # See https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for a drop-in replacement. ./dont-use-device_put_replicated-compat.patch ]; # TypeError: clip() got an unexpected keyword argument 'a_min' postPatch = '' substituteInPlace brax/fluid.py \ --replace-fail \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)" \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), min=1e-12)" ''; build-system = [ hatchling ]; Loading pkgs/development/python-modules/brax/dont-use-device_put_replicated-compat.patch 0 → 100644 +68 −0 Original line number Diff line number Diff line diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index f5fcb0e..87b198f 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -310,7 +310,7 @@ def train( specs.Array((env.observation_size,), jnp.dtype(dtype)) ), ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 9aec960..6624733 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -753,7 +753,7 @@ def train( {}, ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index be716e9..8dcf3bf 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -108,7 +108,7 @@ def _init_training_state( alpha_params=log_alpha, normalizer_params=normalizer_params, ) - return jax.device_put_replicated( + return pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/pmap.py b/brax/training/pmap.py index 82760fc..af62ef8 100644 --- a/brax/training/pmap.py +++ b/brax/training/pmap.py @@ -19,12 +19,23 @@ from typing import Any import jax import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + + +def device_put_replicated(x, devices): + """Drop-in replacement for jax.device_put_replicated supporting pytrees.""" + mesh = Mesh(np.array(devices), ('x',)) + sharding = NamedSharding(mesh, P('x')) + return jax.tree.map( + lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x + ) def bcast_local_devices(value, local_devices_to_use=1): """Broadcasts an object to all local devices.""" devices = jax.local_devices()[:local_devices_to_use] - return jax.device_put_replicated(value, devices) + return device_put_replicated(value, devices) def synchronize_hosts(): pkgs/development/python-modules/chex/default.nix +18 −7 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ lib, buildPythonPackage, fetchFromGitHub, fetchpatch, # build-system flit-core, Loading @@ -20,25 +21,35 @@ pytestCheckHook, }: buildPythonPackage rec { buildPythonPackage (finalAttrs: { pname = "chex"; version = "0.1.91"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "deepmind"; repo = "chex"; tag = "v${version}"; tag = "v${finalAttrs.version}"; hash = "sha256-lJ9+kvG7dRtfDVgvkcJ9/jtnX0lMfxY4mmZ290y/74U="; }; patches = [ # jax.device_put_replicated is removed in jax 0.10.0 # This fix was merged upstream -> remove when updating to the next release (fetchpatch { url = "https://github.com/google-deepmind/chex/commit/5fbd2c9a9936799daf92354e0307b9e88b9cc163.patch"; excludes = [ "chex/_src/variants.py" ]; hash = "sha256-ZTimSq7/yt2UEiWmLcfFBadX8+VcaxuPhkQJEyiEZlE="; }) ]; build-system = [ flit-core ]; pythonRelaxDeps = [ "typing_extensions" ]; dependencies = [ absl-py jax Loading Loading @@ -69,8 +80,8 @@ buildPythonPackage rec { meta = { description = "Library of utilities for helping to write reliable JAX code"; homepage = "https://github.com/deepmind/chex"; changelog = "https://github.com/google-deepmind/chex/releases/tag/v${version}"; changelog = "https://github.com/google-deepmind/chex/releases/tag/${finalAttrs.src.tag}"; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ ndl ]; }; } }) pkgs/development/python-modules/clu/default.nix +10 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ buildPythonPackage (finalAttrs: { pname = "clu"; version = "0.0.12"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google"; Loading @@ -39,6 +40,15 @@ buildPythonPackage (finalAttrs: { hash = "sha256-ntqRz3fCXMf0EDQsddT68Mdi105ECBVQpVsStzk2kvQ="; }; # Fix Jax 0.10.0 compatibility # TypeError: clip() got an unexpected keyword argument 'a_min' postPatch = '' substituteInPlace clu/metrics.py \ --replace-fail \ "variance = jnp.clip(variance, a_min=0.0)" \ "variance = jnp.clip(variance, min=0.0)" ''; build-system = [ setuptools ]; Loading pkgs/development/python-modules/distrax/default.nix +1 −0 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ buildPythonPackage (finalAttrs: { pname = "distrax"; version = "0.1.8"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google-deepmind"; Loading Loading
pkgs/development/python-modules/brax/default.nix +15 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,7 @@ buildPythonPackage (finalAttrs: { pname = "brax"; version = "0.14.2"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google"; Loading @@ -48,6 +49,20 @@ buildPythonPackage (finalAttrs: { hash = "sha256-/oznBa44xKl+9T1YrTVhCbuKZj16V1BTlnmCGRbF45g="; }; patches = [ # AttributeError: jax.device_put_replicated is deprecated; use jax.device_put instead. # See https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for a drop-in replacement. ./dont-use-device_put_replicated-compat.patch ]; # TypeError: clip() got an unexpected keyword argument 'a_min' postPatch = '' substituteInPlace brax/fluid.py \ --replace-fail \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)" \ "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), min=1e-12)" ''; build-system = [ hatchling ]; Loading
pkgs/development/python-modules/brax/dont-use-device_put_replicated-compat.patch 0 → 100644 +68 −0 Original line number Diff line number Diff line diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index f5fcb0e..87b198f 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -310,7 +310,7 @@ def train( specs.Array((env.observation_size,), jnp.dtype(dtype)) ), ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 9aec960..6624733 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -753,7 +753,7 @@ def train( {}, ) - training_state = jax.device_put_replicated( + training_state = pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index be716e9..8dcf3bf 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -108,7 +108,7 @@ def _init_training_state( alpha_params=log_alpha, normalizer_params=normalizer_params, ) - return jax.device_put_replicated( + return pmap.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use] ) diff --git a/brax/training/pmap.py b/brax/training/pmap.py index 82760fc..af62ef8 100644 --- a/brax/training/pmap.py +++ b/brax/training/pmap.py @@ -19,12 +19,23 @@ from typing import Any import jax import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + + +def device_put_replicated(x, devices): + """Drop-in replacement for jax.device_put_replicated supporting pytrees.""" + mesh = Mesh(np.array(devices), ('x',)) + sharding = NamedSharding(mesh, P('x')) + return jax.tree.map( + lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x + ) def bcast_local_devices(value, local_devices_to_use=1): """Broadcasts an object to all local devices.""" devices = jax.local_devices()[:local_devices_to_use] - return jax.device_put_replicated(value, devices) + return device_put_replicated(value, devices) def synchronize_hosts():
pkgs/development/python-modules/chex/default.nix +18 −7 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ lib, buildPythonPackage, fetchFromGitHub, fetchpatch, # build-system flit-core, Loading @@ -20,25 +21,35 @@ pytestCheckHook, }: buildPythonPackage rec { buildPythonPackage (finalAttrs: { pname = "chex"; version = "0.1.91"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "deepmind"; repo = "chex"; tag = "v${version}"; tag = "v${finalAttrs.version}"; hash = "sha256-lJ9+kvG7dRtfDVgvkcJ9/jtnX0lMfxY4mmZ290y/74U="; }; patches = [ # jax.device_put_replicated is removed in jax 0.10.0 # This fix was merged upstream -> remove when updating to the next release (fetchpatch { url = "https://github.com/google-deepmind/chex/commit/5fbd2c9a9936799daf92354e0307b9e88b9cc163.patch"; excludes = [ "chex/_src/variants.py" ]; hash = "sha256-ZTimSq7/yt2UEiWmLcfFBadX8+VcaxuPhkQJEyiEZlE="; }) ]; build-system = [ flit-core ]; pythonRelaxDeps = [ "typing_extensions" ]; dependencies = [ absl-py jax Loading Loading @@ -69,8 +80,8 @@ buildPythonPackage rec { meta = { description = "Library of utilities for helping to write reliable JAX code"; homepage = "https://github.com/deepmind/chex"; changelog = "https://github.com/google-deepmind/chex/releases/tag/v${version}"; changelog = "https://github.com/google-deepmind/chex/releases/tag/${finalAttrs.src.tag}"; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ ndl ]; }; } })
pkgs/development/python-modules/clu/default.nix +10 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ buildPythonPackage (finalAttrs: { pname = "clu"; version = "0.0.12"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google"; Loading @@ -39,6 +40,15 @@ buildPythonPackage (finalAttrs: { hash = "sha256-ntqRz3fCXMf0EDQsddT68Mdi105ECBVQpVsStzk2kvQ="; }; # Fix Jax 0.10.0 compatibility # TypeError: clip() got an unexpected keyword argument 'a_min' postPatch = '' substituteInPlace clu/metrics.py \ --replace-fail \ "variance = jnp.clip(variance, a_min=0.0)" \ "variance = jnp.clip(variance, min=0.0)" ''; build-system = [ setuptools ]; Loading
pkgs/development/python-modules/distrax/default.nix +1 −0 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ buildPythonPackage (finalAttrs: { pname = "distrax"; version = "0.1.8"; pyproject = true; __structuredAttrs = true; src = fetchFromGitHub { owner = "google-deepmind"; Loading