Unverified Commit ffe5c6cc authored by Peder Bergebakken Sundt's avatar Peder Bergebakken Sundt Committed by GitHub
Browse files

python312Packages.flax: 0.8.5 -> 0.9.0 (#342970)

parents a0d059de 588d188d
Loading
Loading
Loading
Loading
+10 −6
Original line number Diff line number Diff line
{
  lib,
  buildPythonPackage,
  pythonOlder,
  fetchFromGitHub,

  # build-system
@@ -26,6 +25,7 @@
  pytest-xdist,
  pytestCheckHook,
  tensorflow,
  treescope,

  # optional-dependencies
  matplotlib,
@@ -33,16 +33,14 @@

buildPythonPackage rec {
  pname = "flax";
  version = "0.8.5";
  version = "0.9.0";
  pyproject = true;

  disabled = pythonOlder "3.9";

  src = fetchFromGitHub {
    owner = "google";
    repo = "flax";
    rev = "refs/tags/v${version}";
    hash = "sha256-6WOFq0758gtNdrlWqSQBlKmWVIGe5e4PAaGrvHoGjr0=";
    hash = "sha256-iDWuUJKO7V4QrbVsS4ALgy6fbllOC43o7W4mhjtZ9xc=";
  };

  build-system = [
@@ -75,6 +73,7 @@ buildPythonPackage rec {
    pytest-xdist
    pytestCheckHook
    tensorflow
    treescope
  ];

  pytestFlagsArray = [
@@ -95,13 +94,18 @@ buildPythonPackage rec {
    "flax/nnx/examples/*"
    # See https://github.com/google/flax/issues/3232.
    "tests/jax_utils_test.py"
    # Requires tree
    # Too old version of tensorflow:
    # ModuleNotFoundError: No module named 'keras.api._v2'
    "tests/tensorboard_test.py"
  ];

  disabledTests = [
    # ValueError: Checkpoint path should be absolute
    "test_overwrite_checkpoints0"
    # Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
    # TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
    "test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
    "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
  ];

  meta = {
+65 −0
Original line number Diff line number Diff line
{
  lib,
  buildPythonPackage,
  fetchFromGitHub,

  # build-system
  flit-core,

  # dependencies
  numpy,

  # optional-dependencies
  ipython,
  jax,
  palettable,

  # tests
  absl-py,
  jaxlib,
  pytestCheckHook,
  torch,
}:

buildPythonPackage rec {
  pname = "treescope";
  version = "0.1.5";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "google-deepmind";
    repo = "treescope";
    rev = "refs/tags/v${version}";
    hash = "sha256-+Hm60O9tEXIiE0av1O0BsOdMln4e1s7ijb3WNiQ74jE=";
  };

  build-system = [ flit-core ];

  dependencies = [ numpy ];

  optional-dependencies = {
    notebook = [
      ipython
      jax
      palettable
    ];
  };

  pythonImportsCheck = [ "treescope" ];

  nativeCheckInputs = [
    absl-py
    jax
    jaxlib
    pytestCheckHook
    torch
  ];

  meta = {
    description = "An interactive HTML pretty-printer for machine learning research in IPython notebooks";
    homepage = "https://github.com/google-deepmind/treescope";
    changelog = "https://github.com/google-deepmind/treescope/releases/tag/v${version}";
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ GaetanLepage ];
  };
}
+2 −0
Original line number Diff line number Diff line
@@ -15778,6 +15778,8 @@ self: super: with self; {
  treeo = callPackage ../development/python-modules/treeo { };
  treescope = callPackage ../development/python-modules/treescope { };
  treex = callPackage ../development/python-modules/treex { };
  treq = callPackage ../development/python-modules/treq { };