Loading pkgs/development/python-modules/jax/default.nix +4 −5 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.3.16"; version = "0.3.23"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -30,7 +30,7 @@ buildPythonPackage rec { owner = "google"; repo = pname; rev = "jax-v${version}"; hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I="; hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading Loading @@ -92,9 +92,8 @@ buildPythonPackage rec { "tests/sparse_test.py" ]; pythonImportsCheck = [ "jax" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading pkgs/development/python-modules/jaxlib/bin.nix +23 −44 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ # https://storage.googleapis.com/jax-releases/libtpu_releases.html. # For future reference, the easiest way to test the GPU backend is to run # NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }" # NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }" # export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 # python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" # python -c "from jax import random; random.PRNGKey(0)" Loading Loading @@ -35,46 +35,32 @@ let inherit (cudaPackages) cudatoolkit cudnn; in # There are no jaxlib wheels targeting cudnn <8.0.5, and although there are # wheels for cudatoolkit <11.1, we don't support them. assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; let version = "0.3.0"; version = "0.3.22"; pythonVersion = python.pythonVersion; # Find new releases at https://storage.googleapis.com/jax-releases. When # upgrading, you can get these hashes from prefetch.sh. # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is # the correct index. cpuSrcs = { "3.9" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4="; }; "3.10" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; "aarch64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0="; }; }; gpuSrcs = { "3.9-805" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; }; "3.9-82" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; }; "3.10-805" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; }; "3.10-82" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; }; gpuSrc = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs="; }; in buildPythonPackage rec { Loading @@ -82,23 +68,16 @@ buildPythonPackage rec { inherit version; format = "wheel"; # At the time of writing (2022-03-03), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs # python3 version, and 3.10. disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); # At the time of writing (2022-10-19), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs # python version. disabled = !(pythonVersion == "3.10"); src = if !cudaSupport then cpuSrcs."${pythonVersion}" else let # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and # 8.2. Try to use 8.2 whenever possible. cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; in gpuSrcs."${pythonVersion}-${cudnnVersion}"; src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath; nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc ]; Loading Loading @@ -142,6 +121,6 @@ buildPythonPackage rec { sourceProvenance = with sourceTypes; [ binaryNativeCode ]; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; platforms = [ "x86_64-linux" ]; platforms = [ "aarch64-darwin" "x86_64-linux" ]; }; } pkgs/development/python-modules/jaxlib/default.nix +8 −6 Original line number Diff line number Diff line Loading @@ -53,7 +53,7 @@ let inherit (cudaPackages) cudatoolkit cudnn nccl; pname = "jaxlib"; version = "0.3.15"; version = "0.3.22"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; Loading Loading @@ -96,7 +96,7 @@ let owner = "google"; repo = "jax"; rev = "${pname}-v${version}"; sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s="; hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs="; }; nativeBuildInputs = [ Loading Loading @@ -235,11 +235,11 @@ let fetchAttrs = { sha256 = if cudaSupport then "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk=" "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ=" else if stdenv.isDarwin then "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34=" "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM=" else "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A="; "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM="; }; buildAttrs = { Loading Loading @@ -293,7 +293,9 @@ buildPythonPackage { inherit meta pname version; format = "wheel"; src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl"; src = let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}"; in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for Loading pkgs/top-level/python-packages.nix +0 −6 Original line number Diff line number Diff line Loading @@ -4780,9 +4780,6 @@ in { jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { cudaSupport = pkgs.config.cudaSupport or false; # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we # pin to `cudaPackages_11_6` instead. cudaPackages = pkgs.cudaPackages_11_6; }; jaxlib-build = callPackage ../development/python-modules/jaxlib rec { Loading @@ -4792,9 +4789,6 @@ in { }; # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. cudaSupport = pkgs.config.cudaSupport or false; # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we # pin to `cudaPackages_11_6` instead. cudaPackages = pkgs.cudaPackages_11_6; IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21 }; Loading Loading
pkgs/development/python-modules/jax/default.nix +4 −5 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.3.16"; version = "0.3.23"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -30,7 +30,7 @@ buildPythonPackage rec { owner = "google"; repo = pname; rev = "jax-v${version}"; hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I="; hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading Loading @@ -92,9 +92,8 @@ buildPythonPackage rec { "tests/sparse_test.py" ]; pythonImportsCheck = [ "jax" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading
pkgs/development/python-modules/jaxlib/bin.nix +23 −44 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ # https://storage.googleapis.com/jax-releases/libtpu_releases.html. # For future reference, the easiest way to test the GPU backend is to run # NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }" # NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib-bin.override { cudaSupport = true; }" # export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 # python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" # python -c "from jax import random; random.PRNGKey(0)" Loading Loading @@ -35,46 +35,32 @@ let inherit (cudaPackages) cudatoolkit cudnn; in # There are no jaxlib wheels targeting cudnn <8.0.5, and although there are # wheels for cudatoolkit <11.1, we don't support them. assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; let version = "0.3.0"; version = "0.3.22"; pythonVersion = python.pythonVersion; # Find new releases at https://storage.googleapis.com/jax-releases. When # upgrading, you can get these hashes from prefetch.sh. # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is # the correct index. cpuSrcs = { "3.9" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; hash = "sha256-AfBVqoqChEXlEC5PgbtQ5rQzcbwo558fjqCjSPEmN5Q="; "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4="; }; "3.10" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl"; hash = "sha256-9uBkFOO8LlRpO6AP+S8XK9/d2yRdyHxQGlbAjShqHRQ="; "aarch64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0="; }; }; gpuSrcs = { "3.9-805" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; hash = "sha256-CArIhzM5FrQi3TkdqpUqCeDQYyDMVXlzKFgjNXjLJXw="; }; "3.9-82" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; hash = "sha256-Q0plVnA9pUNQ+gCHSXiLNs4i24xCg8gBGfgfYe3bot4="; }; "3.10-805" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl"; hash = "sha256-JopevCEAs0hgDngIId6NqbLam5YfcS8Lr9cEffBKp1U="; }; "3.10-82" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl"; hash = "sha256-2f5TwbdP7EfQNRM3ZcJXCAkS2VXBwNYH6gwT9pdu3Go="; }; gpuSrc = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs="; }; in buildPythonPackage rec { Loading @@ -82,23 +68,16 @@ buildPythonPackage rec { inherit version; format = "wheel"; # At the time of writing (2022-03-03), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.9, the current nixpkgs # python3 version, and 3.10. disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10"); # At the time of writing (2022-10-19), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs # python version. disabled = !(pythonVersion == "3.10"); src = if !cudaSupport then cpuSrcs."${pythonVersion}" else let # jaxlib wheels are currently provided for cudnn versions at least 8.0.5 and # 8.2. Try to use 8.2 whenever possible. cudnnVersion = if (lib.versionAtLeast cudnn.version "8.2") then "82" else "805"; in gpuSrcs."${pythonVersion}-${cudnnVersion}"; src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath; nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc ]; Loading Loading @@ -142,6 +121,6 @@ buildPythonPackage rec { sourceProvenance = with sourceTypes; [ binaryNativeCode ]; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; platforms = [ "x86_64-linux" ]; platforms = [ "aarch64-darwin" "x86_64-linux" ]; }; }
pkgs/development/python-modules/jaxlib/default.nix +8 −6 Original line number Diff line number Diff line Loading @@ -53,7 +53,7 @@ let inherit (cudaPackages) cudatoolkit cudnn nccl; pname = "jaxlib"; version = "0.3.15"; version = "0.3.22"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; Loading Loading @@ -96,7 +96,7 @@ let owner = "google"; repo = "jax"; rev = "${pname}-v${version}"; sha256 = "sha256-pIl7zzl82w5HHnJadH2vtCT4mYFd5YmM9iHC2GoJD6s="; hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs="; }; nativeBuildInputs = [ Loading Loading @@ -235,11 +235,11 @@ let fetchAttrs = { sha256 = if cudaSupport then "sha256-tdO4YjO985zbittb16RFWgxgUBrHYQfv5gRsA4IAkTk=" "sha256-Z9GDWGv+1YFyJjudyshZfeRJsKShoA1kIbNR3h3GxPQ=" else if stdenv.isDarwin then "sha256-+XYxfXBCASueqDGg0Zqcmpf7zmemYM6xCE+x0rl3j34=" "sha256-i3wiJHD4+pgTvDMhnYiQo9pdxxKItgYnc4/4wGt2NXM=" else "sha256-La1wC8X5aGK5mXvYy/kO8n4J+zaRZEc/DAX5zaH1D5A="; "sha256-liRxmjwm0OmVMfgoGXx+nGBdW2fzzP/d4zmK6A59HAM="; }; buildAttrs = { Loading Loading @@ -293,7 +293,9 @@ buildPythonPackage { inherit meta pname version; format = "wheel"; src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-${platformTag}.whl"; src = let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}"; in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; # Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for Loading
pkgs/top-level/python-packages.nix +0 −6 Original line number Diff line number Diff line Loading @@ -4780,9 +4780,6 @@ in { jaxlib-bin = callPackage ../development/python-modules/jaxlib/bin.nix { cudaSupport = pkgs.config.cudaSupport or false; # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we # pin to `cudaPackages_11_6` instead. cudaPackages = pkgs.cudaPackages_11_6; }; jaxlib-build = callPackage ../development/python-modules/jaxlib rec { Loading @@ -4792,9 +4789,6 @@ in { }; # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. cudaSupport = pkgs.config.cudaSupport or false; # At the time of writing (2022-04-18), `cudaPackages.nccl` is broken, so we # pin to `cudaPackages_11_6` instead. cudaPackages = pkgs.cudaPackages_11_6; IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21 }; Loading