Unverified Commit f6c1156c authored by Sandro Jäckel's avatar Sandro Jäckel Committed by GitHub
Browse files

python312Packages.flax: 0.9.0 -> 0.10.1 (#353862)

parents 4a8f0d12 981dbd7f
Loading
Loading
Loading
Loading
+19 −4
Original line number Diff line number Diff line
@@ -21,26 +21,31 @@
  # checks
  cloudpickle,
  einops,
  flaxlib,
  keras,
  pytest-xdist,
  pytestCheckHook,
  pytest-xdist,
  sphinx,
  tensorflow,
  treescope,

  # optional-dependencies
  matplotlib,

  writeScript,
  tomlq,
}:

buildPythonPackage rec {
  pname = "flax";
  version = "0.9.0";
  version = "0.10.1";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "google";
    repo = "flax";
    rev = "refs/tags/v${version}";
    hash = "sha256-iDWuUJKO7V4QrbVsS4ALgy6fbllOC43o7W4mhjtZ9xc=";
    hash = "sha256-+URbQGnmqmSNgucEyWvI5DMnzXjpmJzLA+Pho2lX+S4=";
  };

  build-system = [
@@ -69,9 +74,11 @@ buildPythonPackage rec {
  nativeCheckInputs = [
    cloudpickle
    einops
    flaxlib
    keras
    pytest-xdist
    pytestCheckHook
    pytest-xdist
    sphinx
    tensorflow
    treescope
  ];
@@ -108,6 +115,14 @@ buildPythonPackage rec {
    "test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
  ];

  passthru = {
    updateScript = writeScript "update.sh" ''
      nix-update flax # does not --build by default
      nix-build . -A flax.src # src is essentially a passthru
      nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
    '';
  };

  meta = {
    description = "Neural network library for JAX";
    homepage = "https://github.com/google/flax";
+65 −0
Original line number Diff line number Diff line
{
  lib,
  buildPythonPackage,
  flax,
  tomlq,
  rustPlatform,
  pytestCheckHook,
}:

buildPythonPackage rec {
  pname = "flaxlib";
  version = "0.0.1-a1";
  pyproject = true;

  inherit (flax) src;

  sourceRoot = "${src.name}/flaxlib";

  postPatch = ''
    expected_version="$version"
    actual_version=$(${lib.getExe tomlq} --file Cargo.toml "package.version")

    if [ "$actual_version" != "$expected_version" ]; then
      echo -e "\n\tERROR:"
      echo -e "\tThe version of the flaxlib python package ($expected_version) does not match the one in its Cargo.toml file ($actual_version)"
      echo -e "\tPlease update the version attribute of the nix python3Packages.flaxlib package."
      exit 1
    fi
  '';

  cargoDeps = rustPlatform.fetchCargoTarball {
    inherit
      pname
      version
      src
      sourceRoot
      ;
    hash = "sha256-RPbMHnRdJaWKLU9Rkz39lmfibO20dnfZmLZqehHM3w4=";
  };

  nativeBuildInputs = [
    rustPlatform.maturinBuildHook
    rustPlatform.cargoSetupHook
  ];

  pythonImportsCheck = [ "flaxlib" ];

  nativeCheckInputs = [
    pytestCheckHook
  ];

  # This package does not have tests (yet ?)
  doCheck = false;

  passthru = {
    inherit (flax) updateScript;
  };

  meta = {
    description = "Rust library used internally by flax";
    homepage = "https://github.com/google/flax/tree/main/flaxlib";
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ GaetanLepage ];
  };
}
+2 −0
Original line number Diff line number Diff line
@@ -4663,6 +4663,8 @@ self: super: with self; {
  flax = callPackage ../development/python-modules/flax { };
  flaxlib = callPackage ../development/python-modules/flaxlib { };
  fleep = callPackage ../development/python-modules/fleep { };
  flet = callPackage ../development/python-modules/flet { };