Commit 05f96ae1 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python3Packages.clu: fix jax 0.10.0 compatibility

parent 5f5f0d62
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ buildPythonPackage (finalAttrs: {
  pname = "clu";
  version = "0.0.12";
  pyproject = true;
  __structuredAttrs = true;

  src = fetchFromGitHub {
    owner = "google";
@@ -39,6 +40,15 @@ buildPythonPackage (finalAttrs: {
    hash = "sha256-ntqRz3fCXMf0EDQsddT68Mdi105ECBVQpVsStzk2kvQ=";
  };

  # Fix Jax 0.10.0 compatibility
  # TypeError: clip() got an unexpected keyword argument 'a_min'
  postPatch = ''
    substituteInPlace clu/metrics.py \
      --replace-fail \
        "variance = jnp.clip(variance, a_min=0.0)" \
        "variance = jnp.clip(variance, min=0.0)"
  '';

  build-system = [
    setuptools
  ];