Loading pkgs/build-support/build-bazel-package/default.nix +10 −29 Original line number Diff line number Diff line Loading @@ -10,12 +10,9 @@ args@{ , bazelFlags ? [] , bazelBuildFlags ? [] , bazelTestFlags ? [] , bazelRunFlags ? [] , runTargetFlags ? [] , bazelFetchFlags ? [] , bazelTargets ? [] , bazelTargets , bazelTestTargets ? [] , bazelRunTarget ? null , buildAttrs , fetchAttrs Loading Loading @@ -49,23 +46,17 @@ args@{ let fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // { inherit name bazelFlags bazelBuildFlags bazelTestFlags bazelRunFlags runTargetFlags bazelFetchFlags bazelTargets bazelTestTargets bazelRunTarget dontAddBazelOpts ; name = name; bazelFlags = bazelFlags; bazelBuildFlags = bazelBuildFlags; bazelTestFlags = bazelTestFlags; bazelFetchFlags = bazelFetchFlags; bazelTestTargets = bazelTestTargets; dontAddBazelOpts = dontAddBazelOpts; }; fBuildAttrs = fArgs // buildAttrs; fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ]; bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }: bazelCmd = { cmd, additionalFlags, targets }: lib.optionalString (targets != [ ]) '' # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables] BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \ Loading @@ -82,8 +73,7 @@ let "''${host_linkopts[@]}" \ $bazelFlags \ ${lib.strings.concatStringsSep " " additionalFlags} \ ${lib.strings.concatStringsSep " " targets} \ ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags} ${lib.strings.concatStringsSep " " targets} ''; # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so: # chmod: cannot operate on dangling symlink '$symlink' Loading Loading @@ -272,15 +262,6 @@ stdenv.mkDerivation (fBuildAttrs // { targets = fBuildAttrs.bazelTargets; } } ${ bazelCmd { cmd = "run"; additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ]; # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list. targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ]; targetRunFlags = fBuildAttrs.runTargetFlags; } } runHook postBuild ''; }) Loading pkgs/development/python-modules/jax/default.nix +16 −6 Original line number Diff line number Diff line Loading @@ -8,7 +8,6 @@ , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -28,7 +27,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.12"; version = "0.4.5"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -38,7 +37,7 @@ buildPythonPackage rec { repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading @@ -47,7 +46,6 @@ buildPythonPackage rec { propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy Loading Loading @@ -98,12 +96,24 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; pythonImportsCheck = [ "jax" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading pkgs/development/python-modules/jaxlib/bin.nix +33 −51 Original line number Diff line number Diff line Loading @@ -18,12 +18,11 @@ , autoPatchelfHook , buildPythonPackage , config , fetchPypi , cudnn ? cudaPackages.cudnn , fetchurl , flatbuffers , jaxlib , isPy39 , lib , ml-dtypes , python , scipy , stdenv Loading @@ -36,57 +35,46 @@ let inherit (cudaPackages) cudatoolkit cudnn; in assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; let version = "0.4.12"; version = "0.4.4"; inherit (python) pythonVersion; pythonVersion = python.pythonVersion; # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the # official instructions recommend installing CPU-only versions via PyPI. cpuSrcs = let getSrcFromPypi = { platform, hash }: fetchPypi { inherit version platform hash; pname = "jaxlib"; format = "wheel"; # See the `disabled` attr comment below. dist = "cp310"; python = "cp310"; abi = "cp310"; }; in { "x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc="; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is # the correct index. cpuSrcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ="; }; "aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo="; "aarch64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U="; }; "x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c="; "x86_64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl"; hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok="; }; }; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. gpuSrc = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo="; url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk="; }; in buildPythonPackage { buildPythonPackage rec { pname = "jaxlib"; inherit version; format = "wheel"; # At the time of writing (2022-10-19), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs # python version. disabled = !(pythonVersion == "3.10"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. Loading @@ -99,10 +87,9 @@ buildPythonPackage { # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc.lib ]; buildInputs = [ stdenv.cc.cc ]; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or Loading @@ -126,12 +113,7 @@ buildPythonPackage { done ''; propagatedBuildInputs = [ absl-py flatbuffers ml-dtypes scipy ]; propagatedBuildInputs = [ absl-py flatbuffers scipy ]; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for Loading @@ -141,7 +123,7 @@ buildPythonPackage { ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ''; inherit (jaxlib) pythonImportsCheck; pythonImportsCheck = [ "jaxlib" ]; meta = with lib; { description = "XLA library for JAX"; Loading pkgs/development/python-modules/jaxlib/default.nix +36 −32 Original line number Diff line number Diff line Loading @@ -4,7 +4,7 @@ # Build-time dependencies: , addOpenGLRunpath , bazel_6 , bazel_5 , binutils , buildBazelPackage , buildPythonPackage Loading @@ -26,7 +26,6 @@ # Python dependencies: , absl-py , flatbuffers , ml-dtypes , numpy , scipy , six Loading @@ -36,6 +35,7 @@ , giflib , grpc , libjpeg_turbo , protobuf , python , snappy , zlib Loading @@ -53,7 +53,7 @@ let inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; pname = "jaxlib"; version = "0.4.12"; version = "0.4.4"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; Loading Loading @@ -138,15 +138,14 @@ let bazel-build = buildBazelPackage rec { name = "bazel-build-${pname}-${version}"; # See https://github.com/google/jax/blob/main/.bazelversion for the latest. bazel = bazel_6; bazel = bazel_5; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo="; }; nativeBuildInputs = [ Loading @@ -170,7 +169,7 @@ let numpy openssl pkgs.flatbuffers pkgs.protobuf protobuf pybind11 scipy six Loading @@ -189,8 +188,7 @@ let rm -f .bazelversion ''; bazelRunTarget = "//build:build_wheel"; runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ]; bazelTargets = [ "//build:build_wheel" ]; removeRulesCC = false; Loading @@ -209,11 +207,7 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) '' build --config=avx_posix '' + lib.optionalString mklSupport '' build --config=mkl_open_source_only build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include" '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" Loading @@ -240,7 +234,7 @@ let fetchAttrs = { TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; # we have to force @mkl_dnn_v1 since it's not needed on darwin bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ]; bazelFlags = bazelFlags ++ [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ Loading @@ -255,9 +249,9 @@ let sha256 = if cudaSupport then "sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4=" "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk=" else "sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc="; "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI="; }; buildAttrs = { Loading @@ -267,13 +261,25 @@ let "nsync" # fails to build on darwin ]); bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ "--config=cuda" ] ++ lib.optionals mklSupport [ "--config=mkl_open_source_only" ]; # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # 1) Fix pybind11 include paths. # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. # 2) Patch python path in the compiler driver. preBuild = lib.optionalString cudaSupport '' # 3) Patch python path in the compiler driver. preBuild = '' for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done '' + lib.optionalString cudaSupport '' export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' # Framework search paths aren't added by bintools hook # https://github.com/NixOS/nixpkgs/pull/41914 Loading @@ -283,12 +289,16 @@ let substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ --replace "/usr/bin/libtool" "${cctools}/bin/libtool" '' + (if stdenv.cc.isGNU then '' sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD '' else if stdenv.cc.isClang then '' sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); installPhase = '' ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch} ''; }; inherit meta; Loading Loading @@ -335,19 +345,13 @@ buildPythonPackage { grpc jsoncpp libjpeg_turbo ml-dtypes numpy scipy six snappy ]; pythonImportsCheck = [ "jaxlib" # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. "jaxlib.cpu_feature_guard" "jaxlib.xla_client" ]; pythonImportsCheck = [ "jaxlib" ]; # Without it there are complaints about libcudart.so.11.0 not being found # because RPATH path entries added above are stripped. Loading pkgs/development/python-modules/jaxlib/prefetch.sh +7 −15 Original line number Diff line number Diff line #!/usr/bin/env bash prefetch () { expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) echo "$url" sha256=$(nix-prefetch-url "$url") nix hash to-sri --type sha256 "$sha256" echo } prefetch "x86_64-linux" "false" prefetch "aarch64-darwin" "false" prefetch "x86_64-darwin" "false" prefetch "x86_64-linux" "true" version="$1" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)" Loading
pkgs/build-support/build-bazel-package/default.nix +10 −29 Original line number Diff line number Diff line Loading @@ -10,12 +10,9 @@ args@{ , bazelFlags ? [] , bazelBuildFlags ? [] , bazelTestFlags ? [] , bazelRunFlags ? [] , runTargetFlags ? [] , bazelFetchFlags ? [] , bazelTargets ? [] , bazelTargets , bazelTestTargets ? [] , bazelRunTarget ? null , buildAttrs , fetchAttrs Loading Loading @@ -49,23 +46,17 @@ args@{ let fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // { inherit name bazelFlags bazelBuildFlags bazelTestFlags bazelRunFlags runTargetFlags bazelFetchFlags bazelTargets bazelTestTargets bazelRunTarget dontAddBazelOpts ; name = name; bazelFlags = bazelFlags; bazelBuildFlags = bazelBuildFlags; bazelTestFlags = bazelTestFlags; bazelFetchFlags = bazelFetchFlags; bazelTestTargets = bazelTestTargets; dontAddBazelOpts = dontAddBazelOpts; }; fBuildAttrs = fArgs // buildAttrs; fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ]; bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }: bazelCmd = { cmd, additionalFlags, targets }: lib.optionalString (targets != [ ]) '' # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables] BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \ Loading @@ -82,8 +73,7 @@ let "''${host_linkopts[@]}" \ $bazelFlags \ ${lib.strings.concatStringsSep " " additionalFlags} \ ${lib.strings.concatStringsSep " " targets} \ ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags} ${lib.strings.concatStringsSep " " targets} ''; # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so: # chmod: cannot operate on dangling symlink '$symlink' Loading Loading @@ -272,15 +262,6 @@ stdenv.mkDerivation (fBuildAttrs // { targets = fBuildAttrs.bazelTargets; } } ${ bazelCmd { cmd = "run"; additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ]; # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list. targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ]; targetRunFlags = fBuildAttrs.runTargetFlags; } } runHook postBuild ''; }) Loading
pkgs/development/python-modules/jax/default.nix +16 −6 Original line number Diff line number Diff line Loading @@ -8,7 +8,6 @@ , jaxlib-bin , lapack , matplotlib , ml-dtypes , numpy , opt-einsum , pytestCheckHook Loading @@ -28,7 +27,7 @@ let in buildPythonPackage rec { pname = "jax"; version = "0.4.12"; version = "0.4.5"; format = "setuptools"; disabled = pythonOlder "3.7"; Loading @@ -38,7 +37,7 @@ buildPythonPackage rec { repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are Loading @@ -47,7 +46,6 @@ buildPythonPackage rec { propagatedBuildInputs = [ absl-py etils ml-dtypes numpy opt-einsum scipy Loading Loading @@ -98,12 +96,24 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # See https://github.com/google/jax/issues/11722. This is a temporary fix in # order to unblock etils, and upgrading jax/jaxlib to the latest version. See # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. disabledTestPaths = [ "tests/api_test.py" "tests/core_test.py" "tests/lax_numpy_indexing_test.py" "tests/lax_numpy_test.py" "tests/nn_test.py" "tests/random_test.py" "tests/sparse_test.py" ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; pythonImportsCheck = [ "jax" ]; # As of 0.3.22, `import jax` does not work without jaxlib being installed. pythonImportsCheck = [ ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; Loading
pkgs/development/python-modules/jaxlib/bin.nix +33 −51 Original line number Diff line number Diff line Loading @@ -18,12 +18,11 @@ , autoPatchelfHook , buildPythonPackage , config , fetchPypi , cudnn ? cudaPackages.cudnn , fetchurl , flatbuffers , jaxlib , isPy39 , lib , ml-dtypes , python , scipy , stdenv Loading @@ -36,57 +35,46 @@ let inherit (cudaPackages) cudatoolkit cudnn; in assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; let version = "0.4.12"; version = "0.4.4"; inherit (python) pythonVersion; pythonVersion = python.pythonVersion; # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the # official instructions recommend installing CPU-only versions via PyPI. cpuSrcs = let getSrcFromPypi = { platform, hash }: fetchPypi { inherit version platform hash; pname = "jaxlib"; format = "wheel"; # See the `disabled` attr comment below. dist = "cp310"; python = "cp310"; abi = "cp310"; }; in { "x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc="; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is # the correct index. cpuSrcs = { "x86_64-linux" = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ="; }; "aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo="; "aarch64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U="; }; "x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c="; "x86_64-darwin" = fetchurl { url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl"; hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok="; }; }; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. gpuSrc = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo="; url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk="; }; in buildPythonPackage { buildPythonPackage rec { pname = "jaxlib"; inherit version; format = "wheel"; # At the time of writing (2022-10-19), there are releases for <=3.10. # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs # python version. disabled = !(pythonVersion == "3.10"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. Loading @@ -99,10 +87,9 @@ buildPythonPackage { # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc.lib ]; buildInputs = [ stdenv.cc.cc ]; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or Loading @@ -126,12 +113,7 @@ buildPythonPackage { done ''; propagatedBuildInputs = [ absl-py flatbuffers ml-dtypes scipy ]; propagatedBuildInputs = [ absl-py flatbuffers scipy ]; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for Loading @@ -141,7 +123,7 @@ buildPythonPackage { ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ''; inherit (jaxlib) pythonImportsCheck; pythonImportsCheck = [ "jaxlib" ]; meta = with lib; { description = "XLA library for JAX"; Loading
pkgs/development/python-modules/jaxlib/default.nix +36 −32 Original line number Diff line number Diff line Loading @@ -4,7 +4,7 @@ # Build-time dependencies: , addOpenGLRunpath , bazel_6 , bazel_5 , binutils , buildBazelPackage , buildPythonPackage Loading @@ -26,7 +26,6 @@ # Python dependencies: , absl-py , flatbuffers , ml-dtypes , numpy , scipy , six Loading @@ -36,6 +35,7 @@ , giflib , grpc , libjpeg_turbo , protobuf , python , snappy , zlib Loading @@ -53,7 +53,7 @@ let inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; pname = "jaxlib"; version = "0.4.12"; version = "0.4.4"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; Loading Loading @@ -138,15 +138,14 @@ let bazel-build = buildBazelPackage rec { name = "bazel-build-${pname}-${version}"; # See https://github.com/google/jax/blob/main/.bazelversion for the latest. bazel = bazel_6; bazel = bazel_5; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo="; }; nativeBuildInputs = [ Loading @@ -170,7 +169,7 @@ let numpy openssl pkgs.flatbuffers pkgs.protobuf protobuf pybind11 scipy six Loading @@ -189,8 +188,7 @@ let rm -f .bazelversion ''; bazelRunTarget = "//build:build_wheel"; runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ]; bazelTargets = [ "//build:build_wheel" ]; removeRulesCC = false; Loading @@ -209,11 +207,7 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) '' build --config=avx_posix '' + lib.optionalString mklSupport '' build --config=mkl_open_source_only build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include" '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" Loading @@ -240,7 +234,7 @@ let fetchAttrs = { TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; # we have to force @mkl_dnn_v1 since it's not needed on darwin bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ]; bazelFlags = bazelFlags ++ [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ Loading @@ -255,9 +249,9 @@ let sha256 = if cudaSupport then "sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4=" "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk=" else "sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc="; "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI="; }; buildAttrs = { Loading @@ -267,13 +261,25 @@ let "nsync" # fails to build on darwin ]); bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ "--config=cuda" ] ++ lib.optionals mklSupport [ "--config=mkl_open_source_only" ]; # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # 1) Fix pybind11 include paths. # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. # 2) Patch python path in the compiler driver. preBuild = lib.optionalString cudaSupport '' # 3) Patch python path in the compiler driver. preBuild = '' for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done '' + lib.optionalString cudaSupport '' export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' # Framework search paths aren't added by bintools hook # https://github.com/NixOS/nixpkgs/pull/41914 Loading @@ -283,12 +289,16 @@ let substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ --replace "/usr/bin/libtool" "${cctools}/bin/libtool" '' + (if stdenv.cc.isGNU then '' sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD '' else if stdenv.cc.isClang then '' sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); installPhase = '' ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch} ''; }; inherit meta; Loading Loading @@ -335,19 +345,13 @@ buildPythonPackage { grpc jsoncpp libjpeg_turbo ml-dtypes numpy scipy six snappy ]; pythonImportsCheck = [ "jaxlib" # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. "jaxlib.cpu_feature_guard" "jaxlib.xla_client" ]; pythonImportsCheck = [ "jaxlib" ]; # Without it there are complaints about libcudart.so.11.0 not being found # because RPATH path entries added above are stripped. Loading
pkgs/development/python-modules/jaxlib/prefetch.sh +7 −15 Original line number Diff line number Diff line #!/usr/bin/env bash prefetch () { expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) echo "$url" sha256=$(nix-prefetch-url "$url") nix hash to-sri --type sha256 "$sha256" echo } prefetch "x86_64-linux" "false" prefetch "aarch64-darwin" "false" prefetch "x86_64-darwin" "false" prefetch "x86_64-linux" "true" version="$1" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"