Commit 035ae46b authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files
parent a33f6d17
Loading
Loading
Loading
Loading
+24 −6
Original line number Diff line number Diff line
@@ -33,28 +33,25 @@
  pytest-xdist,
  pytestCheckHook,
  safetensors,
  torch,
}:

buildPythonPackage rec {
  pname = "orbax-checkpoint";
  version = "0.11.28";
  version = "0.11.30";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "google";
    repo = "orbax";
    tag = "v${version}";
    hash = "sha256-a7E60fZRmEXTA220mwr7EDMUc+zYbW7wG40vY7NeAOM=";
    hash = "sha256-y8l0AVGt2t5zLX+x+yuWHsEDy68agpXIkrew+zfYGXU=";
  };

  sourceRoot = "${src.name}/checkpoint";

  build-system = [ flit-core ];

  pythonRelaxDeps = [
    "jax"
  ];

  dependencies = [
    absl-py
    aiofiles
@@ -82,6 +79,7 @@ buildPythonPackage rec {
    pytest-xdist
    pytestCheckHook
    safetensors
    torch
  ];

  pythonImportsCheck = [
@@ -115,6 +113,26 @@ buildPythonPackage rec {
  ];

  disabledTestPaths = [
    # import file mismatch:
    # imported module 'sharding_test' has this __file__ attribute:
    #   /build/source/checkpoint/orbax/checkpoint/_src/arrays/sharding_test.py
    # which is not the same as the test file we want to collect:
    #   /build/source/checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py
    "orbax/checkpoint/_src/metadata/sharding_test.py"

    # Circular dependency with clu
    "orbax/checkpoint/_src/testing/benchmarks/array_handler_benchmark_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_perf_benchmark_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/checkpoint_policy_benchmark_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/core/config_parsing_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/core/core_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/core/metric_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/emergency_checkpoint_manager_benchmark_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/multihost_dispatchers_benchmark_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/pytree_checkpoint_benchmark_test.py"
    "orbax/checkpoint/_src/testing/benchmarks/single_replica_benchmark_test.py"

    # E   absl.flags._exceptions.DuplicateFlagError: The flag 'num_processes' is defined twice.
    # First from multiprocess_test, Second from orbax.checkpoint._src.testing.multiprocess_test.
    # Description from first occurrence: Number of processes to use.