Commit 588d188d authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files
parent 3fc5f627
Loading
Loading
Loading
Loading
+10 −6
Original line number Diff line number Diff line
{
  lib,
  buildPythonPackage,
  pythonOlder,
  fetchFromGitHub,

  # build-system
@@ -26,6 +25,7 @@
  pytest-xdist,
  pytestCheckHook,
  tensorflow,
  treescope,

  # optional-dependencies
  matplotlib,
@@ -33,16 +33,14 @@

buildPythonPackage rec {
  pname = "flax";
  version = "0.8.5";
  version = "0.9.0";
  pyproject = true;

  disabled = pythonOlder "3.9";

  src = fetchFromGitHub {
    owner = "google";
    repo = "flax";
    rev = "refs/tags/v${version}";
    hash = "sha256-6WOFq0758gtNdrlWqSQBlKmWVIGe5e4PAaGrvHoGjr0=";
    hash = "sha256-iDWuUJKO7V4QrbVsS4ALgy6fbllOC43o7W4mhjtZ9xc=";
  };

  build-system = [
@@ -75,6 +73,7 @@ buildPythonPackage rec {
    pytest-xdist
    pytestCheckHook
    tensorflow
    treescope
  ];

  pytestFlagsArray = [
@@ -95,13 +94,18 @@ buildPythonPackage rec {
    "flax/nnx/examples/*"
    # See https://github.com/google/flax/issues/3232.
    "tests/jax_utils_test.py"
    # Requires tree
    # Too old version of tensorflow:
    # ModuleNotFoundError: No module named 'keras.api._v2'
    "tests/tensorboard_test.py"
  ];

  disabledTests = [
    # ValueError: Checkpoint path should be absolute
    "test_overwrite_checkpoints0"
    # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
    # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
    "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
    "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
  ];

  meta = {