Unverified Commit 386cd594 authored by Nick Cao's avatar Nick Cao Committed by GitHub
Browse files

Merge pull request #252532 from GaetanLepage/jax

python310Packages.{jax,jaxlib,jaxlib-bin}: 0.4.14 -> 0.4.16
parents 79a23c6c 309c92d4
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -27,7 +27,7 @@ let
in
buildPythonPackage rec {
  pname = "jax";
  version = "0.4.14";
  version = "0.4.16";
  format = "pyproject";

  disabled = pythonOlder "3.9";
@@ -37,13 +37,17 @@ buildPythonPackage rec {
    repo = pname;
    # google/jax contains tags for jax and jaxlib. Only use jax tags!
    rev = "refs/tags/${pname}-v${version}";
    hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
    hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg=";
  };

  nativeBuildInputs = [
    setuptools
  ];

  # The version is automatically set to ".dev" if this variable is not set.
  # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
  JAX_RELEASE = "1";

  # jaxlib is _not_ included in propagatedBuildInputs because there are
  # different versions of jaxlib depending on the desired target hardware. The
  # JAX project ships separate wheels for CPU, GPU, and TPU.
+5 −5
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ in
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;

let
  version = "0.4.14";
  version = "0.4.16";

  inherit (python) pythonVersion;

@@ -60,15 +60,15 @@ let
    {
      "x86_64-linux" = getSrcFromPypi {
        platform = "manylinux2014_x86_64";
        hash = "sha256-nyylSZfqHeftlvVgJZFCN1ldjluZVJIYu4ZSsVxvXf8=";
        hash = "sha256-4XyaDnKEMhAbfPEvN3RCDEjXTWbOL6tWrTlyYeiboVs=";
      };
      "aarch64-darwin" = getSrcFromPypi {
        platform = "macosx_11_0_arm64";
        hash = "sha256-La3wYbGCjWTl7krBD6BaBRqyBD8R530Lckbz0AWv0FM=";
        hash = "sha256-IG2pCui/Yj+LDMbQwBVlu7yl2llqnaxMzz/MtBvBr6U=";
      };
      "x86_64-darwin" = getSrcFromPypi {
        platform = "macosx_10_14_x86_64";
        hash = "sha256-hDg5+qisgtgOrdvbjxsUgI73cW6Aah8NLjhPe4kMAsM=";
        hash = "sha256-x5DqsmHqEb7Dl7dnxT5N0l30GKt5OPZpq3HGX9MFKmo=";
      };
    };

@@ -78,7 +78,7 @@ let
  # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
  gpuSrc = fetchurl {
    url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
    hash = "sha256-CcQ5kjp4XfUX4/RwFY3T5G3kVKAeyoCTXu1Lo4O16Qo=";
    hash = "sha256-eLOprP2kv6roodwRKZXVZFQCD1wC26TSTEDJBjMu/Uo=";
  };

in
+9 −5
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ let
  inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;

  pname = "jaxlib";
  version = "0.4.14";
  version = "0.4.16";

  meta = with lib; {
    description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -151,7 +151,7 @@ let
      repo = "jax";
      # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
      rev = "refs/tags/${pname}-v${version}";
      hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
      hash = "sha256-q+8CXGxK8JX0bUMK4KJB3qV/EaLHg68D1B5UrtRz0Eg=";
    };

    nativeBuildInputs = [
@@ -203,6 +203,10 @@ let
    GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
    GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";

    # The version is automatically set to ".dev" if this variable is not set.
    # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
    JAXLIB_RELEASE = "1";

    preConfigure = ''
      # dummy ldconfig
      mkdir dummy-ldconfig
@@ -260,10 +264,10 @@ let
      ];

      sha256 = (if cudaSupport then {
        x86_64-linux = "sha256-L+d4umcN8eZQJS7NtbyMhFbbGUVd0a73GOYbZx3bW9Q=";
        x86_64-linux = "sha256-6HkrEWAPjGPj4zRxahl0FLiV7WZO/6zsdCX8STfV5EE=";
      } else {
        x86_64-linux = "sha256-V1giQbu70RYjbNsqudibiCgvtFNRIJ8XG75QtIzjM4g=";
        aarch64-linux = "sha256-DRU4aT7kQffhsOllgHtSlIsYOeLF4Sy5o5RR1CaTle0=";
        x86_64-linux = "sha256-MDnuJwJ/xKnC72Qub0ETYj5uQB2r8/AgGm10oqmzzcc=";
        aarch64-linux = "sha256-aVUm612VNEsjZLDrtiOPTqSk1t+AhmOx+pOG3bZdOAw=";
      }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
    };