Unverified Commit 35501f3e authored by Weijia Wang's avatar Weijia Wang Committed by GitHub
Browse files

Merge pull request #247925 from GaetanLepage/equinox

python3Packages.equinox: init at 0.10.11
parents 6bf8b8b6 64407861
Loading
Loading
Loading
Loading
+49 −0
Original line number Diff line number Diff line
{ lib
, buildPythonPackage
, fetchFromGitHub
, hatchling
, jax
, jaxlib
, jaxtyping
, typing-extensions
, beartype
, pytestCheckHook
}:

buildPythonPackage rec {
  pname = "equinox";
  version = "0.10.11";
  format = "pyproject";

  src = fetchFromGitHub {
    owner = "patrick-kidger";
    repo = pname;
    rev = "refs/tags/v${version}";
    hash = "sha256-JffuPplIROPog29FBsWH9cQHSkrFKuXjaTjjEwIqW/0=";
  };

  nativeBuildInputs = [
    hatchling
  ];

  propagatedBuildInputs = [
    jax
    jaxlib
    jaxtyping
    typing-extensions
  ];

  nativeCheckInputs = [
    beartype
    pytestCheckHook
  ];

  pythonImportsCheck = [ "equinox" ];

  meta = with lib; {
    description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
    homepage = "https://github.com/patrick-kidger/equinox";
    license = licenses.asl20;
    maintainers = with maintainers; [ GaetanLepage ];
  };
}
+64 −0
Original line number Diff line number Diff line
{ lib
, buildPythonPackage
, fetchFromGitHub
, hatchling
, numpy
, typeguard
, typing-extensions
, cloudpickle
, equinox
, jax
, jaxlib
, torch
, pytestCheckHook
}:

let
  jaxtyping = buildPythonPackage rec {
    pname = "jaxtyping";
    version = "0.2.20";
    format = "pyproject";

    src = fetchFromGitHub {
      owner = "google";
      repo = pname;
      rev = "refs/tags/v${version}";
      hash = "sha256-q/KQGV7I7w5p7VP8C9BDUHfPsuCMf2v304qiH+XCzyU=";
    };

    nativeBuildInputs = [
      hatchling
    ];

    propagatedBuildInputs = [
      numpy
      typeguard
      typing-extensions
    ];

    nativeCheckInputs = [
      cloudpickle
      equinox
      jax
      jaxlib
      pytestCheckHook
      torch
    ];

    doCheck = false;

    # Enable tests via passthru to avoid cyclic dependency with equinox.
    passthru.tests = {
      check = jaxtyping.overridePythonAttrs { doCheck = true; };
    };

    pythonImportsCheck = [ "jaxtyping" ];

    meta = with lib; {
      description = "Type annotations and runtime checking for JAX arrays and PyTrees";
      homepage = "https://github.com/google/jaxtyping";
      license = licenses.mit;
      maintainers = with maintainers; [ GaetanLepage ];
    };
  };
 in jaxtyping
+4 −0
Original line number Diff line number Diff line
@@ -3410,6 +3410,8 @@ self: super: with self; {

  epson-projector = callPackage ../development/python-modules/epson-projector { };

  equinox = callPackage ../development/python-modules/equinox { };

  eradicate = callPackage ../development/python-modules/eradicate { };

  es-client = callPackage ../development/python-modules/es-client { };
@@ -5340,6 +5342,8 @@ self: super: with self; {

  jaxopt = callPackage ../development/python-modules/jaxopt { };

  jaxtyping = callPackage ../development/python-modules/jaxtyping { };

  jaydebeapi = callPackage ../development/python-modules/jaydebeapi { };

  jc = callPackage ../development/python-modules/jc { };