Unverified Commit 6f17142e authored by kirillrdy's avatar kirillrdy Committed by GitHub
Browse files

python3Packages.torchtune: init at 0.6.1 (#460933)

parents a99e109b f8108137
Loading
Loading
Loading
Loading
+51 −0
Original line number Diff line number Diff line
{
  lib,
  buildPythonPackage,
  fetchFromGitHub,

  # build-system
  setuptools,

  # tests
  pytestCheckHook,
}:

buildPythonPackage rec {
  pname = "pytest-integration";
  version = "0.2.3";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "jbwdevries";
    repo = "pytest-integration";
    tag = "v${version}";
    hash = "sha256-Ziy+GEfljYDccx3mm63p7rhDUQVDXLbk7DxUW3npjiE=";
  };

  build-system = [
    setuptools
  ];

  pythonImportsCheck = [ "pytest_integration" ];

  nativeCheckInputs = [
    pytestCheckHook
  ];

  # Tests need to discover the mock `package` module located under `example/`
  preCheck = ''
    pushd example
  '';

  postCheck = ''
    popd
  '';

  meta = {
    description = "Organizing test by unit test, quick integration or slow integration";
    homepage = "https://github.com/jbwdevries/pytest-integration";
    changelog = "https://github.com/jbwdevries/pytest-integration/blob/${src.tag}/CHANGELOG.md";
    license = lib.licenses.mit;
    maintainers = with lib.maintainers; [ GaetanLepage ];
  };
}
+123 −0
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchFromGitHub,

  # build-system
  setuptools,

  # dependencies
  blobfile,
  datasets,
  huggingface-hub,
  kagglehub,
  numpy,
  omegaconf,
  pillow,
  psutil,
  safetensors,
  sentencepiece,
  tiktoken,
  tokenizers,
  torch,
  torchdata,
  tqdm,
  torchao,
  torchvision,

  # tests
  comet-ml,
  mlflow,
  pytest-integration,
  pytest-mock,
  pytestCheckHook,
  writableTmpDirAsHomeHook,
}:

buildPythonPackage rec {
  pname = "torchtune";
  version = "0.6.1";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "meta-pytorch";
    repo = "torchtune";
    tag = "v${version}";
    hash = "sha256-evhQBpZiUXriL0PAYkEzGypH21iRs37Ix6Nl5YAyeQ0=";
  };

  build-system = [
    setuptools
  ];

  dependencies = [
    blobfile
    datasets
    huggingface-hub
    kagglehub
    numpy
    omegaconf
    pillow
    psutil
    safetensors
    sentencepiece
    tiktoken
    tokenizers
    torch
    torchdata
    tqdm

    # Not explicitly listed as requirements, but effectively imported at runtime
    torchao
    torchvision
  ]
  ++ huggingface-hub.optional-dependencies.hf_transfer;

  pythonImportsCheck = [ "torchtune" ];

  nativeCheckInputs = [
    comet-ml
    mlflow
    pytest-integration
    pytest-mock
    pytestCheckHook
    writableTmpDirAsHomeHook
  ];

  disabledTests = [
    # AssertionError (tensors are not equal)
    "test_stop_tokens"
    "test_stop_tokens_batched"
    "test_stop_tokens_batched_uneven_stopping"
    "test_stop_tokens_batched_uneven_stopping_left_padded"

    # RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags;
    # please use flags() context manager instead
    "test_deterministic_false"
    "test_deterministic_true"

    # TypeError: exceptions must be derived from Warning, not <class 'NoneType'>
    "test_deprecated"

    # Flaky
    # AssertionError: actual: -83.3048095703125, expected: -83.15229797363281
    "test_forward"
    "test_forward_kv_cache"
    "test_forward_with_2d_pos_ids"
    "test_forward_with_curr_pos"
    "test_forward_with_packed_pos"
  ]
  ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
    # Fatal Python error: Segmentation fault
    "test_forward_gqa"
  ];

  meta = {
    description = "PyTorch native post-training library";
    homepage = "https://github.com/meta-pytorch/torchtune";
    changelog = "https://github.com/meta-pytorch/torchtune/releases/tag/${src.tag}";
    license = lib.licenses.bsd3;
    maintainers = with lib.maintainers; [ GaetanLepage ];
  };
}
+4 −0
Original line number Diff line number Diff line
@@ -14741,6 +14741,8 @@ self: super: with self; {
  pytest-instafail = callPackage ../development/python-modules/pytest-instafail { };
  pytest-integration = callPackage ../development/python-modules/pytest-integration { };
  pytest-isort = callPackage ../development/python-modules/pytest-isort { };
  pytest-json-report = callPackage ../development/python-modules/pytest-json-report { };
@@ -18853,6 +18855,8 @@ self: super: with self; {
  torchtnt = callPackage ../development/python-modules/torchtnt { };
  torchtune = callPackage ../development/python-modules/torchtune { };
  torchvision = callPackage ../development/python-modules/torchvision { };
  torchvision-bin = callPackage ../development/python-modules/torchvision/bin.nix { };