Loading pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +10 −2 Original line number Diff line number Diff line Loading @@ -18,10 +18,16 @@ let cudaLibPath = lib.makeLibraryPath ( with cudaPackages; [ (lib.getLib libcublas) # libcublas.so (lib.getLib cuda_cupti) # libcupti.so (lib.getLib cuda_cudart) # libcudart.so (lib.getLib cudnn) # libcudnn.so (lib.getLib libcublas) # libcublas.so addDriverRunpath.driverLink # libcuda.so (lib.getLib libcufft) # libcufft.so (lib.getLib libcusolver) # libcusolver.so (lib.getLib libcusparse) # libcusparse.so (lib.getLib nccl) # libnccl.so (lib.getLib libnvjitlink) # libnvJitLink.so (lib.getLib addDriverRunpath.driverLink) # libcuda.so ] ); Loading Loading @@ -83,6 +89,8 @@ buildPythonPackage { pythonImportsCheck = [ "jax_plugins" ]; inherit cudaLibPath; meta = { description = "JAX XLA PJRT Plugin for NVIDIA GPUs"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; Loading pkgs/development/python-modules/jax-cuda12-plugin/default.nix +26 −3 Original line number Diff line number Diff line Loading @@ -12,8 +12,9 @@ jax-cuda12-pjrt, }: let inherit (cudaPackages) cudaVersion; inherit (jaxlib) version; inherit (cudaPackages) cudaVersion; inherit (jax-cuda12-pjrt) cudaLibPath; getSrcFromPypi = { Loading Loading @@ -94,12 +95,34 @@ buildPythonPackage { wheelUnpackHook ]; # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. # Linking into $out is the least bad solution. See # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 # * https://github.com/NixOS/nixpkgs/pull/375186 # for more info. postInstall = '' mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin ''; # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or # autoPatchelfHook. That means we need to sneak them into rpath. This step # must be done after autoPatchelfHook and the automatic stripping of # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the # patchPhase. preInstallCheck = '' patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so ''; dependencies = [ jax-cuda12-pjrt ]; pythonImportsCheck = [ "jax_cuda12_plugin" ]; # no tests doCheck = false; # FIXME: there are no tests, but we need to run preInstallCheck above doCheck = true; meta = { description = "JAX Plugin for CUDA12"; Loading pkgs/development/python-modules/jax/default.nix +0 −4 Original line number Diff line number Diff line Loading @@ -198,10 +198,6 @@ buildPythonPackage rec { meta = { description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; longDescription = '' This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. ''; homepage = "https://github.com/google/jax"; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ samuela ]; Loading pkgs/development/python-modules/jax/test-cuda.nix +8 −3 Original line number Diff line number Diff line Loading @@ -11,13 +11,18 @@ pkgs.writers.writePython3Bin "jax-test-cuda" } '' import jax import jax.numpy as jnp from jax import random from jax.experimental import sparse assert jax.devices()[0].platform == "gpu" assert jax.devices()[0].platform == "gpu" # libcuda.so rng = random.PRNGKey(0) rng = random.key(0) # libcudart.so, libcudnn.so x = random.normal(rng, (100, 100)) x @ x x @ x # libcublas.so jnp.fft.fft(x) # libcufft.so jnp.linalg.inv(x) # libcusolver.so sparse.CSR.fromdense(x) @ x # libcusparse.so print("success!") '' Loading
pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +10 −2 Original line number Diff line number Diff line Loading @@ -18,10 +18,16 @@ let cudaLibPath = lib.makeLibraryPath ( with cudaPackages; [ (lib.getLib libcublas) # libcublas.so (lib.getLib cuda_cupti) # libcupti.so (lib.getLib cuda_cudart) # libcudart.so (lib.getLib cudnn) # libcudnn.so (lib.getLib libcublas) # libcublas.so addDriverRunpath.driverLink # libcuda.so (lib.getLib libcufft) # libcufft.so (lib.getLib libcusolver) # libcusolver.so (lib.getLib libcusparse) # libcusparse.so (lib.getLib nccl) # libnccl.so (lib.getLib libnvjitlink) # libnvJitLink.so (lib.getLib addDriverRunpath.driverLink) # libcuda.so ] ); Loading Loading @@ -83,6 +89,8 @@ buildPythonPackage { pythonImportsCheck = [ "jax_plugins" ]; inherit cudaLibPath; meta = { description = "JAX XLA PJRT Plugin for NVIDIA GPUs"; homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; Loading
pkgs/development/python-modules/jax-cuda12-plugin/default.nix +26 −3 Original line number Diff line number Diff line Loading @@ -12,8 +12,9 @@ jax-cuda12-pjrt, }: let inherit (cudaPackages) cudaVersion; inherit (jaxlib) version; inherit (cudaPackages) cudaVersion; inherit (jax-cuda12-pjrt) cudaLibPath; getSrcFromPypi = { Loading Loading @@ -94,12 +95,34 @@ buildPythonPackage { wheelUnpackHook ]; # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. # Linking into $out is the least bad solution. See # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 # * https://github.com/NixOS/nixpkgs/pull/375186 # for more info. postInstall = '' mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin ''; # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or # autoPatchelfHook. That means we need to sneak them into rpath. This step # must be done after autoPatchelfHook and the automatic stripping of # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the # patchPhase. preInstallCheck = '' patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so ''; dependencies = [ jax-cuda12-pjrt ]; pythonImportsCheck = [ "jax_cuda12_plugin" ]; # no tests doCheck = false; # FIXME: there are no tests, but we need to run preInstallCheck above doCheck = true; meta = { description = "JAX Plugin for CUDA12"; Loading
pkgs/development/python-modules/jax/default.nix +0 −4 Original line number Diff line number Diff line Loading @@ -198,10 +198,6 @@ buildPythonPackage rec { meta = { description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; longDescription = '' This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. ''; homepage = "https://github.com/google/jax"; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ samuela ]; Loading
pkgs/development/python-modules/jax/test-cuda.nix +8 −3 Original line number Diff line number Diff line Loading @@ -11,13 +11,18 @@ pkgs.writers.writePython3Bin "jax-test-cuda" } '' import jax import jax.numpy as jnp from jax import random from jax.experimental import sparse assert jax.devices()[0].platform == "gpu" assert jax.devices()[0].platform == "gpu" # libcuda.so rng = random.PRNGKey(0) rng = random.key(0) # libcudart.so, libcudnn.so x = random.normal(rng, (100, 100)) x @ x x @ x # libcublas.so jnp.fft.fft(x) # libcufft.so jnp.linalg.inv(x) # libcusolver.so sparse.CSR.fromdense(x) @ x # libcusparse.so print("success!") ''