Loading pkgs/development/python-modules/distrax/default.nix +14 −0 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ lib, buildPythonPackage, fetchFromGitHub, fetchpatch, chex, jaxlib, numpy, Loading @@ -23,6 +24,15 @@ buildPythonPackage rec { hash = "sha256-A1aCL/I89Blg9sNmIWQru4QJteUTN6+bhgrEJPmCrM0="; }; patches = [ # TODO: remove at the next release (already on master) (fetchpatch { name = "fix-jax-0.6.0-compat"; url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch"; hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78="; }) ]; dependencies = [ chex jaxlib Loading Loading @@ -71,6 +81,10 @@ buildPythonPackage rec { ]; disabledTestPaths = [ # Since jax 0.6.0: # TypeError: <lambda>() got an unexpected keyword argument 'accuracy' "distrax/_src/bijectors/lambda_bijector_test.py" # TypeErrors "distrax/_src/bijectors/tfp_compatible_bijector_test.py" "distrax/_src/distributions/distribution_from_tfp_test.py" Loading pkgs/development/python-modules/dm-haiku/default.nix +11 −0 Original line number Diff line number Diff line Loading @@ -58,6 +58,17 @@ let }) ]; # AttributeError: jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, and # see https://docs.jax.dev/en/latest/jax.extend.html for details. # Alrady on master: https://github.com/google-deepmind/dm-haiku/commit/cfe8480d253a93100bf5e2d24c40435a95399c96 # TODO: remove at the next release postPatch = '' substituteInPlace haiku/_src/jaxpr_info.py \ --replace-fail "jax.core.JaxprEqn" "jax.extend.core.JaxprEqn" \ --replace-fail "jax.core.Var" "jax.extend.core.Var" \ --replace-fail "jax.core.Jaxpr" "jax.extend.core.Jaxpr" ''; build-system = [ setuptools ]; dependencies = [ Loading pkgs/development/python-modules/gymnasium/default.nix +9 −0 Original line number Diff line number Diff line Loading @@ -90,6 +90,15 @@ buildPythonPackage rec { ]; disabledTests = [ # Fails since jax 0.6.0 # Fixed on master https://github.com/Farama-Foundation/Gymnasium/commit/94019feee1a0f945b9569cddf62780f4e1a224a5 # TODO: un-skip at the next release "test_all_env_api" "test_env_determinism_rollout" "test_jax_to_numpy_wrapper" "test_pickle_env" "test_roundtripping" # Succeeds for most environments but `test_render_modes[Reacher-v4]` fails because it requires # OpenGL access which is not possible inside the sandbox. "test_render_mode" Loading pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +22 −21 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ lib, stdenv, buildPythonPackage, fetchurl, fetchPypi, addDriverRunpath, autoPatchelfHook, pypaInstallHook, Loading Loading @@ -31,30 +31,31 @@ let ] ); # Find new releases at https://storage.googleapis.com/jax-releases # When upgrading, you can get these hashes from jaxlib/prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. # upstream does not distribute jax-cuda12-pjrt binaries for aarch64-linux srcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; hash = "sha256-xTeDBlaLoMgbIwp3ndMZTJ3RAzmrY2CugJKBCNN+f3U="; }; # "aarch64-linux" = fetchurl { # url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl"; # hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; # }; }; in buildPythonPackage { buildPythonPackage rec { pname = "jax-cuda12-pjrt"; inherit version; pyproject = false; src = srcs.${stdenv.hostPlatform.system} or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}"); src = fetchPypi { pname = "jax_cuda12_pjrt"; inherit version; format = "wheel"; python = "py3"; dist = "py3"; platform = { x86_64-linux = "manylinux2014_x86_64"; aarch64-linux = "manylinux2014_aarch64"; } .${stdenv.hostPlatform.system}; hash = { x86_64-linux = "sha256-aDcb2cE1JEuJZjA5viCCVWmKdb7JhU1BnqPD+VfKRkY= "; aarch64-linux = "sha256-m/67BqOWFMtomfdzDqhWHxEVasgcuz7GiEpir7OxX/M="; } .${stdenv.hostPlatform.system}; }; nativeBuildInputs = [ autoPatchelfHook Loading Loading @@ -97,7 +98,7 @@ buildPythonPackage { sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.attrNames srcs; platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = Loading pkgs/development/python-modules/jax-cuda12-plugin/default.nix +8 −8 Original line number Diff line number Diff line Loading @@ -40,42 +40,42 @@ let "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; hash = "sha256-uiVVln+bbDgci075+wPQW8Vewl7P7lz+RcWs4099QVI="; hash = "sha256-pwDhcYI84lUQIALkDJR4j6ho8hYle30/BWjQn+dcEHs="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; hash = "sha256-YXGu0vSzvdX8E3gt4QcsamNPzhNzG3XQywpquPTm5lA="; hash = "sha256-UwrYUcpGKZHOgtsmrUfwKwjOvkg8nI0MADfp4np7Up8="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; hash = "sha256-qqcEpe9UdZXQItscHkh4oGdxFkEqk2DBFdZ/9LZOFZY="; hash = "sha256-DZ7O3mbEAlhwKkImHoaM21ahA1UafDyISzX1Mcms1I4="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; hash = "sha256-KY0tdo8QKbdKCx0BJw5Uk0nSw33Adlh5ZULNqWfre9M="; hash = "sha256-fNG0iKVKMInolYjMr2dwiZUsglKefQQD4LBQGZ5SVBg="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; hash = "sha256-IDDPEgjOTqcO5WysYd3SOfl5hpX8Obt3OcUKJdbp2kQ="; hash = "sha256-5w608IRpbD474StekJ7xIFyfVu/j3OzyYhvZtatZVNU="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; hash = "sha256-wlF6fCGG+HCIlGluJs+W69YLeHnOyjmLLEarso0slsg="; hash = "sha256-oqOvX5iIDYb40kartGpVLlou9J12e/xKdMjDV3UgB8Y="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; hash = "sha256-GGJZWyttgVZ50R4OiJ5SMYXuVKRtRuAiaJ9w/EVU3ZE="; hash = "sha256-6W891KlCUWroeMn2l+au/teOFI8JAYynPuKLI0JqfYo="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; hash = "sha256-If7BtWyYeD6gVpt0elZ1Hx+f8hh7SKzBHHANO/xeGjE="; hash = "sha256-o0LyznxLH1nUA/Zlo1qGuGUCU7sl3jRkf7IlxFzrCgQ="; }; }; in Loading Loading
pkgs/development/python-modules/distrax/default.nix +14 −0 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ lib, buildPythonPackage, fetchFromGitHub, fetchpatch, chex, jaxlib, numpy, Loading @@ -23,6 +24,15 @@ buildPythonPackage rec { hash = "sha256-A1aCL/I89Blg9sNmIWQru4QJteUTN6+bhgrEJPmCrM0="; }; patches = [ # TODO: remove at the next release (already on master) (fetchpatch { name = "fix-jax-0.6.0-compat"; url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch"; hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78="; }) ]; dependencies = [ chex jaxlib Loading Loading @@ -71,6 +81,10 @@ buildPythonPackage rec { ]; disabledTestPaths = [ # Since jax 0.6.0: # TypeError: <lambda>() got an unexpected keyword argument 'accuracy' "distrax/_src/bijectors/lambda_bijector_test.py" # TypeErrors "distrax/_src/bijectors/tfp_compatible_bijector_test.py" "distrax/_src/distributions/distribution_from_tfp_test.py" Loading
pkgs/development/python-modules/dm-haiku/default.nix +11 −0 Original line number Diff line number Diff line Loading @@ -58,6 +58,17 @@ let }) ]; # AttributeError: jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, and # see https://docs.jax.dev/en/latest/jax.extend.html for details. # Alrady on master: https://github.com/google-deepmind/dm-haiku/commit/cfe8480d253a93100bf5e2d24c40435a95399c96 # TODO: remove at the next release postPatch = '' substituteInPlace haiku/_src/jaxpr_info.py \ --replace-fail "jax.core.JaxprEqn" "jax.extend.core.JaxprEqn" \ --replace-fail "jax.core.Var" "jax.extend.core.Var" \ --replace-fail "jax.core.Jaxpr" "jax.extend.core.Jaxpr" ''; build-system = [ setuptools ]; dependencies = [ Loading
pkgs/development/python-modules/gymnasium/default.nix +9 −0 Original line number Diff line number Diff line Loading @@ -90,6 +90,15 @@ buildPythonPackage rec { ]; disabledTests = [ # Fails since jax 0.6.0 # Fixed on master https://github.com/Farama-Foundation/Gymnasium/commit/94019feee1a0f945b9569cddf62780f4e1a224a5 # TODO: un-skip at the next release "test_all_env_api" "test_env_determinism_rollout" "test_jax_to_numpy_wrapper" "test_pickle_env" "test_roundtripping" # Succeeds for most environments but `test_render_modes[Reacher-v4]` fails because it requires # OpenGL access which is not possible inside the sandbox. "test_render_mode" Loading
pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +22 −21 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ lib, stdenv, buildPythonPackage, fetchurl, fetchPypi, addDriverRunpath, autoPatchelfHook, pypaInstallHook, Loading Loading @@ -31,30 +31,31 @@ let ] ); # Find new releases at https://storage.googleapis.com/jax-releases # When upgrading, you can get these hashes from jaxlib/prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. # upstream does not distribute jax-cuda12-pjrt binaries for aarch64-linux srcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; hash = "sha256-xTeDBlaLoMgbIwp3ndMZTJ3RAzmrY2CugJKBCNN+f3U="; }; # "aarch64-linux" = fetchurl { # url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl"; # hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; # }; }; in buildPythonPackage { buildPythonPackage rec { pname = "jax-cuda12-pjrt"; inherit version; pyproject = false; src = srcs.${stdenv.hostPlatform.system} or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}"); src = fetchPypi { pname = "jax_cuda12_pjrt"; inherit version; format = "wheel"; python = "py3"; dist = "py3"; platform = { x86_64-linux = "manylinux2014_x86_64"; aarch64-linux = "manylinux2014_aarch64"; } .${stdenv.hostPlatform.system}; hash = { x86_64-linux = "sha256-aDcb2cE1JEuJZjA5viCCVWmKdb7JhU1BnqPD+VfKRkY= "; aarch64-linux = "sha256-m/67BqOWFMtomfdzDqhWHxEVasgcuz7GiEpir7OxX/M="; } .${stdenv.hostPlatform.system}; }; nativeBuildInputs = [ autoPatchelfHook Loading Loading @@ -97,7 +98,7 @@ buildPythonPackage { sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.attrNames srcs; platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = Loading
pkgs/development/python-modules/jax-cuda12-plugin/default.nix +8 −8 Original line number Diff line number Diff line Loading @@ -40,42 +40,42 @@ let "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; hash = "sha256-uiVVln+bbDgci075+wPQW8Vewl7P7lz+RcWs4099QVI="; hash = "sha256-pwDhcYI84lUQIALkDJR4j6ho8hYle30/BWjQn+dcEHs="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; hash = "sha256-YXGu0vSzvdX8E3gt4QcsamNPzhNzG3XQywpquPTm5lA="; hash = "sha256-UwrYUcpGKZHOgtsmrUfwKwjOvkg8nI0MADfp4np7Up8="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; hash = "sha256-qqcEpe9UdZXQItscHkh4oGdxFkEqk2DBFdZ/9LZOFZY="; hash = "sha256-DZ7O3mbEAlhwKkImHoaM21ahA1UafDyISzX1Mcms1I4="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; hash = "sha256-KY0tdo8QKbdKCx0BJw5Uk0nSw33Adlh5ZULNqWfre9M="; hash = "sha256-fNG0iKVKMInolYjMr2dwiZUsglKefQQD4LBQGZ5SVBg="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; hash = "sha256-IDDPEgjOTqcO5WysYd3SOfl5hpX8Obt3OcUKJdbp2kQ="; hash = "sha256-5w608IRpbD474StekJ7xIFyfVu/j3OzyYhvZtatZVNU="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; hash = "sha256-wlF6fCGG+HCIlGluJs+W69YLeHnOyjmLLEarso0slsg="; hash = "sha256-oqOvX5iIDYb40kartGpVLlou9J12e/xKdMjDV3UgB8Y="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; hash = "sha256-GGJZWyttgVZ50R4OiJ5SMYXuVKRtRuAiaJ9w/EVU3ZE="; hash = "sha256-6W891KlCUWroeMn2l+au/teOFI8JAYynPuKLI0JqfYo="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; hash = "sha256-If7BtWyYeD6gVpt0elZ1Hx+f8hh7SKzBHHANO/xeGjE="; hash = "sha256-o0LyznxLH1nUA/Zlo1qGuGUCU7sl3jRkf7IlxFzrCgQ="; }; }; in Loading