Loading pkgs/development/python-modules/jax-cuda12-plugin/default.nix 0 → 100644 +117 −0 Original line number Diff line number Diff line { lib, stdenv, buildPythonPackage, fetchPypi, autoAddDriverRunpath, autoPatchelfHook, pypaInstallHook, wheelUnpackHook, cudaPackages, python, jaxlib, jax-cuda12-pjrt, }: let inherit (cudaPackages) cudaVersion; inherit (jaxlib) version; getSrcFromPypi = { platform, dist, hash, }: fetchPypi { inherit version platform dist hash ; pname = "jax_cuda12_plugin"; format = "wheel"; python = dist; abi = dist; }; # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux srcs = { "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; hash = "sha256-nULpmc1k3VZ8FJ7Wj3k5K6iGRDZCGLtjbNzvoBl8kv4="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; hash = "sha256-cEZUOG8OYAoCgdquqViCqmekfttoOTthsbFzx+jKdKg="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; hash = "sha256-Ufas/3Ew63LrsCU039NYGg9eoGlx3lLX68Ia1Nh/5x4="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; hash = "sha256-CSKKTCtEO3aozZqOwikGAInEzINuBiSWh1ptb9xm0x8="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; }; in buildPythonPackage { pname = "jax-cuda12-plugin"; inherit version; pyproject = false; src = ( srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") ); nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook pypaInstallHook wheelUnpackHook ]; dependencies = [ jax-cuda12-pjrt ]; pythonImportsCheck = [ "jax_cuda12_plugin" ]; meta = { description = "JAX Plugin for CUDA12"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1") || true; }; } pkgs/top-level/python-packages.nix +2 −0 Original line number Diff line number Diff line Loading @@ -6562,6 +6562,8 @@ self: super: with self; { jax-cuda12-pjrt = callPackage ../development/python-modules/jax-cuda12-pjrt { }; jax-cuda12-plugin = callPackage ../development/python-modules/jax-cuda12-plugin { }; jax-jumpy = callPackage ../development/python-modules/jax-jumpy { }; jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { }; Loading Loading
pkgs/development/python-modules/jax-cuda12-plugin/default.nix 0 → 100644 +117 −0 Original line number Diff line number Diff line { lib, stdenv, buildPythonPackage, fetchPypi, autoAddDriverRunpath, autoPatchelfHook, pypaInstallHook, wheelUnpackHook, cudaPackages, python, jaxlib, jax-cuda12-pjrt, }: let inherit (cudaPackages) cudaVersion; inherit (jaxlib) version; getSrcFromPypi = { platform, dist, hash, }: fetchPypi { inherit version platform dist hash ; pname = "jax_cuda12_plugin"; format = "wheel"; python = dist; abi = dist; }; # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux srcs = { "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; hash = "sha256-nULpmc1k3VZ8FJ7Wj3k5K6iGRDZCGLtjbNzvoBl8kv4="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; hash = "sha256-cEZUOG8OYAoCgdquqViCqmekfttoOTthsbFzx+jKdKg="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; hash = "sha256-Ufas/3Ew63LrsCU039NYGg9eoGlx3lLX68Ia1Nh/5x4="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; hash = "sha256-CSKKTCtEO3aozZqOwikGAInEzINuBiSWh1ptb9xm0x8="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; }; }; in buildPythonPackage { pname = "jax-cuda12-plugin"; inherit version; pyproject = false; src = ( srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") ); nativeBuildInputs = [ autoAddDriverRunpath autoPatchelfHook pypaInstallHook wheelUnpackHook ]; dependencies = [ jax-cuda12-pjrt ]; pythonImportsCheck = [ "jax_cuda12_plugin" ]; meta = { description = "JAX Plugin for CUDA12"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1") || true; }; }
pkgs/top-level/python-packages.nix +2 −0 Original line number Diff line number Diff line Loading @@ -6562,6 +6562,8 @@ self: super: with self; { jax-cuda12-pjrt = callPackage ../development/python-modules/jax-cuda12-pjrt { }; jax-cuda12-plugin = callPackage ../development/python-modules/jax-cuda12-plugin { }; jax-jumpy = callPackage ../development/python-modules/jax-jumpy { }; jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { }; Loading