Commit 3aa42474 authored by hacker1024's avatar hacker1024
Browse files

python310Packages.tensorflow-bin: Allow building for Jetsons

parent 059696a8
Loading
Loading
Loading
Loading
+15 −8
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@
  distutils,
  wheel,
  jax,
  ml-dtypes,
  opt-einsum,
  tensorflow-estimator-bin,
  tensorboard,
@@ -38,9 +39,10 @@
  typing-extensions,
}:

# We keep this binary build for two reasons:
# We keep this binary build for three reasons:
# - the source build doesn't work on Darwin.
# - the source build is currently brittle and not easy to maintain
# - the source build doesn't work on NVIDIA Jetson platforms

# unsupported combination
assert !(stdenv.isDarwin && cudaSupport);
@@ -49,19 +51,19 @@ let
  packages = import ./binary-hashes.nix;
  inherit (cudaPackages) cudatoolkit cudnn;
in
buildPythonPackage {
buildPythonPackage rec {
  pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
  inherit (packages) version;
  version = packages."${"version" + lib.optionalString (cudaSupport && cudaPackages.cudaFlags.isJetsonBuild) "_jetson"}";
  format = "wheel";

  src =
    let
      pyVerNoDot = lib.strings.stringAsChars (x: lib.optionalString (x != ".") x) python.pythonVersion;
      platform = stdenv.system;
      cuda = lib.optionalString cudaSupport "_gpu";
      cuda = lib.optionalString cudaSupport (if cudaPackages.cudaFlags.isJetsonBuild then "_jetson" else "_gpu");
      key = "${platform}_${pyVerNoDot}${cuda}";
    in
    fetchurl (packages.${key} or (throw "tensoflow-bin: unsupported system: ${stdenv.system}"));
    fetchurl (packages.${key} or (throw "tensoflow-bin: unsupported configuration: ${key}"));

  buildInputs = [ llvmPackages.openmp ];

@@ -73,7 +75,7 @@ buildPythonPackage {
    protobuf
    numpy
    scipy
    jax
    (if !cudaPackages.cudaFlags.isJetsonBuild then jax else ml-dtypes)
    termcolor
    grpcio
    six
@@ -103,6 +105,10 @@ buildPythonPackage {

    pushd dist

    for f in tensorflow-*+nv*.whl; do
      mv "$f" "$(sed -E 's/(nv[0-9]+)\.0*([0-9]+)/\1.\2/' <<< "$f")"
    done

    wheel unpack --dest unpacked ./*.whl
    rm ./*.whl
    (
@@ -134,11 +140,12 @@ buildPythonPackage {
  postFixup =
    let
      # rpaths we only need to add if CUDA is enabled.
      cudapaths = lib.optionals cudaSupport [
      cudapaths = lib.optionals cudaSupport ([
        cudatoolkit.out
        cudatoolkit.lib
      ] ++ lib.optionals (!cudaPackages.cudaFlags.isJetsonBuild) [
        cudnn
      ];
      ]);

      libpaths = [
        stdenv.cc.cc.lib