Unverified Commit 94f7e044 authored by Gaétan Lepage's avatar Gaétan Lepage Committed by GitHub
Browse files

python312Packages.jax[lib]: 0.5.3 -> 0.6.0 (#399406)

parents 843c644f 8873b3f1
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
  lib,
  buildPythonPackage,
  fetchFromGitHub,
  fetchpatch,
  chex,
  jaxlib,
  numpy,
@@ -23,6 +24,15 @@ buildPythonPackage rec {
    hash = "sha256-A1aCL/I89Blg9sNmIWQru4QJteUTN6+bhgrEJPmCrM0=";
  };

  patches = [
    # TODO: remove at the next release (already on master)
    (fetchpatch {
      name = "fix-jax-0.6.0-compat";
      url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch";
      hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78=";
    })
  ];

  dependencies = [
    chex
    jaxlib
@@ -71,6 +81,10 @@ buildPythonPackage rec {
  ];

  disabledTestPaths = [
    # Since jax 0.6.0:
    # TypeError: <lambda>() got an unexpected keyword argument 'accuracy'
    "distrax/_src/bijectors/lambda_bijector_test.py"

    # TypeErrors
    "distrax/_src/bijectors/tfp_compatible_bijector_test.py"
    "distrax/_src/distributions/distribution_from_tfp_test.py"
+11 −0
Original line number Diff line number Diff line
@@ -58,6 +58,17 @@ let
      })
    ];

    # AttributeError: jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, and
    # see https://docs.jax.dev/en/latest/jax.extend.html for details.
    # Alrady on master: https://github.com/google-deepmind/dm-haiku/commit/cfe8480d253a93100bf5e2d24c40435a95399c96
    # TODO: remove at the next release
    postPatch = ''
      substituteInPlace haiku/_src/jaxpr_info.py \
        --replace-fail "jax.core.JaxprEqn" "jax.extend.core.JaxprEqn" \
        --replace-fail "jax.core.Var" "jax.extend.core.Var" \
        --replace-fail "jax.core.Jaxpr" "jax.extend.core.Jaxpr"
    '';

    build-system = [ setuptools ];

    dependencies = [
+9 −0
Original line number Diff line number Diff line
@@ -90,6 +90,15 @@ buildPythonPackage rec {
  ];

  disabledTests = [
    # Fails since jax 0.6.0
    # Fixed on master https://github.com/Farama-Foundation/Gymnasium/commit/94019feee1a0f945b9569cddf62780f4e1a224a5
    # TODO: un-skip at the next release
    "test_all_env_api"
    "test_env_determinism_rollout"
    "test_jax_to_numpy_wrapper"
    "test_pickle_env"
    "test_roundtripping"

    # Succeeds for most environments but `test_render_modes[Reacher-v4]` fails because it requires
    # OpenGL access which is not possible inside the sandbox.
    "test_render_mode"
+22 −21
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@
  lib,
  stdenv,
  buildPythonPackage,
  fetchurl,
  fetchPypi,
  addDriverRunpath,
  autoPatchelfHook,
  pypaInstallHook,
@@ -31,30 +31,31 @@ let
    ]
  );

  # Find new releases at https://storage.googleapis.com/jax-releases
  # When upgrading, you can get these hashes from jaxlib/prefetch.sh. See
  # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.

  # upstream does not distribute jax-cuda12-pjrt binaries for aarch64-linux
  srcs = {
    "x86_64-linux" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl";
      hash = "sha256-xTeDBlaLoMgbIwp3ndMZTJ3RAzmrY2CugJKBCNN+f3U=";
    };
    # "aarch64-linux" = fetchurl {
    #   url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl";
    #   hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    # };
  };
in
buildPythonPackage {
buildPythonPackage rec {
  pname = "jax-cuda12-pjrt";
  inherit version;
  pyproject = false;

  src =
    srcs.${stdenv.hostPlatform.system}
      or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}");
  src = fetchPypi {
    pname = "jax_cuda12_pjrt";
    inherit version;
    format = "wheel";
    python = "py3";
    dist = "py3";
    platform =
      {
        x86_64-linux = "manylinux2014_x86_64";
        aarch64-linux = "manylinux2014_aarch64";
      }
      .${stdenv.hostPlatform.system};
    hash =
      {
        x86_64-linux = "sha256-aDcb2cE1JEuJZjA5viCCVWmKdb7JhU1BnqPD+VfKRkY= ";
        aarch64-linux = "sha256-m/67BqOWFMtomfdzDqhWHxEVasgcuz7GiEpir7OxX/M=";
      }
      .${stdenv.hostPlatform.system};
  };

  nativeBuildInputs = [
    autoPatchelfHook
@@ -97,7 +98,7 @@ buildPythonPackage {
    sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ natsukium ];
    platforms = lib.attrNames srcs;
    platforms = lib.platforms.linux;
    # see CUDA compatibility matrix
    # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
    broken =
+8 −8
Original line number Diff line number Diff line
@@ -40,42 +40,42 @@ let
    "3.10-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp310";
      hash = "sha256-uiVVln+bbDgci075+wPQW8Vewl7P7lz+RcWs4099QVI=";
      hash = "sha256-pwDhcYI84lUQIALkDJR4j6ho8hYle30/BWjQn+dcEHs=";
    };
    "3.10-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp310";
      hash = "sha256-YXGu0vSzvdX8E3gt4QcsamNPzhNzG3XQywpquPTm5lA=";
      hash = "sha256-UwrYUcpGKZHOgtsmrUfwKwjOvkg8nI0MADfp4np7Up8=";
    };
    "3.11-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp311";
      hash = "sha256-qqcEpe9UdZXQItscHkh4oGdxFkEqk2DBFdZ/9LZOFZY=";
      hash = "sha256-DZ7O3mbEAlhwKkImHoaM21ahA1UafDyISzX1Mcms1I4=";
    };
    "3.11-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp311";
      hash = "sha256-KY0tdo8QKbdKCx0BJw5Uk0nSw33Adlh5ZULNqWfre9M=";
      hash = "sha256-fNG0iKVKMInolYjMr2dwiZUsglKefQQD4LBQGZ5SVBg=";
    };
    "3.12-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp312";
      hash = "sha256-IDDPEgjOTqcO5WysYd3SOfl5hpX8Obt3OcUKJdbp2kQ=";
      hash = "sha256-5w608IRpbD474StekJ7xIFyfVu/j3OzyYhvZtatZVNU=";
    };
    "3.12-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp312";
      hash = "sha256-wlF6fCGG+HCIlGluJs+W69YLeHnOyjmLLEarso0slsg=";
      hash = "sha256-oqOvX5iIDYb40kartGpVLlou9J12e/xKdMjDV3UgB8Y=";
    };
    "3.13-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp313";
      hash = "sha256-GGJZWyttgVZ50R4OiJ5SMYXuVKRtRuAiaJ9w/EVU3ZE=";
      hash = "sha256-6W891KlCUWroeMn2l+au/teOFI8JAYynPuKLI0JqfYo=";
    };
    "3.13-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp313";
      hash = "sha256-If7BtWyYeD6gVpt0elZ1Hx+f8hh7SKzBHHANO/xeGjE=";
      hash = "sha256-o0LyznxLH1nUA/Zlo1qGuGUCU7sl3jRkf7IlxFzrCgQ=";
    };
  };
in
Loading