Commit 7ebfca6c authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python312Packages.distrax: adapt to jax 0.6.0 new API changes

parent 81d5c7c9
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
  lib,
  buildPythonPackage,
  fetchFromGitHub,
  fetchpatch,
  chex,
  jaxlib,
  numpy,
@@ -23,6 +24,15 @@ buildPythonPackage rec {
    hash = "sha256-A1aCL/I89Blg9sNmIWQru4QJteUTN6+bhgrEJPmCrM0=";
  };

  patches = [
    # TODO: remove at the next release (already on master)
    (fetchpatch {
      name = "fix-jax-0.6.0-compat";
      url = "https://github.com/google-deepmind/distrax/commit/c02708ac46518fac00ab2945311e0f2ee32c672c.patch";
      hash = "sha256-hFNXKoA1b5I6dzhwTRXp/SnkHv89GI6tYwlnBBHwG78=";
    })
  ];

  dependencies = [
    chex
    jaxlib
@@ -71,6 +81,10 @@ buildPythonPackage rec {
  ];

  disabledTestPaths = [
    # Since jax 0.6.0:
    # TypeError: <lambda>() got an unexpected keyword argument 'accuracy'
    "distrax/_src/bijectors/lambda_bijector_test.py"

    # TypeErrors
    "distrax/_src/bijectors/tfp_compatible_bijector_test.py"
    "distrax/_src/distributions/distribution_from_tfp_test.py"