Loading nixos/doc/manual/release-notes/rl-2505.section.md +5 −0 Original line number Diff line number Diff line Loading @@ -258,6 +258,11 @@ - `siduck76-st` has been renamed to `st-snazzy`, like the project's [flake](https://github.com/siduck/st/blob/main/flake.nix). - `python3Packages.jax` now directly depends on `python3Packages.jaxlib`. As a result, packages that depend on jax no longer need to include jaxlib to their dependencies. There is also a breaking change in the handling of CUDA. Instead of using a CUDA compatible jaxlib as before, you can use plugins like `python3Packages.jax-cuda12-plugin`. <!-- To avoid merge conflicts, consider adding your item at an arbitrary place in the list instead. --> Loading pkgs/development/python-modules/jax-cuda12-pjrt/default.nix 0 → 100644 +111 −0 Original line number Diff line number Diff line { lib, stdenv, buildPythonPackage, fetchurl, autoAddDriverRunpath, autoPatchelfHook, pypaInstallHook, wheelUnpackHook, cudaPackages, python, jaxlib, }: let inherit (jaxlib) version; inherit (cudaPackages) cudaVersion; cudaLibPath = lib.makeLibraryPath ( with cudaPackages; [ (lib.getLib cuda_cudart) # libcudart.so (lib.getLib cuda_cupti) # libcupti.so (lib.getLib cudnn) # libcudnn.so (lib.getLib libcufft) # libcufft.so (lib.getLib libcusolver) # libcusolver.so (lib.getLib libcusparse) # libcusparse.so ] ); # Find new releases at https://storage.googleapis.com/jax-releases # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. # upstream does not distribute jax-cuda12-pjrt 0.4.38 binaries for aarch64-linux srcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; hash = "sha256-g75MWfvPMAd6YAhdmOfVncc4sckeDWKOSsF3n94VrCs="; }; "aarch64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; }; in buildPythonPackage { pname = "jax-cuda12-pjrt"; inherit version; pyproject = false; src = srcs.${stdenv.hostPlatform.system} or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}"); nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook pypaInstallHook wheelUnpackHook ]; # The following attributes (buildInputs, postInstall and preInstallCheck) are copied from jaxlib-0.4.28 # but it does not recognize GPUs as of 2024-12-29 # Dynamic link dependencies buildInputs = [ (lib.getLib stdenv.cc.cc) ]; # jaxlib looks for ptxas at runtime, eg when running `jax.random.PRNGKey(0)`. # Linking into $out is the least bad solution. See # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 # for more info. postInstall = '' mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas ''; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or # autoPatchelfHook. That means we need to sneak them into rpath. This step # must be done after autoPatchelfHook and the automatic stripping of # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the # patchPhase. preInstallCheck = '' shopt -s globstar for file in $out/**/*.so; do echo $file patchelf --add-rpath "${cudaLibPath}" "$file" done ''; # no tests doCheck = false; pythonImportsCheck = [ "jax_plugins.xla_cuda12" ]; meta = { description = "JAX XLA PJRT Plugin for NVIDIA GPUs"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.attrNames srcs; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1") || true; }; } pkgs/development/python-modules/jax-cuda12-plugin/default.nix 0 → 100644 +117 −0 Original line number Diff line number Diff line { lib, stdenv, buildPythonPackage, fetchPypi, autoAddDriverRunpath, autoPatchelfHook, pypaInstallHook, wheelUnpackHook, cudaPackages, python, jaxlib, jax-cuda12-pjrt, }: let inherit (cudaPackages) cudaVersion; inherit (jaxlib) version; getSrcFromPypi = { platform, dist, hash, }: fetchPypi { inherit version platform dist hash ; pname = "jax_cuda12_plugin"; format = "wheel"; python = dist; abi = dist; }; # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux srcs = { "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; hash = "sha256-nULpmc1k3VZ8FJ7Wj3k5K6iGRDZCGLtjbNzvoBl8kv4="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; hash = "sha256-cEZUOG8OYAoCgdquqViCqmekfttoOTthsbFzx+jKdKg="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; hash = "sha256-Ufas/3Ew63LrsCU039NYGg9eoGlx3lLX68Ia1Nh/5x4="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; hash = "sha256-CSKKTCtEO3aozZqOwikGAInEzINuBiSWh1ptb9xm0x8="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; }; in buildPythonPackage { pname = "jax-cuda12-plugin"; inherit version; pyproject = false; src = ( srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") ); nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook pypaInstallHook wheelUnpackHook ]; dependencies = [ jax-cuda12-pjrt ]; pythonImportsCheck = [ "jax_cuda12_plugin" ]; meta = { description = "JAX Plugin for CUDA12"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1") || true; }; } pkgs/development/python-modules/jax/default.nix +77 −43 Original line number Diff line number Diff line { lib, config, stdenv, blas, lapack, buildPythonPackage, callPackage, setuptools, importlib-metadata, fetchFromGitHub, cudaSupport ? config.cudaSupport, # build-system setuptools, # dependencies jaxlib, jaxlib-bin, jaxlib-build, hypothesis, lapack, matplotlib, ml-dtypes, numpy, opt-einsum, scipy, # optional-dependencies jax-cuda12-plugin, # tests cloudpickle, hypothesis, matplotlib, pytestCheckHook, pytest-xdist, pythonOlder, scipy, stdenv, # passthru callPackage, jax, jaxlib-build, jaxlib-bin, }: let Loading @@ -27,38 +40,41 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.28"; version = "0.4.38"; pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/jax-v${version}"; hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; tag = "jax-v${version}"; hash = "sha256-H8I9Mkz6Ia1RxJmnuJOSevLGHN2J8ey59ZTlFx8YfnA="; }; nativeBuildInputs = [ setuptools ]; build-system = [ setuptools ]; # The version is automatically set to ".dev" if this variable is not set. # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 JAX_RELEASE = "1"; # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. propagatedBuildInputs = [ dependencies = [ jaxlib ml-dtypes numpy opt-einsum scipy ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; ] ++ lib.optionals cudaSupport optional-dependencies.cuda; optional-dependencies = rec { cuda = [ jax-cuda12-plugin ]; cuda12 = cuda; cuda12_pip = cuda; cuda12_local = cuda; }; nativeCheckInputs = [ cloudpickle hypothesis jaxlib matplotlib pytestCheckHook pytest-xdist Loading @@ -71,10 +87,16 @@ buildPythonPackage rec { # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. # Not a big deal, this is how the JAX docs suggest running the test suite # anyhow. pytestFlagsArray = [ pytestFlagsArray = [ "--numprocesses=4" "-W ignore::DeprecationWarning" "tests/" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "--deselect tests/shape_poly_test.py::ShapePolyTest" "--deselect tests/tree_util_test.py::TreeTest" ]; # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with Loading Loading @@ -125,9 +147,20 @@ buildPythonPackage rec { # Fails on some hardware due to some numerical error # See https://github.com/google/jax/issues/18535 "testQdwhWithOnRankDeficientInput5" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "testInAxesPyTreePrefixMismatchError" "testInAxesPyTreePrefixMismatchErrorKwargs" "testOutAxesPyTreePrefixMismatchError" "test_tree_map" "test_vjp_rule_inconsistent_pytree_structures_error" "test_vmap_in_axes_tree_prefix_error" "test_vmap_mismatched_axis_sizes_error_message_issue_705" ]; disabledTestPaths = [ disabledTestPaths = [ # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba "tests/linalg_test.py" ] Loading @@ -147,25 +180,26 @@ buildPythonPackage rec { # # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin passthru.tests = { test_cuda_jaxlibSource = callPackage ./test-cuda.nix { jaxlib = jaxlib-build.override { cudaSupport = true; }; }; # jaxlib-build is broken as of 2024-12-20 # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { # jax = jax.override { jaxlib = jaxlib-build; }; # }; test_cuda_jaxlibBin = callPackage ./test-cuda.nix { jaxlib = jaxlib-bin.override { cudaSupport = true; }; jax = jax.override { jaxlib = jaxlib-bin; }; }; }; # updater fails to pick the correct branch passthru.skipBulkUpdate = true; meta = with lib; { meta = { description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; longDescription = '' This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. ''; homepage = "https://github.com/google/jax"; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ samuela ]; }; } pkgs/development/python-modules/jax/test-cuda.nix +1 −3 Original line number Diff line number Diff line { jax, jaxlib, pkgs, }: Loading @@ -8,8 +7,7 @@ pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; ] ++ jax.optional-dependencies.cuda; } '' import jax Loading Loading
nixos/doc/manual/release-notes/rl-2505.section.md +5 −0 Original line number Diff line number Diff line Loading @@ -258,6 +258,11 @@ - `siduck76-st` has been renamed to `st-snazzy`, like the project's [flake](https://github.com/siduck/st/blob/main/flake.nix). - `python3Packages.jax` now directly depends on `python3Packages.jaxlib`. As a result, packages that depend on jax no longer need to include jaxlib to their dependencies. There is also a breaking change in the handling of CUDA. Instead of using a CUDA compatible jaxlib as before, you can use plugins like `python3Packages.jax-cuda12-plugin`. <!-- To avoid merge conflicts, consider adding your item at an arbitrary place in the list instead. --> Loading
pkgs/development/python-modules/jax-cuda12-pjrt/default.nix 0 → 100644 +111 −0 Original line number Diff line number Diff line { lib, stdenv, buildPythonPackage, fetchurl, autoAddDriverRunpath, autoPatchelfHook, pypaInstallHook, wheelUnpackHook, cudaPackages, python, jaxlib, }: let inherit (jaxlib) version; inherit (cudaPackages) cudaVersion; cudaLibPath = lib.makeLibraryPath ( with cudaPackages; [ (lib.getLib cuda_cudart) # libcudart.so (lib.getLib cuda_cupti) # libcupti.so (lib.getLib cudnn) # libcudnn.so (lib.getLib libcufft) # libcufft.so (lib.getLib libcusolver) # libcusolver.so (lib.getLib libcusparse) # libcusparse.so ] ); # Find new releases at https://storage.googleapis.com/jax-releases # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. # upstream does not distribute jax-cuda12-pjrt 0.4.38 binaries for aarch64-linux srcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; hash = "sha256-g75MWfvPMAd6YAhdmOfVncc4sckeDWKOSsF3n94VrCs="; }; "aarch64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; }; in buildPythonPackage { pname = "jax-cuda12-pjrt"; inherit version; pyproject = false; src = srcs.${stdenv.hostPlatform.system} or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}"); nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook pypaInstallHook wheelUnpackHook ]; # The following attributes (buildInputs, postInstall and preInstallCheck) are copied from jaxlib-0.4.28 # but it does not recognize GPUs as of 2024-12-29 # Dynamic link dependencies buildInputs = [ (lib.getLib stdenv.cc.cc) ]; # jaxlib looks for ptxas at runtime, eg when running `jax.random.PRNGKey(0)`. # Linking into $out is the least bad solution. See # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 # for more info. postInstall = '' mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas ''; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or # autoPatchelfHook. That means we need to sneak them into rpath. This step # must be done after autoPatchelfHook and the automatic stripping of # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the # patchPhase. preInstallCheck = '' shopt -s globstar for file in $out/**/*.so; do echo $file patchelf --add-rpath "${cudaLibPath}" "$file" done ''; # no tests doCheck = false; pythonImportsCheck = [ "jax_plugins.xla_cuda12" ]; meta = { description = "JAX XLA PJRT Plugin for NVIDIA GPUs"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.attrNames srcs; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1") || true; }; }
pkgs/development/python-modules/jax-cuda12-plugin/default.nix 0 → 100644 +117 −0 Original line number Diff line number Diff line { lib, stdenv, buildPythonPackage, fetchPypi, autoAddDriverRunpath, autoPatchelfHook, pypaInstallHook, wheelUnpackHook, cudaPackages, python, jaxlib, jax-cuda12-pjrt, }: let inherit (cudaPackages) cudaVersion; inherit (jaxlib) version; getSrcFromPypi = { platform, dist, hash, }: fetchPypi { inherit version platform dist hash ; pname = "jax_cuda12_plugin"; format = "wheel"; python = dist; abi = dist; }; # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux srcs = { "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; hash = "sha256-nULpmc1k3VZ8FJ7Wj3k5K6iGRDZCGLtjbNzvoBl8kv4="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; hash = "sha256-cEZUOG8OYAoCgdquqViCqmekfttoOTthsbFzx+jKdKg="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; hash = "sha256-Ufas/3Ew63LrsCU039NYGg9eoGlx3lLX68Ia1Nh/5x4="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; hash = "sha256-CSKKTCtEO3aozZqOwikGAInEzINuBiSWh1ptb9xm0x8="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; }; in buildPythonPackage { pname = "jax-cuda12-plugin"; inherit version; pyproject = false; src = ( srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") ); nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook pypaInstallHook wheelUnpackHook ]; dependencies = [ jax-cuda12-pjrt ]; pythonImportsCheck = [ "jax_cuda12_plugin" ]; meta = { description = "JAX Plugin for CUDA12"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1") || true; }; }
pkgs/development/python-modules/jax/default.nix +77 −43 Original line number Diff line number Diff line { lib, config, stdenv, blas, lapack, buildPythonPackage, callPackage, setuptools, importlib-metadata, fetchFromGitHub, cudaSupport ? config.cudaSupport, # build-system setuptools, # dependencies jaxlib, jaxlib-bin, jaxlib-build, hypothesis, lapack, matplotlib, ml-dtypes, numpy, opt-einsum, scipy, # optional-dependencies jax-cuda12-plugin, # tests cloudpickle, hypothesis, matplotlib, pytestCheckHook, pytest-xdist, pythonOlder, scipy, stdenv, # passthru callPackage, jax, jaxlib-build, jaxlib-bin, }: let Loading @@ -27,38 +40,41 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.28"; version = "0.4.38"; pyproject = true; disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/jax-v${version}"; hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; tag = "jax-v${version}"; hash = "sha256-H8I9Mkz6Ia1RxJmnuJOSevLGHN2J8ey59ZTlFx8YfnA="; }; nativeBuildInputs = [ setuptools ]; build-system = [ setuptools ]; # The version is automatically set to ".dev" if this variable is not set. # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 JAX_RELEASE = "1"; # jaxlib is _not_ included in propagatedBuildInputs because there are # different versions of jaxlib depending on the desired target hardware. The # JAX project ships separate wheels for CPU, GPU, and TPU. propagatedBuildInputs = [ dependencies = [ jaxlib ml-dtypes numpy opt-einsum scipy ] ++ lib.optional (pythonOlder "3.10") importlib-metadata; ] ++ lib.optionals cudaSupport optional-dependencies.cuda; optional-dependencies = rec { cuda = [ jax-cuda12-plugin ]; cuda12 = cuda; cuda12_pip = cuda; cuda12_local = cuda; }; nativeCheckInputs = [ cloudpickle hypothesis jaxlib matplotlib pytestCheckHook pytest-xdist Loading @@ -71,10 +87,16 @@ buildPythonPackage rec { # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. # Not a big deal, this is how the JAX docs suggest running the test suite # anyhow. pytestFlagsArray = [ pytestFlagsArray = [ "--numprocesses=4" "-W ignore::DeprecationWarning" "tests/" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "--deselect tests/shape_poly_test.py::ShapePolyTest" "--deselect tests/tree_util_test.py::TreeTest" ]; # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with Loading Loading @@ -125,9 +147,20 @@ buildPythonPackage rec { # Fails on some hardware due to some numerical error # See https://github.com/google/jax/issues/18535 "testQdwhWithOnRankDeficientInput5" ] ++ lib.optionals stdenv.hostPlatform.isDarwin [ # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! "testInAxesPyTreePrefixMismatchError" "testInAxesPyTreePrefixMismatchErrorKwargs" "testOutAxesPyTreePrefixMismatchError" "test_tree_map" "test_vjp_rule_inconsistent_pytree_structures_error" "test_vmap_in_axes_tree_prefix_error" "test_vmap_mismatched_axis_sizes_error_message_issue_705" ]; disabledTestPaths = [ disabledTestPaths = [ # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba "tests/linalg_test.py" ] Loading @@ -147,25 +180,26 @@ buildPythonPackage rec { # # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin passthru.tests = { test_cuda_jaxlibSource = callPackage ./test-cuda.nix { jaxlib = jaxlib-build.override { cudaSupport = true; }; }; # jaxlib-build is broken as of 2024-12-20 # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { # jax = jax.override { jaxlib = jaxlib-build; }; # }; test_cuda_jaxlibBin = callPackage ./test-cuda.nix { jaxlib = jaxlib-bin.override { cudaSupport = true; }; jax = jax.override { jaxlib = jaxlib-bin; }; }; }; # updater fails to pick the correct branch passthru.skipBulkUpdate = true; meta = with lib; { meta = { description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; longDescription = '' This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. ''; homepage = "https://github.com/google/jax"; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ samuela ]; }; }
pkgs/development/python-modules/jax/test-cuda.nix +1 −3 Original line number Diff line number Diff line { jax, jaxlib, pkgs, }: Loading @@ -8,8 +7,7 @@ pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; ] ++ jax.optional-dependencies.cuda; } '' import jax Loading