Commit 7a10de4a authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python3Packages.jax: 0.4.5 -> 0.4.12

parent fe108f01
Loading
Loading
Loading
Loading
+6 −16
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@
, jaxlib-bin
, lapack
, matplotlib
, ml-dtypes
, numpy
, opt-einsum
, pytestCheckHook
@@ -27,7 +28,7 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.5";
  version = "0.4.12";
  format = "setuptools";

  disabled = pythonOlder "3.7";
@@ -37,7 +38,7 @@ buildPythonPackage rec {
    repo = pname;
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/${pname}-v${version}";
    hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
    hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y=";
  };

  # jaxlib is _not_ included in propagatedBuildInputs because there are
@@ -46,6 +47,7 @@ buildPythonPackage rec {
  propagatedBuildInputs = [
    absl-py
    etils
    ml-dtypes
    numpy
    opt-einsum
    scipy
@@ -96,24 +98,12 @@ buildPythonPackage rec {
    "testScanGrad_jit_scan"
  ];

  # See https://github.com/google/jax/issues/11722. This is a temporary fix in
  # order to unblock etils, and upgrading jax/jaxlib to the latest version. See
  # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
  disabledTestPaths = [
    "tests/api_test.py"
    "tests/core_test.py"
    "tests/lax_numpy_indexing_test.py"
    "tests/lax_numpy_test.py"
    "tests/nn_test.py"
    "tests/random_test.py"
    "tests/sparse_test.py"
  ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
  disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
    # RuntimeWarning: invalid value encountered in cast
    "tests/lax_test.py"
  ];

  # As of 0.3.22, `import jax` does not work without jaxlib being installed.
  pythonImportsCheck = [ ];
  pythonImportsCheck = [ "jax" ];

  meta = with lib; {
    description = "Differentiate, compile, and transform Numpy code";