Loading pkgs/development/python-modules/flax/default.nix +12 −22 Original line number Diff line number Diff line Loading @@ -4,14 +4,17 @@ , jaxlib , pythonRelaxDepsHook , setuptools-scm , cloudpickle , jax , matplotlib , msgpack , numpy , optax , pyyaml , rich , tensorstore , typing-extensions , matplotlib , cloudpickle , einops , keras , pytest-xdist , pytestCheckHook Loading @@ -37,24 +40,27 @@ buildPythonPackage rec { ]; propagatedBuildInputs = [ cloudpickle jax matplotlib msgpack numpy optax pyyaml rich tensorstore typing-extensions ]; # See https://github.com/google/flax/pull/2882. pythonRemoveDeps = [ "orbax" ]; passthru.optional-dependencies = { all = [ matplotlib ]; }; pythonImportsCheck = [ "flax" ]; nativeCheckInputs = [ cloudpickle einops keras pytest-xdist pytestCheckHook Loading Loading @@ -85,22 +91,6 @@ buildPythonPackage rec { "tests/checkpoints_test.py" ]; disabledTests = [ # See https://github.com/google/flax/issues/2554. "test_async_save_checkpoints" "test_jax_array0" "test_jax_array1" "test_keep0" "test_keep1" "test_optimized_lstm_cell_matches_regular" "test_overwrite_checkpoints" "test_save_restore_checkpoints_target_empty" "test_save_restore_checkpoints_target_none" "test_save_restore_checkpoints_target_singular" "test_save_restore_checkpoints_w_float_steps" "test_save_restore_checkpoints" ]; meta = with lib; { description = "Neural network library for JAX"; homepage = "https://github.com/google/flax"; Loading Loading
pkgs/development/python-modules/flax/default.nix +12 −22 Original line number Diff line number Diff line Loading @@ -4,14 +4,17 @@ , jaxlib , pythonRelaxDepsHook , setuptools-scm , cloudpickle , jax , matplotlib , msgpack , numpy , optax , pyyaml , rich , tensorstore , typing-extensions , matplotlib , cloudpickle , einops , keras , pytest-xdist , pytestCheckHook Loading @@ -37,24 +40,27 @@ buildPythonPackage rec { ]; propagatedBuildInputs = [ cloudpickle jax matplotlib msgpack numpy optax pyyaml rich tensorstore typing-extensions ]; # See https://github.com/google/flax/pull/2882. pythonRemoveDeps = [ "orbax" ]; passthru.optional-dependencies = { all = [ matplotlib ]; }; pythonImportsCheck = [ "flax" ]; nativeCheckInputs = [ cloudpickle einops keras pytest-xdist pytestCheckHook Loading Loading @@ -85,22 +91,6 @@ buildPythonPackage rec { "tests/checkpoints_test.py" ]; disabledTests = [ # See https://github.com/google/flax/issues/2554. "test_async_save_checkpoints" "test_jax_array0" "test_jax_array1" "test_keep0" "test_keep1" "test_optimized_lstm_cell_matches_regular" "test_overwrite_checkpoints" "test_save_restore_checkpoints_target_empty" "test_save_restore_checkpoints_target_none" "test_save_restore_checkpoints_target_singular" "test_save_restore_checkpoints_w_float_steps" "test_save_restore_checkpoints" ]; meta = with lib; { description = "Neural network library for JAX"; homepage = "https://github.com/google/flax"; Loading