Unverified Commit 533e2d71 authored by natsukium's avatar natsukium
Browse files

python312Packages.jax-cuda12-plugin: init at 0.4.38

parent ad4d314c
Loading
Loading
Loading
Loading
+117 −0
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchPypi,
  autoAddDriverRunpath,
  autoPatchelfHook,
  pypaInstallHook,
  wheelUnpackHook,
  cudaPackages,
  python,
  jaxlib,
  jax-cuda12-pjrt,
}:
let
  inherit (cudaPackages) cudaVersion;
  inherit (jaxlib) version;

  getSrcFromPypi =
    {
      platform,
      dist,
      hash,
    }:
    fetchPypi {
      inherit
        version
        platform
        dist
        hash
        ;
      pname = "jax_cuda12_plugin";
      format = "wheel";
      python = dist;
      abi = dist;
    };

  # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux
  srcs = {
    "3.10-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp310";
      hash = "sha256-nULpmc1k3VZ8FJ7Wj3k5K6iGRDZCGLtjbNzvoBl8kv4=";
    };
    "3.10-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp310";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
    "3.11-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp311";
      hash = "sha256-cEZUOG8OYAoCgdquqViCqmekfttoOTthsbFzx+jKdKg=";
    };
    "3.11-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp311";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
    "3.12-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp312";
      hash = "sha256-Ufas/3Ew63LrsCU039NYGg9eoGlx3lLX68Ia1Nh/5x4=";
    };
    "3.12-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp312";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
    "3.13-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp313";
      hash = "sha256-CSKKTCtEO3aozZqOwikGAInEzINuBiSWh1ptb9xm0x8=";
    };
    "3.13-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp313";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
  };
in
buildPythonPackage {
  pname = "jax-cuda12-plugin";
  inherit version;
  pyproject = false;

  src = (
    srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}"
      or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}")
  );

  nativeBuildInputs = [
    autoAddDriverRunpath
    autoPatchelfHook
    pypaInstallHook
    wheelUnpackHook
  ];

  dependencies = [ jax-cuda12-pjrt ];

  pythonImportsCheck = [ "jax_cuda12_plugin" ];

  meta = {
    description = "JAX Plugin for CUDA12";
    homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
    sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ natsukium ];
    platforms = lib.platforms.linux;
    # see CUDA compatibility matrix
    # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
    broken =
      !(lib.versionAtLeast cudaVersion "12.1")
      || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1")
      || true;
  };
}
+2 −0
Original line number Diff line number Diff line
@@ -6562,6 +6562,8 @@ self: super: with self; {
  jax-cuda12-pjrt = callPackage ../development/python-modules/jax-cuda12-pjrt { };
  jax-cuda12-plugin = callPackage ../development/python-modules/jax-cuda12-plugin { };
  jax-jumpy = callPackage ../development/python-modules/jax-jumpy { };
  jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { };