Loading pkgs/development/python-modules/jax/default.nix +15 −25 Original line number Diff line number Diff line { lib , absl-py , blas , buildPythonPackage , etils , setuptools , importlib-metadata , fetchFromGitHub , jaxlib , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -15,7 +16,6 @@ , pythonOlder , scipy , stdenv , typing-extensions }: let Loading @@ -27,30 +27,32 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.5"; format = "setuptools"; version = "0.4.14"; format = "pyproject"; disabled = pythonOlder "3.7"; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg="; }; nativeBuildInputs = [ setuptools ]; # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy typing-extensions ] ++ etils.optional-dependencies.epath; ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; nativeCheckInputs = [ jaxlib' Loading Loading @@ -96,24 +98,12 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; pythonImportsCheck = [ "jax" ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading Loading
pkgs/development/python-modules/jax/default.nix +15 −25 Original line number Diff line number Diff line { lib , absl-py , blas , buildPythonPackage , etils , setuptools , importlib-metadata , fetchFromGitHub , jaxlib , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -15,7 +16,6 @@ , pythonOlder , scipy , stdenv , typing-extensions }: let Loading @@ -27,30 +27,32 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.5"; format = "setuptools"; version = "0.4.14"; format = "pyproject"; disabled = pythonOlder "3.7"; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg="; }; nativeBuildInputs = [ setuptools ]; # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy typing-extensions ] ++ etils.optional-dependencies.epath; ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; nativeCheckInputs = [ jaxlib' Loading Loading @@ -96,24 +98,12 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; pythonImportsCheck = [ "jax" ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading