Unverified Commit 59f337a6 authored by Samuel Ainsworth's avatar Samuel Ainsworth Committed by GitHub
Browse files

python3Packages.jax: add missing cuda libraries (#375186)

parents 3db65802 f9c16aa4
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
@@ -18,10 +18,16 @@ let
  cudaLibPath = lib.makeLibraryPath (
    with cudaPackages;
    [
      (lib.getLib libcublas) # libcublas.so
      (lib.getLib cuda_cupti) # libcupti.so
      (lib.getLib cuda_cudart) # libcudart.so
      (lib.getLib cudnn) # libcudnn.so
      (lib.getLib libcublas) # libcublas.so
      addDriverRunpath.driverLink # libcuda.so
      (lib.getLib libcufft) # libcufft.so
      (lib.getLib libcusolver) # libcusolver.so
      (lib.getLib libcusparse) # libcusparse.so
      (lib.getLib nccl) # libnccl.so
      (lib.getLib libnvjitlink) # libnvJitLink.so
      (lib.getLib addDriverRunpath.driverLink) # libcuda.so
    ]
  );

@@ -83,6 +89,8 @@ buildPythonPackage {

  pythonImportsCheck = [ "jax_plugins" ];

  inherit cudaLibPath;

  meta = {
    description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
    homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
+26 −3
Original line number Diff line number Diff line
@@ -12,8 +12,9 @@
  jax-cuda12-pjrt,
}:
let
  inherit (cudaPackages) cudaVersion;
  inherit (jaxlib) version;
  inherit (cudaPackages) cudaVersion;
  inherit (jax-cuda12-pjrt) cudaLibPath;

  getSrcFromPypi =
    {
@@ -94,12 +95,34 @@ buildPythonPackage {
    wheelUnpackHook
  ];

  # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
  # Linking into $out is the least bad solution. See
  # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
  # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
  # * https://github.com/NixOS/nixpkgs/pull/375186
  # for more info.
  postInstall = ''
    mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
  '';

  # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
  # and these implicit dependencies are not recognized by ldd or
  # autoPatchelfHook. That means we need to sneak them into rpath. This step
  # must be done after autoPatchelfHook and the automatic stripping of
  # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
  # patchPhase.
  preInstallCheck = ''
    patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
  '';

  dependencies = [ jax-cuda12-pjrt ];

  pythonImportsCheck = [ "jax_cuda12_plugin" ];

  # no tests
  doCheck = false;
  # FIXME: there are no tests, but we need to run preInstallCheck above
  doCheck = true;

  meta = {
    description = "JAX Plugin for CUDA12";
+0 −4
Original line number Diff line number Diff line
@@ -198,10 +198,6 @@ buildPythonPackage rec {

  meta = {
    description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
    longDescription = ''
      This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
      e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
    '';
    homepage = "https://github.com/google/jax";
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ samuela ];
+8 −3
Original line number Diff line number Diff line
@@ -11,13 +11,18 @@ pkgs.writers.writePython3Bin "jax-test-cuda"
  }
  ''
    import jax
    import jax.numpy as jnp
    from jax import random
    from jax.experimental import sparse

    assert jax.devices()[0].platform == "gpu"
    assert jax.devices()[0].platform == "gpu"  # libcuda.so

    rng = random.PRNGKey(0)
    rng = random.key(0)  # libcudart.so, libcudnn.so
    x = random.normal(rng, (100, 100))
    x @ x
    x @ x  # libcublas.so
    jnp.fft.fft(x)  # libcufft.so
    jnp.linalg.inv(x)  # libcusolver.so
    sparse.CSR.fromdense(x) @ x  # libcusparse.so

    print("success!")
  ''