Unverified Commit 707f5c1e authored by Peder Bergebakken Sundt's avatar Peder Bergebakken Sundt Committed by GitHub
Browse files

python3Packages.torcheval: now supports unix platforms (#396559)

parents a0647dc4 0a0d1526
Loading
Loading
Loading
Loading
+60 −37
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchFromGitHub,

@@ -71,7 +72,8 @@ buildPythonPackage {
    torchvision
  ];

  pytestFlagsArray = [
  pytestFlagsArray =
    [
      "-v"
      "tests/"

@@ -104,6 +106,27 @@ buildPythonPackage {
      # AssertionError: Scalars are not close!
      # Expected -640.4547729492188 but got -640.4707641601562
      "--deselect=tests/metrics/regression/test_mean_squared_error.py::TestMeanSquaredError::test_mean_squared_error_class_update_input_shape_different"
    ]

    # These tests error on darwin platforms.
    # NotImplementedError: The operator 'c10d::allgather_' is not currently implemented for the mps device
    #
    # Applying the suggested environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1;` causes the tests to fail,
    # as using the CPU instead of the MPS causes the tensors to be on the wrong device:
    # RuntimeError: ProcessGroupGloo::allgather: invalid tensor type at index 0;
    # Expected TensorOptions(dtype=float, device=cpu, ...), got TensorOptions(dtype=float, device=mps:0, ...)
    ++ lib.optional stdenv.hostPlatform.isDarwin [
      # -- tests/metrics/test_synclib.py --
      "--deselect=tests/metrics/test_synclib.py::SynclibTest::test_complex_mixed_state_sync"
      "--deselect=tests/metrics/test_synclib.py::SynclibTest::test_complex_mixed_state_sync"
      "--deselect=tests/metrics/test_synclib.py::SynclibTest::test_empty_tensor_list_sync_state"
      "--deselect=tests/metrics/test_synclib.py::SynclibTest::test_sync_dtype_and_shape"
      "--deselect=tests/metrics/test_synclib.py::SynclibTest::test_tensor_list_sync_states"
      "--deselect=tests/metrics/test_synclib.py::SynclibTest::test_tensor_dict_sync_states"
      "--deselect=tests/metrics/test_synclib.py::SynclibTest::test_tensor_sync_states"
      # -- tests/metrics/test_toolkit.py --
      "--deselect=tests/metrics/test_toolkit.py::MetricToolkitTest::test_metric_sync"
      "--deselect=tests/metrics/test_toolkit.py::MetricCollectionToolkitTest::test_metric_collection_sync"
    ];

  meta = {
@@ -111,8 +134,8 @@ buildPythonPackage {
    homepage = "https://pytorch.org/torcheval";
    changelog = "https://github.com/pytorch/torcheval/releases/tag/${version}";

    platforms = lib.platforms.linux;
    license = with lib.licenses; [ bsd3 ];
    maintainers = with lib.maintainers; [ bengsparks ];
    platforms = lib.platforms.unix;
    license = [ lib.licenses.bsd3 ];
    maintainers = [ lib.maintainers.bengsparks ];
  };
}