Unverified Commit fe6c67bf authored by Nick Cao's avatar Nick Cao Committed by GitHub
Browse files

python3Packages.flax: 0.12.0 -> 0.12.1 (#463316)

parents e3b90fb7 3aae229d
Loading
Loading
Loading
Loading
+2 −20
Original line number Diff line number Diff line
@@ -2,7 +2,6 @@
  lib,
  buildPythonPackage,
  fetchFromGitHub,
  fetchpatch,

  # build-system
  setuptools,
@@ -39,27 +38,16 @@

buildPythonPackage rec {
  pname = "flax";
  version = "0.12.0";
  version = "0.12.1";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "google";
    repo = "flax";
    tag = "v${version}";
    hash = "sha256-ioMj8+TuOFX3t9p3oVaywaOQPFBgvNcy7b/2WX/yvXA=";
    hash = "sha256-AUgNU1ww1Ic+lfdHtdP4fdFuvIatAXqs7AX615aVPKM=";
  };

  patches = [
    # Fixes TypeError: shard_map() got an unexpected keyword argument 'auto'
    # TODO: remove when updating to the next release
    # https://github.com/google/flax/pull/5020
    (fetchpatch {
      name = "jax-0.8.0-compat";
      url = "https://github.com/google/flax/commit/5bf9b35ff03130f440a93a812fd8b47ec6a49add.patch";
      hash = "sha256-KYpa1wDQMt77XIDGQEg/VuU/OPPNp2enGSA986TZSLQ=";
    })
  ];

  build-system = [
    setuptools
    setuptools-scm
@@ -95,12 +83,6 @@ buildPythonPackage rec {
    tensorflow
  ];

  pytestFlags = [
    # DeprecationWarning: Triggering of __jax_array__() during abstractification is deprecated.
    # To avoid this error, either explicitly convert your object using jax.numpy.array(), or register your object as a pytree.
    "-Wignore::DeprecationWarning"
  ];

  disabledTestPaths = [
    # Docs test, needs extra deps + we're not interested in it.
    "docs/_ext/codediff_test.py"