Unverified Commit 98b89f97 authored by natsukium's avatar natsukium
Browse files

python312Packages.pot: migrate to optional-dependencies

parent 3e5e84c8
Loading
Loading
Loading
Loading
+42 −20
Original line number Diff line number Diff line
@@ -3,9 +3,10 @@
  autograd,
  buildPythonPackage,
  fetchFromGitHub,
  cupy,
  cvxopt,
  cython,
  jax,
  jaxlib,
  matplotlib,
  numpy,
  pymanopt,
@@ -14,8 +15,8 @@
  scikit-learn,
  scipy,
  setuptools,
  enableDimensionalityReduction ? false,
  enableGPU ? false,
  tensorflow,
  torch,
}:

buildPythonPackage rec {
@@ -38,24 +39,47 @@ buildPythonPackage rec {
    numpy
  ];

  dependencies =
    [
  dependencies = [
    numpy
    scipy
    ]
    ++ lib.optionals enableGPU [ cupy ]
    ++ lib.optionals enableDimensionalityReduction [
      autograd
      pymanopt
  ];

  nativeCheckInputs = [
    cvxopt
    matplotlib
    numpy
  optional-dependencies = {
    backend-numpy = [ ];
    backend-jax = [
      jax
      jaxlib
    ];
    backend-cupy = [ ];
    backend-tf = [ tensorflow ];
    backend-torch = [ torch ];
    cvxopt = [ cvxopt ];
    dr = [
      scikit-learn
    pytestCheckHook
      pymanopt
      autograd
    ];
    gnn = [
      torch
      # torch-geometric
    ];
    plot = [ matplotlib ];
    all =
      with optional-dependencies;
      (
        backend-numpy
        ++ backend-jax
        ++ backend-cupy
        ++ backend-tf
        ++ backend-torch
        ++ optional-dependencies.cvxopt
        ++ dr
        ++ gnn
        ++ plot
      );
  };

  nativeCheckInputs = [ pytestCheckHook ];

  postPatch = ''
    substituteInPlace setup.cfg \
@@ -108,8 +132,6 @@ buildPythonPackage rec {
    "test_emd1d_device_tf"
  ];

  disabledTestPaths = lib.optionals (!enableDimensionalityReduction) [ "test/test_dr.py" ];

  pythonImportsCheck = [
    "ot"
    "ot.lp"