Commit b5de13f7 authored by natsukium's avatar natsukium Committed by Yt
Browse files
parent 7a031b95
Loading
Loading
Loading
Loading
+28 −7
Original line number Diff line number Diff line
@@ -2,6 +2,8 @@
, lib
, buildPythonPackage
, fetchFromGitHub
, fetchpatch
, pythonAtLeast
, pythonOlder
, pytestCheckHook
, setuptools
@@ -17,7 +19,7 @@

buildPythonPackage rec {
  pname = "accelerate";
  version = "0.19.0";
  version = "0.21.0";
  format = "pyproject";
  disabled = pythonOlder "3.7";

@@ -25,9 +27,18 @@ buildPythonPackage rec {
    owner = "huggingface";
    repo = pname;
    rev = "refs/tags/v${version}";
    hash = "sha256-gW4wCpkyxoWfxXu8UHZfgopSQhOoPhGgqEqFiHJ+Db4=";
    hash = "sha256-BwM3gyNhsRkxtxLNrycUGwBmXf8eq/7b56/ykMryt5w=";
  };

  patches = [
    # fix import error when torch>=2.0.1 and torch.distributed is disabled
    # https://github.com/huggingface/accelerate/pull/1800
    (fetchpatch {
      url = "https://github.com/huggingface/accelerate/commit/32701039d302d3875c50c35ab3e76c467755eae9.patch";
      hash = "sha256-Hth7qyOfx1sC8UaRdbYTnyRXD/VRKf41GtLc0ee1t2I=";
    })
  ];

  nativeBuildInputs = [ setuptools ];

  propagatedBuildInputs = [
@@ -53,15 +64,25 @@ buildPythonPackage rec {
    # try to download data:
    "FeatureExamplesTests"
    "test_infer_auto_device_map_on_t0pp"
    # known failure with Torch>2.0; see https://github.com/huggingface/accelerate/pull/1339:
    # (remove for next release)
    "test_gradient_sync_cpu_multi"
  ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
    # usual aarch64-linux RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
    "CheckpointTest"
  ] ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
    # RuntimeError: torch_shm_manager: execl failed: Permission denied
    "CheckpointTest"
  ] ++ lib.optionals (pythonAtLeast "3.11") [
    # python3.11 not yet supported for torch.compile
    "test_dynamo_extract_model"
  ];

  disabledTestPaths = lib.optionals (!(stdenv.isLinux && stdenv.isx86_64)) [
    # numerous instances of torch.multiprocessing.spawn.ProcessRaisedException:
  doCheck = !stdenv.isDarwin;
    "tests/test_cpu.py"
    "tests/test_grad_sync.py"
    "tests/test_metrics.py"
    "tests/test_scheduler.py"
  ];

  pythonImportsCheck = [
    "accelerate"
  ];