Unverified Commit bb849179 authored by Someone's avatar Someone Committed by GitHub
Browse files

python3Packages.tensorflow-bin: Use CUDA 12, add Jetson support (#334996)

parents 268bb509 f2a4ee0a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ buildPythonPackage rec {
  pythonImportsCheck = [ "geoip2" ];

  disabledTests =
    lib.optionals (pythonAtLeast "3.11") [
    lib.optionals (pythonAtLeast "3.10") [
      # https://github.com/maxmind/GeoIP2-python/pull/136
      "TestAsyncClient"
    ]
+22 −9
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@
  distutils,
  wheel,
  jax,
  ml-dtypes,
  opt-einsum,
  tensorflow-estimator-bin,
  tensorboard,
@@ -38,9 +39,10 @@
  typing-extensions,
}:

# We keep this binary build for two reasons:
# We keep this binary build for three reasons:
# - the source build doesn't work on Darwin.
# - the source build is currently brittle and not easy to maintain
# - the source build doesn't work on NVIDIA Jetson platforms

# unsupported combination
assert !(stdenv.isDarwin && cudaSupport);
@@ -48,20 +50,23 @@ assert !(stdenv.isDarwin && cudaSupport);
let
  packages = import ./binary-hashes.nix;
  inherit (cudaPackages) cudatoolkit cudnn;

  isCudaJetson = cudaSupport && cudaPackages.cudaFlags.isJetsonBuild;
  isCudaX64 = cudaSupport && stdenv.hostPlatform.isx86_64;
in
buildPythonPackage {
  pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
  inherit (packages) version;
  version = packages."${"version" + lib.optionalString isCudaJetson "_jetson"}";
  format = "wheel";

  src =
    let
      pyVerNoDot = lib.strings.stringAsChars (x: lib.optionalString (x != ".") x) python.pythonVersion;
      platform = stdenv.system;
      cuda = lib.optionalString cudaSupport "_gpu";
      cuda = lib.optionalString cudaSupport (if isCudaJetson then "_jetson" else "_gpu");
      key = "${platform}_${pyVerNoDot}${cuda}";
    in
    fetchurl (packages.${key} or (throw "tensoflow-bin: unsupported system: ${stdenv.system}"));
    fetchurl (packages.${key} or (throw "tensoflow-bin: unsupported configuration: ${key}"));

  buildInputs = [ llvmPackages.openmp ];

@@ -73,7 +78,7 @@ buildPythonPackage {
    protobuf
    numpy
    scipy
    jax
    (if isCudaX64 then jax else ml-dtypes)
    termcolor
    grpcio
    six
@@ -90,10 +95,13 @@ buildPythonPackage {
    h5py
  ] ++ lib.optional (!isPy3k) mock;

  build-system = [
  build-system =
    [
      distutils
      wheel
  ] ++ lib.optionals cudaSupport [ addDriverRunpath ];
    ]
    ++ lib.optionals cudaSupport [ addDriverRunpath ]
    ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ];

  preConfigure = ''
    unset SOURCE_DATE_EPOCH
@@ -103,6 +111,11 @@ buildPythonPackage {

    pushd dist

    for f in tensorflow-*+nv*.whl; do
      # e.g. *nv24.07* -> *nv24.7*
      mv "$f" "$(sed -E 's/(nv[0-9]+)\.0*([0-9]+)/\1.\2/' <<< "$f")"
    done

    wheel unpack --dest unpacked ./*.whl
    rm ./*.whl
    (
+5 −0
Original line number Diff line number Diff line
{
  version = "2.17.0";
  version_jetson = "2.16.1+nv24.07";
  x86_64-linux_39 = {
    url = "https://storage.googleapis.com/tensorflow/versions/2.17.0/tensorflow_cpu-2.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl";
    sha256 = "1aacn68b88bnnmpl1q0irih0avzm2lfyhwr3wldg144n5zljlrbx";
@@ -48,6 +49,10 @@
    url = "https://storage.googleapis.com/tensorflow/versions/2.17.0/tensorflow-2.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl";
    sha256 = "1zrscms9qkfpiscnl8c7ibfipwpw8jrdfvwh4wb69p9rxvqgxbbj";
  };
  aarch64-linux_310_jetson = {
    url = "https://developer.download.nvidia.com/compute/redist/jp/v60/tensorflow/tensorflow-2.16.1+nv24.07-cp310-cp310-linux_aarch64.whl";
    sha256 = "1ymdknl5v41z6z0wg068diici30am8vysg6b6sqxr8w6yk4aib42";
  };
  aarch64-darwin_39 = {
    url = "https://storage.googleapis.com/tensorflow/versions/2.17.0/tensorflow-2.17.0-cp39-cp39-macosx_12_0_arm64.whl";
    sha256 = "01a3hjnrgjp2i0ciwyy0gki41cy32prvjhr20zhlcjwbssarxy4p";
+4 −0
Original line number Diff line number Diff line
#!/usr/bin/env bash

version="2.17.0"
version_jetson="2.16.1+nv24.07"

bucket="https://storage.googleapis.com/tensorflow/versions/${version}"
bucket_jetson="https://developer.download.nvidia.com/compute/redist/jp/v60/tensorflow"

# List of binary wheels for Tensorflow.  The most recent versions can be found
# on the following page:
@@ -20,6 +22,7 @@ url_and_key_list=(
"aarch64-linux_310 $bucket/tensorflow-${version}-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"
"aarch64-linux_311 $bucket/tensorflow-${version}-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"
"aarch64-linux_312 $bucket/tensorflow-${version}-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"
"aarch64-linux_310_jetson $bucket_jetson/tensorflow-${version_jetson}-cp310-cp310-linux_aarch64.whl"
"aarch64-darwin_39 $bucket/tensorflow-${version}-cp39-cp39-macosx_12_0_arm64.whl"
"aarch64-darwin_310 $bucket/tensorflow-${version}-cp310-cp310-macosx_12_0_arm64.whl"
"aarch64-darwin_311 $bucket/tensorflow-${version}-cp311-cp311-macosx_12_0_arm64.whl"
@@ -30,6 +33,7 @@ hashfile=binary-hashes.nix
rm -f $hashfile
echo "{" >> $hashfile
echo "version = \"$version\";" >> $hashfile
echo "version_jetson = \"$version_jetson\";" >> $hashfile

for url_and_key in "${url_and_key_list[@]}"; do
  key=$(echo "$url_and_key" | cut -d' ' -f1)
+1 −1
Original line number Diff line number Diff line
@@ -15246,7 +15246,7 @@ self: super: with self; {
  tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix {
    inherit (pkgs.config) cudaSupport;
    # https://www.tensorflow.org/install/source#gpu
    cudaPackages = pkgs.cudaPackages_11;
    cudaPackages = pkgs.cudaPackages_12;
  };
  tensorflow-build = let