Commit aa8014b5 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files
parent 1783f05a
Loading
Loading
Loading
Loading
+19 −30
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@
  fetchFromGitHub,

  # build-system
  jaxlib,
  setuptools,
  setuptools-scm,

  # dependencies
@@ -19,42 +19,43 @@
  tensorstore,
  typing-extensions,

  # checks
  # optional-dependencies
  matplotlib,

  # dependencies
  cloudpickle,
  keras,
  einops,
  flaxlib,
  keras,
  pytestCheckHook,
  pytest-xdist,
  sphinx,
  tensorflow,
  treescope,

  # optional-dependencies
  matplotlib,

  writeScript,
  tomlq,
}:

buildPythonPackage rec {
  pname = "flax";
  version = "0.10.1";
  version = "0.10.3";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "google";
    repo = "flax";
    tag = "v${version}";
    hash = "sha256-+URbQGnmqmSNgucEyWvI5DMnzXjpmJzLA+Pho2lX+S4=";
    hash = "sha256-PRKdtltiBVX9p6Sjw4sCDghqxYRxq4L9TLle1vy5dkk=";
  };

  build-system = [
    jaxlib
    setuptools
    setuptools-scm
  ];

  dependencies = [
    flaxlib
    jax
    msgpack
    numpy
@@ -63,6 +64,7 @@ buildPythonPackage rec {
    pyyaml
    rich
    tensorstore
    treescope
    typing-extensions
  ];

@@ -74,24 +76,18 @@ buildPythonPackage rec {

  nativeCheckInputs = [
    cloudpickle
    einops
    flaxlib
    keras
    einops
    pytestCheckHook
    pytest-xdist
    sphinx
    tensorflow
    treescope
  ];

  pytestFlagsArray = [
    "-W ignore::FutureWarning"
    "-W ignore::DeprecationWarning"
  ];

  disabledTestPaths = [
    # Docs test, needs extra deps + we're not interested in it.
    "docs/_ext/codediff_test.py"

    # The tests in `examples` are not designed to be executed from a single test
    # session and thus either have the modules that conflict with each other or
    # wrong import paths, depending on how they're invoked. Many tests also have
@@ -99,19 +95,12 @@ buildPythonPackage rec {
    # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
    # would be limited anyway.
    "examples/*"

    # See https://github.com/google/flax/issues/3232.
    "tests/jax_utils_test.py"
  ];

  disabledTests =
    [
      # Failing with AssertionError since the jax update to 0.5.0
      "test_basic_demo_single"
      "test_batch_norm_multi_init"
      "test_multimetric"
      "test_split_merge"
    ]
    ++ lib.optionals stdenv.hostPlatform.isDarwin [
  disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [
    # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
    "test_ref_changed"
    "test_structure_changed"
+37 −25
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  buildPythonPackage,
  flax,
  tomlq,
  rustPlatform,
  pytestCheckHook,
  python,

  # build-system
  meson-python,
  nanobind,
  ninja,

  # nativeBuildInputs
  cmake,
  pkg-config,
}:

let
  nanobind-wrapper = stdenv.mkDerivation {
    pname = "nanobind-wrapper";
    inherit (nanobind) version;

    src = ./nanobind-wrapper;

    nativeBuildInputs = [
      cmake
    ];

    buildFlags = [ "nanobind-static" ];

    env.CMAKE_PREFIX_PATH = "${nanobind}/${python.sitePackages}/nanobind";
  };
in
buildPythonPackage rec {
  pname = "flaxlib";
  version = "0.0.1-a1";
@@ -14,7 +39,7 @@ buildPythonPackage rec {

  inherit (flax) src;

  sourceRoot = "${src.name}/flaxlib";
  sourceRoot = "${src.name}/flaxlib_src";

  postPatch = ''
    expected_version="$version"
@@ -28,34 +53,21 @@ buildPythonPackage rec {
    fi
  '';

  cargoDeps = rustPlatform.fetchCargoVendor {
    inherit
      pname
      version
      src
      sourceRoot
      ;
    hash = "sha256-CN/ZbDxdCQPEuLfxPh/m+JtlFDkerO8aWgAaUwhixjQ=";
  };
  dontUseCmakeConfigure = true;

  build-system = [
    meson-python
    nanobind
    ninja
  ];
  nativeBuildInputs = [
    rustPlatform.maturinBuildHook
    rustPlatform.cargoSetupHook
    cmake
    pkg-config
  ];
  buildInputs = [ nanobind-wrapper ];

  pythonImportsCheck = [ "flaxlib" ];

  nativeCheckInputs = [
    pytestCheckHook
  ];

  env = {
    # https://github.com/google/flax/issues/4491
    # Upstream should update Cargo.lock
    # Enabling `PYO3_USE_ABI3_FORWARD_COMPATIBILITY` allows us to temporarily avoid the issue
    PYO3_USE_ABI3_FORWARD_COMPATIBILITY = true;
  };

  # This package does not have tests (yet ?)
  doCheck = false;

+9 −0
Original line number Diff line number Diff line
cmake_minimum_required(VERSION 3.31)
project(nanobind-wrapper)

find_package(Python COMPONENTS Interpreter Development REQUIRED)
find_package(nanobind CONFIG REQUIRED)
nanobind_build_library(nanobind-static)
set_property(TARGET nanobind-static PROPERTY EXPORT_NAME nanobind)
install(TARGETS nanobind-static EXPORT nanobind-static)
install(EXPORT nanobind-static FILE nanobindConfig.cmake DESTINATION lib/nanobind/cmake)