Loading pkgs/development/python-modules/jax/default.nix +77 −43 Original line number Diff line number Diff line { lib, config, stdenv, blas, lapack, buildPythonPackage, callPackage, setuptools, importlib-metadata, fetchFromGitHub, cudaSupport ? config.cudaSupport, # build-system setuptools, # dependencies jaxlib, jaxlib-bin, jaxlib-build, hypothesis, lapack, matplotlib, ml-dtypes, numpy, opt-einsum, scipy, # optional-dependencies jax-cuda12-plugin, # tests cloudpickle, hypothesis, matplotlib, pytestCheckHook, pytest-xdist, pythonOlder, scipy, stdenv, # passthru callPackage, jax, jaxlib-build, jaxlib-bin, }: let Loading @@ -27,38 +40,41 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.28"; version = "0.4.38"; pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/jax-v${version}"; hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; tag = "jax-v${version}"; hash = "sha256-H8I9Mkz6Ia1RxJmnuJOSevLGHN2J8ey59ZTlFx8YfnA="; }; nativeBuildInputs = [ setuptools ]; build-system = [ setuptools ]; # The version is automatically set to ".dev" if this variable is not set. # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 JAX_RELEASE = "1"; # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. propagatedBuildInputs = [ dependencies = [ jaxlib ml-dtypes numpy opt-einsum scipy ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; ] ++ lib.optionals cudaSupport optional-dependencies.cuda; optional-dependencies = rec { cuda = [ jax-cuda12-plugin ]; cuda12 = cuda; cuda12_pip = cuda; cuda12_local = cuda; }; nativeCheckInputs = [ cloudpickle hypothesis jaxlib matplotlib pytestCheckHook pytest-xdist Loading @@ -71,10 +87,16 @@ buildPythonPackage rec { # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. # Not a big deal, this is how the JAX docs suggest running the test suite # anyhow. pytestFlagsArray = [ pytestFlagsArray = [ "--numprocesses=4" "-W ignore::DeprecationWarning" "tests/" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "--deselect tests/shape_poly_test.py::ShapePolyTest" "--deselect tests/tree_util_test.py::TreeTest" ]; # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with Loading Loading @@ -125,9 +147,20 @@ buildPythonPackage rec { # Fails on some hardware due to some numerical error # See https://github.com/google/jax/issues/18535 "testQdwhWithOnRankDeficientInput5" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "testInAxesPyTreePrefixMismatchError" "testInAxesPyTreePrefixMismatchErrorKwargs" "testOutAxesPyTreePrefixMismatchError" "test_tree_map" "test_vjp_rule_inconsistent_pytree_structures_error" "test_vmap_in_axes_tree_prefix_error" "test_vmap_mismatched_axis_sizes_error_message_issue_705" ]; disabledTestPaths = [ disabledTestPaths = [ # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba "tests/linalg_test.py" ] Loading @@ -147,25 +180,26 @@ buildPythonPackage rec { # # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin passthru.tests = { test_cuda_jaxlibSource = callPackage ./test-cuda.nix { jaxlib = jaxlib-build.override { cudaSupport = true; }; }; # jaxlib-build is broken as of 2024-12-20 # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { # jax = jax.override { jaxlib = jaxlib-build; }; # }; test_cuda_jaxlibBin = callPackage ./test-cuda.nix { jaxlib = jaxlib-bin.override { cudaSupport = true; }; jax = jax.override { jaxlib = jaxlib-bin; }; }; }; # updater fails to pick the correct branch passthru.skipBulkUpdate = true; meta = with lib; { meta = { description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; longDescription = '' This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. ''; homepage = "https://github.com/google/jax"; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ samuela ]; }; } pkgs/development/python-modules/jax/test-cuda.nix +1 −3 Original line number Diff line number Diff line { jax, jaxlib, pkgs, }: Loading @@ -8,8 +7,7 @@ pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; ] ++ jax.optional-dependencies.cuda; } '' import jax Loading Loading
pkgs/development/python-modules/jax/default.nix +77 −43 Original line number Diff line number Diff line { lib, config, stdenv, blas, lapack, buildPythonPackage, callPackage, setuptools, importlib-metadata, fetchFromGitHub, cudaSupport ? config.cudaSupport, # build-system setuptools, # dependencies jaxlib, jaxlib-bin, jaxlib-build, hypothesis, lapack, matplotlib, ml-dtypes, numpy, opt-einsum, scipy, # optional-dependencies jax-cuda12-plugin, # tests cloudpickle, hypothesis, matplotlib, pytestCheckHook, pytest-xdist, pythonOlder, scipy, stdenv, # passthru callPackage, jax, jaxlib-build, jaxlib-bin, }: let Loading @@ -27,38 +40,41 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.28"; version = "0.4.38"; pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/jax-v${version}"; hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; tag = "jax-v${version}"; hash = "sha256-H8I9Mkz6Ia1RxJmnuJOSevLGHN2J8ey59ZTlFx8YfnA="; }; nativeBuildInputs = [ setuptools ]; build-system = [ setuptools ]; # The version is automatically set to ".dev" if this variable is not set. # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 JAX_RELEASE = "1"; # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. propagatedBuildInputs = [ dependencies = [ jaxlib ml-dtypes numpy opt-einsum scipy ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; ] ++ lib.optionals cudaSupport optional-dependencies.cuda; optional-dependencies = rec { cuda = [ jax-cuda12-plugin ]; cuda12 = cuda; cuda12_pip = cuda; cuda12_local = cuda; }; nativeCheckInputs = [ cloudpickle hypothesis jaxlib matplotlib pytestCheckHook pytest-xdist Loading @@ -71,10 +87,16 @@ buildPythonPackage rec { # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. # Not a big deal, this is how the JAX docs suggest running the test suite # anyhow. pytestFlagsArray = [ pytestFlagsArray = [ "--numprocesses=4" "-W ignore::DeprecationWarning" "tests/" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "--deselect tests/shape_poly_test.py::ShapePolyTest" "--deselect tests/tree_util_test.py::TreeTest" ]; # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with Loading Loading @@ -125,9 +147,20 @@ buildPythonPackage rec { # Fails on some hardware due to some numerical error # See https://github.com/google/jax/issues/18535 "testQdwhWithOnRankDeficientInput5" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "testInAxesPyTreePrefixMismatchError" "testInAxesPyTreePrefixMismatchErrorKwargs" "testOutAxesPyTreePrefixMismatchError" "test_tree_map" "test_vjp_rule_inconsistent_pytree_structures_error" "test_vmap_in_axes_tree_prefix_error" "test_vmap_mismatched_axis_sizes_error_message_issue_705" ]; disabledTestPaths = [ disabledTestPaths = [ # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba "tests/linalg_test.py" ] Loading @@ -147,25 +180,26 @@ buildPythonPackage rec { # # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin passthru.tests = { test_cuda_jaxlibSource = callPackage ./test-cuda.nix { jaxlib = jaxlib-build.override { cudaSupport = true; }; }; # jaxlib-build is broken as of 2024-12-20 # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { # jax = jax.override { jaxlib = jaxlib-build; }; # }; test_cuda_jaxlibBin = callPackage ./test-cuda.nix { jaxlib = jaxlib-bin.override { cudaSupport = true; }; jax = jax.override { jaxlib = jaxlib-bin; }; }; }; # updater fails to pick the correct branch passthru.skipBulkUpdate = true; meta = with lib; { meta = { description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; longDescription = '' This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. ''; homepage = "https://github.com/google/jax"; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ samuela ]; }; }
pkgs/development/python-modules/jax/test-cuda.nix +1 −3 Original line number Diff line number Diff line { jax, jaxlib, pkgs, }: Loading @@ -8,8 +7,7 @@ pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; ] ++ jax.optional-dependencies.cuda; } '' import jax Loading