Commit 15cd73f5 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python311Packages.flax: dependencies and tests check up

parent e556bb0b
Loading
Loading
Loading
Loading
+12 −22
Original line number Diff line number Diff line
@@ -4,14 +4,17 @@
, jaxlib
, pythonRelaxDepsHook
, setuptools-scm
, cloudpickle
, jax
, matplotlib
, msgpack
, numpy
, optax
, pyyaml
, rich
, tensorstore
, typing-extensions
, matplotlib
, cloudpickle
, einops
, keras
, pytest-xdist
, pytestCheckHook
@@ -37,24 +40,27 @@ buildPythonPackage rec {
  ];

  propagatedBuildInputs = [
    cloudpickle
    jax
    matplotlib
    msgpack
    numpy
    optax
    pyyaml
    rich
    tensorstore
    typing-extensions
  ];

  # See https://github.com/google/flax/pull/2882.
  pythonRemoveDeps = [ "orbax" ];
  passthru.optional-dependencies = {
    all = [ matplotlib ];
  };

  pythonImportsCheck = [
    "flax"
  ];

  nativeCheckInputs = [
    cloudpickle
    einops
    keras
    pytest-xdist
    pytestCheckHook
@@ -85,22 +91,6 @@ buildPythonPackage rec {
    "tests/checkpoints_test.py"
  ];

  disabledTests = [
    # See https://github.com/google/flax/issues/2554.
    "test_async_save_checkpoints"
    "test_jax_array0"
    "test_jax_array1"
    "test_keep0"
    "test_keep1"
    "test_optimized_lstm_cell_matches_regular"
    "test_overwrite_checkpoints"
    "test_save_restore_checkpoints_target_empty"
    "test_save_restore_checkpoints_target_none"
    "test_save_restore_checkpoints_target_singular"
    "test_save_restore_checkpoints_w_float_steps"
    "test_save_restore_checkpoints"
  ];

  meta = with lib; {
    description = "Neural network library for JAX";
    homepage = "https://github.com/google/flax";