Commit 33e97611 authored by Samuel Ainsworth's avatar Samuel Ainsworth
Browse files

python3Packages.flax: fix build

parent 46388074
Loading
Loading
Loading
Loading
+13 −3
Original line number Diff line number Diff line
@@ -10,7 +10,9 @@
, optax
, pytest-xdist
, pytestCheckHook
, pythonRelaxDepsHook
, tensorflow
, tensorstore
, fetchpatch
, rich
}:
@@ -26,7 +28,7 @@ buildPythonPackage rec {
    hash = "sha256-Vv68BK83gTIKj0r9x+twdhqmRYziD0vxQCdHkYSeTak=";
  };

  buildInputs = [ jaxlib ];
  nativeBuildInputs = [ jaxlib pythonRelaxDepsHook ];

  propagatedBuildInputs = [
    jax
@@ -35,8 +37,12 @@ buildPythonPackage rec {
    numpy
    optax
    rich
    tensorstore
  ];

  # See https://github.com/google/flax/pull/2882.
  pythonRemoveDeps = [ "orbax" ];

  pythonImportsCheck = [
    "flax"
  ];
@@ -64,6 +70,12 @@ buildPythonPackage rec {
    # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
    # would be limited anyway.
    "examples/*"

    # See https://github.com/google/flax/issues/3232.
    "tests/jax_utils_test.py"

    # Requires orbax which is not packaged as of 2023-07-27.
    "tests/checkpoints_test.py"
  ];

  disabledTests = [
@@ -88,7 +100,5 @@ buildPythonPackage rec {
    changelog = "https://github.com/google/flax/releases/tag/v${version}";
    license = licenses.asl20;
    maintainers = with maintainers; [ ndl ];
    # Requires orbax which is not available
    broken = true;
  };
}