Commit f2a4ee0a authored by hacker1024's avatar hacker1024
Browse files

python3Packages.tensorflow-bin: Simplify CUDA platform branching

parent cca50359
Loading
Loading
Loading
Loading
+13 −8
Original line number Diff line number Diff line
@@ -50,17 +50,20 @@ assert !(stdenv.isDarwin && cudaSupport);
let
  packages = import ./binary-hashes.nix;
  inherit (cudaPackages) cudatoolkit cudnn;

  isCudaJetson = cudaSupport && cudaPackages.cudaFlags.isJetsonBuild;
  isCudaX64 = cudaSupport && stdenv.hostPlatform.isx86_64;
in
buildPythonPackage {
  pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
  version = packages."${"version" + lib.optionalString (cudaSupport && cudaPackages.cudaFlags.isJetsonBuild) "_jetson"}";
  version = packages."${"version" + lib.optionalString isCudaJetson "_jetson"}";
  format = "wheel";

  src =
    let
      pyVerNoDot = lib.strings.stringAsChars (x: lib.optionalString (x != ".") x) python.pythonVersion;
      platform = stdenv.system;
      cuda = lib.optionalString cudaSupport (if cudaPackages.cudaFlags.isJetsonBuild then "_jetson" else "_gpu");
      cuda = lib.optionalString cudaSupport (if isCudaJetson then "_jetson" else "_gpu");
      key = "${platform}_${pyVerNoDot}${cuda}";
    in
    fetchurl (packages.${key} or (throw "tensoflow-bin: unsupported configuration: ${key}"));
@@ -75,7 +78,7 @@ buildPythonPackage {
    protobuf
    numpy
    scipy
    (if !cudaPackages.cudaFlags.isJetsonBuild then jax else ml-dtypes)
    (if isCudaX64 then jax else ml-dtypes)
    termcolor
    grpcio
    six
@@ -92,11 +95,13 @@ buildPythonPackage {
    h5py
  ] ++ lib.optional (!isPy3k) mock;

  build-system = [
  build-system =
    [
      distutils
      wheel
  ] ++ lib.optionals cudaSupport ([ addDriverRunpath ]
    ++ lib.optionals cudaPackages.cudaFlags.isJetsonBuild [ cudaPackages.autoAddCudaCompatRunpath ]);
    ]
    ++ lib.optionals cudaSupport [ addDriverRunpath ]
    ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ];

  preConfigure = ''
    unset SOURCE_DATE_EPOCH