Loading pkgs/development/python-modules/jax/default.nix +6 −16 Original line number Diff line number Diff line Loading @@ -8,6 +8,7 @@ , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -27,7 +28,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.5"; version = "0.4.12"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -37,7 +38,7 @@ buildPythonPackage rec { repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading @@ -46,6 +47,7 @@ buildPythonPackage rec { propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy Loading Loading @@ -96,24 +98,12 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; pythonImportsCheck = [ "jax" ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading Loading
pkgs/development/python-modules/jax/default.nix +6 −16 Original line number Diff line number Diff line Loading @@ -8,6 +8,7 @@ , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -27,7 +28,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.5"; version = "0.4.12"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -37,7 +38,7 @@ buildPythonPackage rec { repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading @@ -46,6 +47,7 @@ buildPythonPackage rec { propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy Loading Loading @@ -96,24 +98,12 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; pythonImportsCheck = [ "jax" ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading