Unverified Commit 9f38e1c7 authored by Jonas Chevalier's avatar Jonas Chevalier Committed by GitHub
Browse files

Merge pull request #323154 from zimbatm/jax-fixes

Jax fixes
parents 84ef8494 039cf118
Loading
Loading
Loading
Loading
+13 −8
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@
  fetchFromGitHub,
  jaxlib,
  jaxlib-bin,
  jaxlib-build,
  hypothesis,
  lapack,
  matplotlib,
@@ -23,10 +24,6 @@

let
  usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
  # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
  # fine. jaxlib is only used in the checkPhase, so switching backends does not
  # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
  jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
in
buildPythonPackage rec {
  pname = "jax";
@@ -61,7 +58,7 @@ buildPythonPackage rec {

  nativeCheckInputs = [
    hypothesis
    jaxlib'
    jaxlib
    matplotlib
    pytestCheckHook
    pytest-xdist
@@ -130,7 +127,11 @@ buildPythonPackage rec {
      "testQdwhWithOnRankDeficientInput5"
    ];

  disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
  disabledTestPaths = [
    # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
    "tests/linalg_test.py"
  ]
  ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
    # RuntimeWarning: invalid value encountered in cast
    "tests/lax_test.py"
  ];
@@ -147,7 +148,7 @@ buildPythonPackage rec {
  #   NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
  passthru.tests = {
    test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
      jaxlib = jaxlib.override { cudaSupport = true; };
      jaxlib = jaxlib-build.override { cudaSupport = true; };
    };
    test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
      jaxlib = jaxlib-bin.override { cudaSupport = true; };
@@ -158,7 +159,11 @@ buildPythonPackage rec {
  passthru.skipBulkUpdate = true;

  meta = with lib; {
    description = "Differentiate, compile, and transform Numpy code";
    description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
    longDescription = ''
      This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
      e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
    '';
    homepage = "https://github.com/google/jax";
    license = licenses.asl20;
    maintainers = with maintainers; [ samuela ];
+1 −1
Original line number Diff line number Diff line
@@ -225,7 +225,7 @@ buildPythonPackage {
  inherit (jaxlib-build) pythonImportsCheck;

  meta = with lib; {
    description = "XLA library for JAX";
    description = "Prebuilt jaxlib backend from PyPi";
    homepage = "https://github.com/google/jax";
    sourceProvenance = with sourceTypes; [ binaryNativeCode ];
    license = licenses.asl20;
+5 −4
Original line number Diff line number Diff line
@@ -67,16 +67,17 @@ let
  effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;

  meta = with lib; {
    description = "JAX is Autograd and XLA, brought together for high-performance machine learning research";
    description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research";
    homepage = "https://github.com/google/jax";
    license = licenses.asl20;
    maintainers = with maintainers; [ ndl ];
    platforms = platforms.unix;

    # Make this platforms.unix once Darwin is supported.
    # The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms.
    # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
    # however even with that fix applied, it doesn't work for everyone:
    # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
    # NOTE: We always build with NCCL; if it is unsupported, then our build is broken.
    broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
    platforms = platforms.linux;
  };

  # Bazel wants a merged cudnn at configuration time
+4 −3
Original line number Diff line number Diff line
@@ -6121,13 +6121,14 @@ self: super: with self; {
    IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
  };
  jaxlib = self.jaxlib-build;
  # Use the -bin on macOS since the source build doesn't support it (see #323154)
  jaxlib = if jaxlib-build.meta.unsupported then jaxlib-bin else jaxlib-build;
  jaxlibWithCuda = self.jaxlib-build.override {
  jaxlibWithCuda = self.jaxlib.override {
    cudaSupport = true;
  };
  jaxlibWithoutCuda = self.jaxlib-build.override {
  jaxlibWithoutCuda = self.jaxlib.override {
    cudaSupport = false;
  };