Unverified Commit 65b07370 authored by Gaétan Lepage's avatar Gaétan Lepage Committed by GitHub
Browse files

python312Packages.{jax,jaxlib}: 0.4.28 -> 0.4.38; use jaxlib-bin by default (#363130)

parents fab17649 b070bad6
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -258,6 +258,11 @@

- `siduck76-st` has been renamed to `st-snazzy`, like the project's [flake](https://github.com/siduck/st/blob/main/flake.nix).

- `python3Packages.jax` now directly depends on `python3Packages.jaxlib`.
  As a result, packages that depend on jax no longer need to include jaxlib to their dependencies.
  There is also a breaking change in the handling of CUDA. Instead of using a CUDA compatible jaxlib
  as before, you can use plugins like `python3Packages.jax-cuda12-plugin`.


<!-- To avoid merge conflicts, consider adding your item at an arbitrary place in the list instead. -->

+111 −0
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchurl,
  autoAddDriverRunpath,
  autoPatchelfHook,
  pypaInstallHook,
  wheelUnpackHook,
  cudaPackages,
  python,
  jaxlib,
}:
let
  inherit (jaxlib) version;
  inherit (cudaPackages) cudaVersion;

  cudaLibPath = lib.makeLibraryPath (
    with cudaPackages;
    [
      (lib.getLib cuda_cudart) # libcudart.so
      (lib.getLib cuda_cupti) # libcupti.so
      (lib.getLib cudnn) # libcudnn.so
      (lib.getLib libcufft) # libcufft.so
      (lib.getLib libcusolver) # libcusolver.so
      (lib.getLib libcusparse) # libcusparse.so
    ]
  );

  # Find new releases at https://storage.googleapis.com/jax-releases
  # When upgrading, you can get these hashes from prefetch.sh. See
  # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.

  # upstream does not distribute jax-cuda12-pjrt 0.4.38 binaries for aarch64-linux
  srcs = {
    "x86_64-linux" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl";
      hash = "sha256-g75MWfvPMAd6YAhdmOfVncc4sckeDWKOSsF3n94VrCs=";
    };
    "aarch64-linux" = fetchurl {
      url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
  };
in
buildPythonPackage {
  pname = "jax-cuda12-pjrt";
  inherit version;
  pyproject = false;

  src =
    srcs.${stdenv.hostPlatform.system}
      or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}");

  nativeBuildInputs = [
    autoAddDriverRunpath
    autoPatchelfHook
    pypaInstallHook
    wheelUnpackHook
  ];

  # The following attributes (buildInputs, postInstall and preInstallCheck) are copied from jaxlib-0.4.28
  # but it does not recognize GPUs as of 2024-12-29
  # Dynamic link dependencies
  buildInputs = [ (lib.getLib stdenv.cc.cc) ];

  # jaxlib looks for ptxas at runtime, eg when running `jax.random.PRNGKey(0)`.
  # Linking into $out is the least bad solution. See
  # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
  # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
  # for more info.
  postInstall = ''
    mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
  '';

  # jaxlib contains shared libraries that open other shared libraries via dlopen
  # and these implicit dependencies are not recognized by ldd or
  # autoPatchelfHook. That means we need to sneak them into rpath. This step
  # must be done after autoPatchelfHook and the automatic stripping of
  # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
  # patchPhase.
  preInstallCheck = ''
    shopt -s globstar

    for file in $out/**/*.so; do
      echo $file
      patchelf --add-rpath "${cudaLibPath}" "$file"
    done
  '';

  # no tests
  doCheck = false;

  pythonImportsCheck = [ "jax_plugins.xla_cuda12" ];

  meta = {
    description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
    homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
    sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ natsukium ];
    platforms = lib.attrNames srcs;
    # see CUDA compatibility matrix
    # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
    broken =
      !(lib.versionAtLeast cudaVersion "12.1")
      || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1")
      || true;
  };
}
+117 −0
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchPypi,
  autoAddDriverRunpath,
  autoPatchelfHook,
  pypaInstallHook,
  wheelUnpackHook,
  cudaPackages,
  python,
  jaxlib,
  jax-cuda12-pjrt,
}:
let
  inherit (cudaPackages) cudaVersion;
  inherit (jaxlib) version;

  getSrcFromPypi =
    {
      platform,
      dist,
      hash,
    }:
    fetchPypi {
      inherit
        version
        platform
        dist
        hash
        ;
      pname = "jax_cuda12_plugin";
      format = "wheel";
      python = dist;
      abi = dist;
    };

  # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux
  srcs = {
    "3.10-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp310";
      hash = "sha256-nULpmc1k3VZ8FJ7Wj3k5K6iGRDZCGLtjbNzvoBl8kv4=";
    };
    "3.10-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp310";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
    "3.11-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp311";
      hash = "sha256-cEZUOG8OYAoCgdquqViCqmekfttoOTthsbFzx+jKdKg=";
    };
    "3.11-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp311";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
    "3.12-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp312";
      hash = "sha256-Ufas/3Ew63LrsCU039NYGg9eoGlx3lLX68Ia1Nh/5x4=";
    };
    "3.12-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp312";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
    "3.13-x86_64-linux" = getSrcFromPypi {
      platform = "manylinux2014_x86_64";
      dist = "cp313";
      hash = "sha256-CSKKTCtEO3aozZqOwikGAInEzINuBiSWh1ptb9xm0x8=";
    };
    "3.13-aarch64-linux" = getSrcFromPypi {
      platform = "manylinux2014_aarch64";
      dist = "cp313";
      hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
    };
  };
in
buildPythonPackage {
  pname = "jax-cuda12-plugin";
  inherit version;
  pyproject = false;

  src = (
    srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}"
      or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}")
  );

  nativeBuildInputs = [
    autoAddDriverRunpath
    autoPatchelfHook
    pypaInstallHook
    wheelUnpackHook
  ];

  dependencies = [ jax-cuda12-pjrt ];

  pythonImportsCheck = [ "jax_cuda12_plugin" ];

  meta = {
    description = "JAX Plugin for CUDA12";
    homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
    sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ natsukium ];
    platforms = lib.platforms.linux;
    # see CUDA compatibility matrix
    # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
    broken =
      !(lib.versionAtLeast cudaVersion "12.1")
      || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1")
      || true;
  };
}
+77 −43
Original line number Diff line number Diff line
{
  lib,
  config,
  stdenv,
  blas,
  lapack,
  buildPythonPackage,
  callPackage,
  setuptools,
  importlib-metadata,
  fetchFromGitHub,
  cudaSupport ? config.cudaSupport,

  # build-system
  setuptools,

  # dependencies
  jaxlib,
  jaxlib-bin,
  jaxlib-build,
  hypothesis,
  lapack,
  matplotlib,
  ml-dtypes,
  numpy,
  opt-einsum,
  scipy,

  # optional-dependencies
  jax-cuda12-plugin,

  # tests
  cloudpickle,
  hypothesis,
  matplotlib,
  pytestCheckHook,
  pytest-xdist,
  pythonOlder,
  scipy,
  stdenv,

  # passthru
  callPackage,
  jax,
  jaxlib-build,
  jaxlib-bin,
}:

let
@@ -27,38 +40,41 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.28";
  version = "0.4.38";
  pyproject = true;

  disabled = pythonOlder "3.9";

  src = fetchFromGitHub {
    owner = "google";
    repo = "jax";
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/jax-v${version}";
    hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
    tag = "jax-v${version}";
    hash = "sha256-H8I9Mkz6Ia1RxJmnuJOSevLGHN2J8ey59ZTlFx8YfnA=";
  };

  nativeBuildInputs = [ setuptools ];
  build-system = [ setuptools ];

  # The version is automatically set to ".dev" if this variable is not set.
  # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
  JAX_RELEASE = "1";

  # jaxlib is _not_ included in propagatedBuildInputs because there are
  # different versions of jaxlib depending on the desired target hardware. The
  # JAX project ships separate wheels for CPU, GPU, and TPU.
  propagatedBuildInputs = [
  dependencies = [
    jaxlib
    ml-dtypes
    numpy
    opt-einsum
    scipy
  ] ++ lib.optional (pythonOlder "3.10") importlib-metadata;
  ] ++ lib.optionals cudaSupport optional-dependencies.cuda;

  optional-dependencies = rec {
    cuda = [ jax-cuda12-plugin ];
    cuda12 = cuda;
    cuda12_pip = cuda;
    cuda12_local = cuda;
  };

  nativeCheckInputs = [
    cloudpickle
    hypothesis
    jaxlib
    matplotlib
    pytestCheckHook
    pytest-xdist
@@ -71,10 +87,16 @@ buildPythonPackage rec {
  # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
  # Not a big deal, this is how the JAX docs suggest running the test suite
  # anyhow.
  pytestFlagsArray = [
  pytestFlagsArray =
    [
      "--numprocesses=4"
      "-W ignore::DeprecationWarning"
      "tests/"
    ]
    ++ lib.optionals stdenv.hostPlatform.isDarwin [
      # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
      "--deselect tests/shape_poly_test.py::ShapePolyTest"
      "--deselect tests/tree_util_test.py::TreeTest"
    ];

  # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
@@ -125,9 +147,20 @@ buildPythonPackage rec {
      # Fails on some hardware due to some numerical error
      # See https://github.com/google/jax/issues/18535
      "testQdwhWithOnRankDeficientInput5"
    ]
    ++ lib.optionals stdenv.hostPlatform.isDarwin [
      # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
      "testInAxesPyTreePrefixMismatchError"
      "testInAxesPyTreePrefixMismatchErrorKwargs"
      "testOutAxesPyTreePrefixMismatchError"
      "test_tree_map"
      "test_vjp_rule_inconsistent_pytree_structures_error"
      "test_vmap_in_axes_tree_prefix_error"
      "test_vmap_mismatched_axis_sizes_error_message_issue_705"
    ];

  disabledTestPaths = [
  disabledTestPaths =
    [
      # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
      "tests/linalg_test.py"
    ]
@@ -147,25 +180,26 @@ buildPythonPackage rec {
  #
  #   NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
  passthru.tests = {
    test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
      jaxlib = jaxlib-build.override { cudaSupport = true; };
    };
    # jaxlib-build is broken as of 2024-12-20
    # test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
    #   jax = jax.override { jaxlib = jaxlib-build; };
    # };
    test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
      jaxlib = jaxlib-bin.override { cudaSupport = true; };
      jax = jax.override { jaxlib = jaxlib-bin; };
    };
  };

  # updater fails to pick the correct branch
  passthru.skipBulkUpdate = true;

  meta = with lib; {
  meta = {
    description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
    longDescription = ''
      This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
      e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
    '';
    homepage = "https://github.com/google/jax";
    license = licenses.asl20;
    maintainers = with maintainers; [ samuela ];
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ samuela ];
  };
}
+1 −3
Original line number Diff line number Diff line
{
  jax,
  jaxlib,
  pkgs,
}:

@@ -8,8 +7,7 @@ pkgs.writers.writePython3Bin "jax-test-cuda"
  {
    libraries = [
      jax
      jaxlib
    ];
    ] ++ jax.optional-dependencies.cuda;
  }
  ''
    import jax
Loading