Commit 5a39beac authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files
parent 3fed5dff
Loading
Loading
Loading
Loading
+68 −11
Original line number Diff line number Diff line
@@ -17,12 +17,26 @@
  torch,

  # optional-dependencies
  ale-py,
  gym,
  pygame,
  torchsnapshot,
  # atari
  gymnasium,
  # checkpointing
  torchsnapshot,
  # gym-continuous
  mujoco,
  # llm
  accelerate,
  datasets,
  einops,
  immutabledict,
  langdetect,
  nltk,
  playwright,
  protobuf,
  safetensors,
  sentencepiece,
  transformers,
  vllm,
  # offline-data
  h5py,
  huggingface-hub,
  minari,
@@ -32,7 +46,9 @@
  scikit-learn,
  torchvision,
  tqdm,
  # rendering
  moviepy,
  # utils
  git,
  hydra-core,
  tensorboard,
@@ -48,14 +64,14 @@

buildPythonPackage rec {
  pname = "torchrl";
  version = "0.8.1";
  version = "0.9.1";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "pytorch";
    repo = "rl";
    tag = "v${version}";
    hash = "sha256-ANoqIAVKSq023hG83Q71t8oLzud1LeVN5WVPYL3nOks=";
    hash = "sha256-afaWDX5lIAoGTfrBSqrktYoA1S4hv6ogBaKYHc8dQ6E=";
  };

  build-system = [
@@ -73,16 +89,26 @@ buildPythonPackage rec {
  ];

  optional-dependencies = {
    atari = [
      ale-py
      gym
      pygame
    ];
    atari = gymnasium.optional-dependencies.atari;
    checkpointing = [ torchsnapshot ];
    gym-continuous = [
      gymnasium
      mujoco
    ];
    llm = [
      accelerate
      datasets
      einops
      immutabledict
      langdetect
      nltk
      playwright
      protobuf
      safetensors
      sentencepiece
      transformers
      vllm
    ];
    offline-data = [
      h5py
      huggingface-hub
@@ -131,10 +157,31 @@ buildPythonPackage rec {
    ]
    ++ optional-dependencies.atari
    ++ optional-dependencies.gym-continuous
    ++ optional-dependencies.llm
    ++ optional-dependencies.rendering;

  disabledTests =
    [
      # Require network
      "test_create_or_load_dataset"
      "test_from_text_env_tokenizer"
      "test_from_text_env_tokenizer_catframes"
      "test_from_text_rb_slicesampler"
      "test_generate"
      "test_get_dataloader"
      "test_get_scores"
      "test_preproc_data"
      "test_prompt_tensordict_tokenizer"
      "test_reward_model"
      "test_tensordict_tokenizer"
      "test_transform_compose"
      "test_transform_model"
      "test_transform_no_env"
      "test_transform_rb"

      # ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment
      "TestRayCollector"

      # torchrl is incompatible with gymnasium>=1.0
      # https://github.com/pytorch/rl/discussions/2483
      "test_resetting_strategies"
@@ -194,6 +241,16 @@ buildPythonPackage rec {
      "test_vecnorm_parallel_auto"
    ];

  disabledTestPaths = [
    # ERROR collecting test/smoke_test.py
    # import file mismatch:
    # imported module 'smoke_test' has this __file__ attribute:
    #   /build/source/test/llm/smoke_test.py
    # which is not the same as the test file we want to collect:
    #   /build/source/test/smoke_test.py
    "test/llm"
  ];

  meta = {
    description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
    homepage = "https://github.com/pytorch/rl";