Unverified Commit 92c9c263 authored by Sandro Jäckel's avatar Sandro Jäckel Committed by GitHub
Browse files

Merge pull request #215047 from bcdarwin/blackjax

parents 69f36ea4 36acad3d
Loading
Loading
Loading
Loading
+8 −4
Original line number Diff line number Diff line
@@ -4,9 +4,11 @@
, fetchFromGitHub
, pytestCheckHook
, arviz
, blackjax
, formulae
, graphviz
, numpy
, numpyro
, pandas
, pymc
, scipy
@@ -35,14 +37,16 @@ buildPythonPackage rec {

  preCheck = ''export HOME=$(mktemp -d)'';

  nativeCheckInputs = [ graphviz pytestCheckHook ];
  nativeCheckInputs = [
    blackjax
    graphviz
    numpyro
    pytestCheckHook
  ];
  disabledTests = [
    # attempt to fetch data:
    "test_data_is_copied"
    "test_predict_offset"
    # require blackjax (not in Nixpkgs), numpyro, and jax:
    "test_logistic_regression_categoric_alternative_samplers"
    "test_regression_alternative_samplers"
  ];

  pythonImportsCheck = [ "bambi" ];
+62 −0
Original line number Diff line number Diff line
{ lib
, buildPythonPackage
, pythonOlder
, fetchFromGitHub
, fetchpatch
, pytestCheckHook
, fastprogress
, jax
, jaxlib
, jaxopt
, optax
, typing-extensions
}:

buildPythonPackage rec {
  pname = "blackjax";
  version = "0.9.6";
  disabled = pythonOlder "3.7";

  src = fetchFromGitHub {
    owner = "blackjax-devs";
    repo = pname;
    rev = "refs/tags/${version}";
    hash = "sha256-EieDu9SJxi2cp1bHlxX4vvFZeDGMGIm24GoR8nSyjvE=";
  };

  patches = [
    # remove in next release
    (fetchpatch {
      name = "fix-lbfgs-args";
      url = "https://github.com/blackjax-devs/blackjax/commit/1aaa6f64bbcb0557b658604b2daba826e260cbc6.patch";
      hash = "sha256-XyjorXPH5Ap35Tv1/lTeTWamjplJF29SsvOq59ypftE=";
    })
  ];

  propagatedBuildInputs = [
    fastprogress
    jax
    jaxlib
    jaxopt
    optax
    typing-extensions
  ];

  nativeCheckInputs = [ pytestCheckHook ];
  disabledTestPaths = [ "tests/test_benchmarks.py" ];
  disabledTests = [
    # too slow
    "test_adaptive_tempered_smc"
  ];

  pythonImportsCheck = [
    "blackjax"
  ];

  meta = with lib; {
    homepage = "https://blackjax-devs.github.io/blackjax";
    description = "Sampling library designed for ease of use, speed and modularity";
    license = licenses.asl20;
    maintainers = with maintainers; [ bcdarwin ];
  };
}
+59 −0
Original line number Diff line number Diff line
{ lib
, buildPythonPackage
, pythonOlder
, fetchFromGitHub
, pytestCheckHook
, absl-py
, cvxpy
, jax
, jaxlib
, matplotlib
, numpy
, optax
, scipy
, scikitlearn
}:

buildPythonPackage rec {
  pname = "jaxopt";
  version = "0.5.5";
  disabled = pythonOlder "3.7";

  src = fetchFromGitHub {
    owner = "google";
    repo = pname;
    rev = "refs/tags/${pname}-v${version}";
    hash = "sha256-WOsr/Dvguu9/qX6+LMlAKM3EANtYPtDu8Uo2157+bs0=";
  };

  propagatedBuildInputs = [
    absl-py
    jax
    jaxlib
    matplotlib
    numpy
    scipy
  ];

  nativeCheckInputs = [
    pytestCheckHook
    cvxpy
    optax
    scikitlearn
  ];

  pythonImportsCheck = [
    "jaxopt"
    "jaxopt.implicit_diff"
    "jaxopt.linear_solve"
    "jaxopt.loss"
    "jaxopt.tree_util"
  ];

  meta = with lib; {
    homepage = "https://jaxopt.github.io";
    description = "Hardware accelerated, batchable and differentiable optimizers in JAX";
    license = licenses.asl20;
    maintainers = with maintainers; [ bcdarwin ];
  };
}
+4 −0
Original line number Diff line number Diff line
@@ -1300,6 +1300,8 @@ self: super: with self; {

  black = callPackage ../development/python-modules/black { };

  blackjax = callPackage ../development/python-modules/blackjax { };

  black-macchiato = callPackage ../development/python-modules/black-macchiato { };

  bleach = callPackage ../development/python-modules/bleach { };
@@ -4868,6 +4870,8 @@ self: super: with self; {
    cudaSupport = false;
  };

  jaxopt = callPackage ../development/python-modules/jaxopt { };

  JayDeBeApi = callPackage ../development/python-modules/JayDeBeApi { };

  jc = callPackage ../development/python-modules/jc { };