Unverified Commit 06ef57da authored by Gaetan Lepage's avatar Gaetan Lepage Committed by Nick Cao
Browse files

python3Packages.jax: 0.4.5 -> 0.4.14

parent 7b16d5d8
Loading
Loading
Loading
Loading
+15 −25
Original line number Diff line number Diff line
{ lib
, absl-py
, blas
, buildPythonPackage
, etils
, setuptools
, importlib-metadata
, fetchFromGitHub
, jaxlib
, jaxlib-bin
, lapack
, matplotlib
, ml-dtypes
, numpy
, opt-einsum
, pytestCheckHook
@@ -15,7 +16,6 @@
, pythonOlder
, scipy
, stdenv
, typing-extensions
}:

let
@@ -27,30 +27,32 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.5";
  format = "setuptools";
  version = "0.4.14";
  format = "pyproject";

  disabled = pythonOlder "3.7";
  disabled = pythonOlder "3.9";

  src = fetchFromGitHub {
    owner = "google";
    repo = pname;
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/${pname}-v${version}";
    hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
    hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
  };

  nativeBuildInputs = [
    setuptools
  ];

  # jaxlib is _not_ included in propagatedBuildInputs because there are
  # different versions of jaxlib depending on the desired target hardware. The
  # JAX project ships separate wheels for CPU, GPU, and TPU.
  propagatedBuildInputs = [
    absl-py
    etils
    ml-dtypes
    numpy
    opt-einsum
    scipy
    typing-extensions
  ] ++ etils.optional-dependencies.epath;
  ] ++ lib.optional (pythonOlder "3.10") importlib-metadata;

  nativeCheckInputs = [
    jaxlib'
@@ -96,24 +98,12 @@ buildPythonPackage rec {
    "testScanGrad_jit_scan"
  ];

  # See https://github.com/google/jax/issues/11722. This is a temporary fix in
  # order to unblock etils, and upgrading jax/jaxlib to the latest version. See
  # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
  disabledTestPaths = [
    "tests/api_test.py"
    "tests/core_test.py"
    "tests/lax_numpy_indexing_test.py"
    "tests/lax_numpy_test.py"
    "tests/nn_test.py"
    "tests/random_test.py"
    "tests/sparse_test.py"
  ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
  disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
    # RuntimeWarning: invalid value encountered in cast
    "tests/lax_test.py"
  ];

  # As of 0.3.22, `import jax` does not work without jaxlib being installed.
  pythonImportsCheck = [ ];
  pythonImportsCheck = [ "jax" ];

  meta = with lib; {
    description = "Differentiate, compile, and transform Numpy code";