Loading pkgs/build-support/build-bazel-package/default.nix +29 −10 Original line number Diff line number Diff line Loading @@ -10,9 +10,12 @@ args@{ , bazelFlags ? [] , bazelBuildFlags ? [] , bazelTestFlags ? [] , bazelRunFlags ? [] , runTargetFlags ? [] , bazelFetchFlags ? [] , bazelTargets , bazelTargets ? [] , bazelTestTargets ? [] , bazelRunTarget ? null , buildAttrs , fetchAttrs Loading Loading @@ -46,17 +49,23 @@ args@{ let fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // { name = name; bazelFlags = bazelFlags; bazelBuildFlags = bazelBuildFlags; bazelTestFlags = bazelTestFlags; bazelFetchFlags = bazelFetchFlags; bazelTestTargets = bazelTestTargets; dontAddBazelOpts = dontAddBazelOpts; inherit name bazelFlags bazelBuildFlags bazelTestFlags bazelRunFlags runTargetFlags bazelFetchFlags bazelTargets bazelTestTargets bazelRunTarget dontAddBazelOpts ; }; fBuildAttrs = fArgs // buildAttrs; fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ]; bazelCmd = { cmd, additionalFlags, targets }: bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }: lib.optionalString (targets != [ ]) '' # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables] BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \ Loading @@ -73,7 +82,8 @@ let "''${host_linkopts[@]}" \ $bazelFlags \ ${lib.strings.concatStringsSep " " additionalFlags} \ ${lib.strings.concatStringsSep " " targets} ${lib.strings.concatStringsSep " " targets} \ ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags} ''; # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so: # chmod: cannot operate on dangling symlink '$symlink' Loading Loading @@ -262,6 +272,15 @@ stdenv.mkDerivation (fBuildAttrs // { targets = fBuildAttrs.bazelTargets; } } ${ bazelCmd { cmd = "run"; additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ]; # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list. targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ]; targetRunFlags = fBuildAttrs.runTargetFlags; } } runHook postBuild ''; }) Loading pkgs/development/python-modules/jax/default.nix +6 −16 Original line number Diff line number Diff line Loading @@ -8,6 +8,7 @@ , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -27,7 +28,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.5"; version = "0.4.12"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -37,7 +38,7 @@ buildPythonPackage rec { repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading @@ -46,6 +47,7 @@ buildPythonPackage rec { propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy Loading Loading @@ -96,24 +98,12 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; pythonImportsCheck = [ "jax" ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading pkgs/development/python-modules/jaxlib/bin.nix +51 −33 Original line number Diff line number Diff line Loading @@ -18,11 +18,12 @@ , autoPatchelfHook , buildPythonPackage , config , cudnn ? cudaPackages.cudnn , fetchPypi , fetchurl , flatbuffers , isPy39 , jaxlib , lib , ml-dtypes , python , scipy , stdenv Loading @@ -35,46 +36,57 @@ let inherit (cudaPackages) cudatoolkit cudnn; in assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; let version = "0.4.4"; version = "0.4.12"; pythonVersion = python.pythonVersion; inherit (python) pythonVersion; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is # the correct index. cpuSrcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ="; # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the # official instructions recommend installing CPU-only versions via PyPI. cpuSrcs = let getSrcFromPypi = { platform, hash }: fetchPypi { inherit version platform hash; pname = "jaxlib"; format = "wheel"; # See the `disabled` attr comment below. dist = "cp310"; python = "cp310"; abi = "cp310"; }; "aarch64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U="; in { "x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc="; }; "x86_64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl"; hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok="; "aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo="; }; "x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c="; }; }; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. gpuSrc = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk="; url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo="; }; in buildPythonPackage rec { buildPythonPackage { pname = "jaxlib"; inherit version; format = "wheel"; # At the time of writing (2022-10-19), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs # python version. disabled = !(pythonVersion == "3.10"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. Loading @@ -87,9 +99,10 @@ buildPythonPackage rec { # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc ]; buildInputs = [ stdenv.cc.cc.lib ]; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or Loading @@ -113,7 +126,12 @@ buildPythonPackage rec { done ''; propagatedBuildInputs = [ absl-py flatbuffers scipy ]; propagatedBuildInputs = [ absl-py flatbuffers ml-dtypes scipy ]; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for Loading @@ -123,7 +141,7 @@ buildPythonPackage rec { ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ''; pythonImportsCheck = [ "jaxlib" ]; inherit (jaxlib) pythonImportsCheck; meta = with lib; { description = "XLA library for JAX"; Loading pkgs/development/python-modules/jaxlib/default.nix +32 −36 Original line number Diff line number Diff line Loading @@ -4,7 +4,7 @@ # Build-time dependencies: , addOpenGLRunpath , bazel_5 , bazel_6 , binutils , buildBazelPackage , buildPythonPackage Loading @@ -26,6 +26,7 @@ # Python dependencies: , absl-py , flatbuffers , ml-dtypes , numpy , scipy , six Loading @@ -35,7 +36,6 @@ , giflib , grpc , libjpeg_turbo , protobuf , python , snappy , zlib Loading @@ -53,7 +53,7 @@ let inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; pname = "jaxlib"; version = "0.4.4"; version = "0.4.12"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; Loading Loading @@ -138,14 +138,15 @@ let bazel-build = buildBazelPackage rec { name = "bazel-build-${pname}-${version}"; bazel = bazel_5; # See https://github.com/google/jax/blob/main/.bazelversion for the latest. bazel = bazel_6; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo="; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; }; nativeBuildInputs = [ Loading @@ -169,7 +170,7 @@ let numpy openssl pkgs.flatbuffers protobuf pkgs.protobuf pybind11 scipy six Loading @@ -188,7 +189,8 @@ let rm -f .bazelversion ''; bazelTargets = [ "//build:build_wheel" ]; bazelRunTarget = "//build:build_wheel"; runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ]; removeRulesCC = false; Loading @@ -207,7 +209,11 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include" build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) '' build --config=avx_posix '' + lib.optionalString mklSupport '' build --config=mkl_open_source_only '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" Loading @@ -234,7 +240,7 @@ let fetchAttrs = { TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; # we have to force @mkl_dnn_v1 since it's not needed on darwin bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ]; bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; bazelFlags = bazelFlags ++ [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ Loading @@ -249,9 +255,9 @@ let sha256 = if cudaSupport then "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk=" "sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4=" else "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI="; "sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc="; }; buildAttrs = { Loading @@ -261,25 +267,13 @@ let "nsync" # fails to build on darwin ]); bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ "--config=cuda" ] ++ lib.optionals mklSupport [ "--config=mkl_open_source_only" ]; # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. # 1) Fix pybind11 include paths. # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. # 3) Patch python path in the compiler driver. preBuild = '' for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done '' + lib.optionalString cudaSupport '' # 2) Patch python path in the compiler driver. preBuild = lib.optionalString cudaSupport '' export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' # Framework search paths aren't added by bintools hook # https://github.com/NixOS/nixpkgs/pull/41914 Loading @@ -289,16 +283,12 @@ let substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ --replace "/usr/bin/libtool" "${cctools}/bin/libtool" '' + (if stdenv.cc.isGNU then '' sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD '' else if stdenv.cc.isClang then '' sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); installPhase = '' ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch} ''; }; inherit meta; Loading Loading @@ -345,13 +335,19 @@ buildPythonPackage { grpc jsoncpp libjpeg_turbo ml-dtypes numpy scipy six snappy ]; pythonImportsCheck = [ "jaxlib" ]; pythonImportsCheck = [ "jaxlib" # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. "jaxlib.cpu_feature_guard" "jaxlib.xla_client" ]; # Without it there are complaints about libcudart.so.11.0 not being found # because RPATH path entries added above are stripped. Loading pkgs/development/python-modules/jaxlib/prefetch.sh +15 −7 Original line number Diff line number Diff line version="$1" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)" #!/usr/bin/env bash prefetch () { expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) echo "$url" sha256=$(nix-prefetch-url "$url") nix hash to-sri --type sha256 "$sha256" echo } prefetch "x86_64-linux" "false" prefetch "aarch64-darwin" "false" prefetch "x86_64-darwin" "false" prefetch "x86_64-linux" "true" Loading
pkgs/build-support/build-bazel-package/default.nix +29 −10 Original line number Diff line number Diff line Loading @@ -10,9 +10,12 @@ args@{ , bazelFlags ? [] , bazelBuildFlags ? [] , bazelTestFlags ? [] , bazelRunFlags ? [] , runTargetFlags ? [] , bazelFetchFlags ? [] , bazelTargets , bazelTargets ? [] , bazelTestTargets ? [] , bazelRunTarget ? null , buildAttrs , fetchAttrs Loading Loading @@ -46,17 +49,23 @@ args@{ let fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // { name = name; bazelFlags = bazelFlags; bazelBuildFlags = bazelBuildFlags; bazelTestFlags = bazelTestFlags; bazelFetchFlags = bazelFetchFlags; bazelTestTargets = bazelTestTargets; dontAddBazelOpts = dontAddBazelOpts; inherit name bazelFlags bazelBuildFlags bazelTestFlags bazelRunFlags runTargetFlags bazelFetchFlags bazelTargets bazelTestTargets bazelRunTarget dontAddBazelOpts ; }; fBuildAttrs = fArgs // buildAttrs; fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ]; bazelCmd = { cmd, additionalFlags, targets }: bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }: lib.optionalString (targets != [ ]) '' # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables] BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \ Loading @@ -73,7 +82,8 @@ let "''${host_linkopts[@]}" \ $bazelFlags \ ${lib.strings.concatStringsSep " " additionalFlags} \ ${lib.strings.concatStringsSep " " targets} ${lib.strings.concatStringsSep " " targets} \ ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags} ''; # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so: # chmod: cannot operate on dangling symlink '$symlink' Loading Loading @@ -262,6 +272,15 @@ stdenv.mkDerivation (fBuildAttrs // { targets = fBuildAttrs.bazelTargets; } } ${ bazelCmd { cmd = "run"; additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ]; # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list. targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ]; targetRunFlags = fBuildAttrs.runTargetFlags; } } runHook postBuild ''; }) Loading
pkgs/development/python-modules/jax/default.nix +6 −16 Original line number Diff line number Diff line Loading @@ -8,6 +8,7 @@ , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -27,7 +28,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.5"; version = "0.4.12"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -37,7 +38,7 @@ buildPythonPackage rec { repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading @@ -46,6 +47,7 @@ buildPythonPackage rec { propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy Loading Loading @@ -96,24 +98,12 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; pythonImportsCheck = [ "jax" ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading
pkgs/development/python-modules/jaxlib/bin.nix +51 −33 Original line number Diff line number Diff line Loading @@ -18,11 +18,12 @@ , autoPatchelfHook , buildPythonPackage , config , cudnn ? cudaPackages.cudnn , fetchPypi , fetchurl , flatbuffers , isPy39 , jaxlib , lib , ml-dtypes , python , scipy , stdenv Loading @@ -35,46 +36,57 @@ let inherit (cudaPackages) cudatoolkit cudnn; in assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; let version = "0.4.4"; version = "0.4.12"; pythonVersion = python.pythonVersion; inherit (python) pythonVersion; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is # the correct index. cpuSrcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ="; # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the # official instructions recommend installing CPU-only versions via PyPI. cpuSrcs = let getSrcFromPypi = { platform, hash }: fetchPypi { inherit version platform hash; pname = "jaxlib"; format = "wheel"; # See the `disabled` attr comment below. dist = "cp310"; python = "cp310"; abi = "cp310"; }; "aarch64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U="; in { "x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc="; }; "x86_64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl"; hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok="; "aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo="; }; "x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c="; }; }; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. gpuSrc = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk="; url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo="; }; in buildPythonPackage rec { buildPythonPackage { pname = "jaxlib"; inherit version; format = "wheel"; # At the time of writing (2022-10-19), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs # python version. disabled = !(pythonVersion == "3.10"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. Loading @@ -87,9 +99,10 @@ buildPythonPackage rec { # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc ]; buildInputs = [ stdenv.cc.cc.lib ]; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or Loading @@ -113,7 +126,12 @@ buildPythonPackage rec { done ''; propagatedBuildInputs = [ absl-py flatbuffers scipy ]; propagatedBuildInputs = [ absl-py flatbuffers ml-dtypes scipy ]; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for Loading @@ -123,7 +141,7 @@ buildPythonPackage rec { ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ''; pythonImportsCheck = [ "jaxlib" ]; inherit (jaxlib) pythonImportsCheck; meta = with lib; { description = "XLA library for JAX"; Loading
pkgs/development/python-modules/jaxlib/default.nix +32 −36 Original line number Diff line number Diff line Loading @@ -4,7 +4,7 @@ # Build-time dependencies: , addOpenGLRunpath , bazel_5 , bazel_6 , binutils , buildBazelPackage , buildPythonPackage Loading @@ -26,6 +26,7 @@ # Python dependencies: , absl-py , flatbuffers , ml-dtypes , numpy , scipy , six Loading @@ -35,7 +36,6 @@ , giflib , grpc , libjpeg_turbo , protobuf , python , snappy , zlib Loading @@ -53,7 +53,7 @@ let inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; pname = "jaxlib"; version = "0.4.4"; version = "0.4.12"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; Loading Loading @@ -138,14 +138,15 @@ let bazel-build = buildBazelPackage rec { name = "bazel-build-${pname}-${version}"; bazel = bazel_5; # See https://github.com/google/jax/blob/main/.bazelversion for the latest. bazel = bazel_6; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo="; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; }; nativeBuildInputs = [ Loading @@ -169,7 +170,7 @@ let numpy openssl pkgs.flatbuffers protobuf pkgs.protobuf pybind11 scipy six Loading @@ -188,7 +189,8 @@ let rm -f .bazelversion ''; bazelTargets = [ "//build:build_wheel" ]; bazelRunTarget = "//build:build_wheel"; runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ]; removeRulesCC = false; Loading @@ -207,7 +209,11 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include" build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) '' build --config=avx_posix '' + lib.optionalString mklSupport '' build --config=mkl_open_source_only '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" Loading @@ -234,7 +240,7 @@ let fetchAttrs = { TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; # we have to force @mkl_dnn_v1 since it's not needed on darwin bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ]; bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; bazelFlags = bazelFlags ++ [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ Loading @@ -249,9 +255,9 @@ let sha256 = if cudaSupport then "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk=" "sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4=" else "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI="; "sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc="; }; buildAttrs = { Loading @@ -261,25 +267,13 @@ let "nsync" # fails to build on darwin ]); bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ "--config=cuda" ] ++ lib.optionals mklSupport [ "--config=mkl_open_source_only" ]; # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. # 1) Fix pybind11 include paths. # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. # 3) Patch python path in the compiler driver. preBuild = '' for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done '' + lib.optionalString cudaSupport '' # 2) Patch python path in the compiler driver. preBuild = lib.optionalString cudaSupport '' export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' # Framework search paths aren't added by bintools hook # https://github.com/NixOS/nixpkgs/pull/41914 Loading @@ -289,16 +283,12 @@ let substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ --replace "/usr/bin/libtool" "${cctools}/bin/libtool" '' + (if stdenv.cc.isGNU then '' sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD '' else if stdenv.cc.isClang then '' sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); installPhase = '' ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch} ''; }; inherit meta; Loading Loading @@ -345,13 +335,19 @@ buildPythonPackage { grpc jsoncpp libjpeg_turbo ml-dtypes numpy scipy six snappy ]; pythonImportsCheck = [ "jaxlib" ]; pythonImportsCheck = [ "jaxlib" # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. "jaxlib.cpu_feature_guard" "jaxlib.xla_client" ]; # Without it there are complaints about libcudart.so.11.0 not being found # because RPATH path entries added above are stripped. Loading
pkgs/development/python-modules/jaxlib/prefetch.sh +15 −7 Original line number Diff line number Diff line version="$1" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)" #!/usr/bin/env bash prefetch () { expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) echo "$url" sha256=$(nix-prefetch-url "$url") nix hash to-sri --type sha256 "$sha256" echo } prefetch "x86_64-linux" "false" prefetch "aarch64-darwin" "false" prefetch "x86_64-darwin" "false" prefetch "x86_64-linux" "true"