Unverified Commit 1802b4b4 authored by natsukium's avatar natsukium
Browse files
parent 533e2d71
Loading
Loading
Loading
Loading
+77 −43
Original line number Diff line number Diff line
{
  lib,
  config,
  stdenv,
  blas,
  lapack,
  buildPythonPackage,
  callPackage,
  setuptools,
  importlib-metadata,
  fetchFromGitHub,
  cudaSupport ? config.cudaSupport,

  # build-system
  setuptools,

  # dependencies
  jaxlib,
  jaxlib-bin,
  jaxlib-build,
  hypothesis,
  lapack,
  matplotlib,
  ml-dtypes,
  numpy,
  opt-einsum,
  scipy,

  # optional-dependencies
  jax-cuda12-plugin,

  # tests
  cloudpickle,
  hypothesis,
  matplotlib,
  pytestCheckHook,
  pytest-xdist,
  pythonOlder,
  scipy,
  stdenv,

  # passthru
  callPackage,
  jax,
  jaxlib-build,
  jaxlib-bin,
}:

let
@@ -27,38 +40,41 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.28";
  version = "0.4.38";
  pyproject = true;

  disabled = pythonOlder "3.9";

  src = fetchFromGitHub {
    owner = "google";
    repo = "jax";
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/jax-v${version}";
    hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
    tag = "jax-v${version}";
    hash = "sha256-H8I9Mkz6Ia1RxJmnuJOSevLGHN2J8ey59ZTlFx8YfnA=";
  };

  nativeBuildInputs = [ setuptools ];
  build-system = [ setuptools ];

  # The version is automatically set to ".dev" if this variable is not set.
  # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
  JAX_RELEASE = "1";

  # jaxlib is _not_ included in propagatedBuildInputs because there are
  # different versions of jaxlib depending on the desired target hardware. The
  # JAX project ships separate wheels for CPU, GPU, and TPU.
  propagatedBuildInputs = [
  dependencies = [
    jaxlib
    ml-dtypes
    numpy
    opt-einsum
    scipy
  ] ++ lib.optional (pythonOlder "3.10") importlib-metadata;
  ] ++ lib.optionals cudaSupport optional-dependencies.cuda;

  optional-dependencies = rec {
    cuda = [ jax-cuda12-plugin ];
    cuda12 = cuda;
    cuda12_pip = cuda;
    cuda12_local = cuda;
  };

  nativeCheckInputs = [
    cloudpickle
    hypothesis
    jaxlib
    matplotlib
    pytestCheckHook
    pytest-xdist
@@ -71,10 +87,16 @@ buildPythonPackage rec {
  # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
  # Not a big deal, this is how the JAX docs suggest running the test suite
  # anyhow.
  pytestFlagsArray = [
  pytestFlagsArray =
    [
      "--numprocesses=4"
      "-W ignore::DeprecationWarning"
      "tests/"
    ]
    ++ lib.optionals stdenv.hostPlatform.isDarwin [
      # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
      "--deselect tests/shape_poly_test.py::ShapePolyTest"
      "--deselect tests/tree_util_test.py::TreeTest"
    ];

  # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
@@ -125,9 +147,20 @@ buildPythonPackage rec {
      # Fails on some hardware due to some numerical error
      # See https://github.com/google/jax/issues/18535
      "testQdwhWithOnRankDeficientInput5"
    ]
    ++ lib.optionals stdenv.hostPlatform.isDarwin [
      # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
      "testInAxesPyTreePrefixMismatchError"
      "testInAxesPyTreePrefixMismatchErrorKwargs"
      "testOutAxesPyTreePrefixMismatchError"
      "test_tree_map"
      "test_vjp_rule_inconsistent_pytree_structures_error"
      "test_vmap_in_axes_tree_prefix_error"
      "test_vmap_mismatched_axis_sizes_error_message_issue_705"
    ];

  disabledTestPaths = [
  disabledTestPaths =
    [
      # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
      "tests/linalg_test.py"
    ]
@@ -147,25 +180,26 @@ buildPythonPackage rec {
  #
  #   NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
  passthru.tests = {
    test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
      jaxlib = jaxlib-build.override { cudaSupport = true; };
    };
    # jaxlib-build is broken as of 2024-12-20
    # test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
    #   jax = jax.override { jaxlib = jaxlib-build; };
    # };
    test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
      jaxlib = jaxlib-bin.override { cudaSupport = true; };
      jax = jax.override { jaxlib = jaxlib-bin; };
    };
  };

  # updater fails to pick the correct branch
  passthru.skipBulkUpdate = true;

  meta = with lib; {
  meta = {
    description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
    longDescription = ''
      This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
      e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
    '';
    homepage = "https://github.com/google/jax";
    license = licenses.asl20;
    maintainers = with maintainers; [ samuela ];
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ samuela ];
  };
}
+1 −3
Original line number Diff line number Diff line
{
  jax,
  jaxlib,
  pkgs,
}:

@@ -8,8 +7,7 @@ pkgs.writers.writePython3Bin "jax-test-cuda"
  {
    libraries = [
      jax
      jaxlib
    ];
    ] ++ jax.optional-dependencies.cuda;
  }
  ''
    import jax