Unverified Commit 97b8492e authored by Yt's avatar Yt Committed by GitHub
Browse files

python3Packages.jax: 0.9.2 -> 0.10.0 (#511567)

parents dc66b370 3d6866c3
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ buildPythonPackage (finalAttrs: {
  pname = "brax";
  version = "0.14.2";
  pyproject = true;
  __structuredAttrs = true;

  src = fetchFromGitHub {
    owner = "google";
@@ -48,6 +49,20 @@ buildPythonPackage (finalAttrs: {
    hash = "sha256-/oznBa44xKl+9T1YrTVhCbuKZj16V1BTlnmCGRbF45g=";
  };

  patches = [
    # AttributeError: jax.device_put_replicated is deprecated; use jax.device_put instead.
    # See https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for a drop-in replacement.
    ./dont-use-device_put_replicated-compat.patch
  ];

  # TypeError: clip() got an unexpected keyword argument 'a_min'
  postPatch = ''
    substituteInPlace brax/fluid.py \
      --replace-fail \
        "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)" \
        "box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), min=1e-12)"
  '';

  build-system = [
    hatchling
  ];
+68 −0
Original line number Diff line number Diff line
diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py
index f5fcb0e..87b198f 100644
--- a/brax/training/agents/apg/train.py
+++ b/brax/training/agents/apg/train.py
@@ -310,7 +310,7 @@ def train(
           specs.Array((env.observation_size,), jnp.dtype(dtype))
       ),
   )
-  training_state = jax.device_put_replicated(
+  training_state = pmap.device_put_replicated(
       training_state, jax.local_devices()[:local_devices_to_use]
   )
 
diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py
index 9aec960..6624733 100644
--- a/brax/training/agents/ppo/train.py
+++ b/brax/training/agents/ppo/train.py
@@ -753,7 +753,7 @@ def train(
         {},
     )
 
-  training_state = jax.device_put_replicated(
+  training_state = pmap.device_put_replicated(
       training_state, jax.local_devices()[:local_devices_to_use]
   )
 
diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py
index be716e9..8dcf3bf 100644
--- a/brax/training/agents/sac/train.py
+++ b/brax/training/agents/sac/train.py
@@ -108,7 +108,7 @@ def _init_training_state(
       alpha_params=log_alpha,
       normalizer_params=normalizer_params,
   )
-  return jax.device_put_replicated(
+  return pmap.device_put_replicated(
       training_state, jax.local_devices()[:local_devices_to_use]
   )
 
diff --git a/brax/training/pmap.py b/brax/training/pmap.py
index 82760fc..af62ef8 100644
--- a/brax/training/pmap.py
+++ b/brax/training/pmap.py
@@ -19,12 +19,23 @@ from typing import Any
 
 import jax
 import jax.numpy as jnp
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+import numpy as np
+
+
+def device_put_replicated(x, devices):
+  """Drop-in replacement for jax.device_put_replicated supporting pytrees."""
+  mesh = Mesh(np.array(devices), ('x',))
+  sharding = NamedSharding(mesh, P('x'))
+  return jax.tree.map(
+      lambda y: jax.device_put(jnp.stack([y] * len(devices)), sharding), x
+  )
 
 
 def bcast_local_devices(value, local_devices_to_use=1):
   """Broadcasts an object to all local devices."""
   devices = jax.local_devices()[:local_devices_to_use]
-  return jax.device_put_replicated(value, devices)
+  return device_put_replicated(value, devices)
 
 
 def synchronize_hosts():
+18 −7
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
  lib,
  buildPythonPackage,
  fetchFromGitHub,
  fetchpatch,

  # build-system
  flit-core,
@@ -20,25 +21,35 @@
  pytestCheckHook,
}:

buildPythonPackage rec {
buildPythonPackage (finalAttrs: {
  pname = "chex";
  version = "0.1.91";
  pyproject = true;
  __structuredAttrs = true;

  src = fetchFromGitHub {
    owner = "deepmind";
    repo = "chex";
    tag = "v${version}";
    tag = "v${finalAttrs.version}";
    hash = "sha256-lJ9+kvG7dRtfDVgvkcJ9/jtnX0lMfxY4mmZ290y/74U=";
  };

  patches = [
    # jax.device_put_replicated is removed in jax 0.10.0
    # This fix was merged upstream -> remove when updating to the next release
    (fetchpatch {
      url = "https://github.com/google-deepmind/chex/commit/5fbd2c9a9936799daf92354e0307b9e88b9cc163.patch";
      excludes = [
        "chex/_src/variants.py"
      ];
      hash = "sha256-ZTimSq7/yt2UEiWmLcfFBadX8+VcaxuPhkQJEyiEZlE=";
    })
  ];

  build-system = [
    flit-core
  ];

  pythonRelaxDeps = [
    "typing_extensions"
  ];
  dependencies = [
    absl-py
    jax
@@ -69,8 +80,8 @@ buildPythonPackage rec {
  meta = {
    description = "Library of utilities for helping to write reliable JAX code";
    homepage = "https://github.com/deepmind/chex";
    changelog = "https://github.com/google-deepmind/chex/releases/tag/v${version}";
    changelog = "https://github.com/google-deepmind/chex/releases/tag/${finalAttrs.src.tag}";
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ ndl ];
  };
}
})
+10 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ buildPythonPackage (finalAttrs: {
  pname = "clu";
  version = "0.0.12";
  pyproject = true;
  __structuredAttrs = true;

  src = fetchFromGitHub {
    owner = "google";
@@ -39,6 +40,15 @@ buildPythonPackage (finalAttrs: {
    hash = "sha256-ntqRz3fCXMf0EDQsddT68Mdi105ECBVQpVsStzk2kvQ=";
  };

  # Fix Jax 0.10.0 compatibility
  # TypeError: clip() got an unexpected keyword argument 'a_min'
  postPatch = ''
    substituteInPlace clu/metrics.py \
      --replace-fail \
        "variance = jnp.clip(variance, a_min=0.0)" \
        "variance = jnp.clip(variance, min=0.0)"
  '';

  build-system = [
    setuptools
  ];
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ buildPythonPackage (finalAttrs: {
  pname = "distrax";
  version = "0.1.8";
  pyproject = true;
  __structuredAttrs = true;

  src = fetchFromGitHub {
    owner = "google-deepmind";
Loading