Commit 0e6724e0 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

python3Packages.executorch: fix cuda build

parent e886994e
Loading
Loading
Loading
Loading
+18 −1
Original line number Diff line number Diff line
@@ -7,6 +7,8 @@

  # nativeBuildInputs
  gitMinimal,
  # cuda-only:
  autoPatchelfHook,

  # build-system
  certifi,
@@ -46,8 +48,11 @@
  transformers,
  writableTmpDirAsHomeHook,
  yaspin,

  cudaSupport ? torch.cudaSupport,
  cudaPackages,
}:
buildPythonPackage (finalAttrs: {
buildPythonPackage.override { inherit (torch) stdenv; } (finalAttrs: {
  pname = "executorch";
  version = "1.2.0";
  pyproject = true;
@@ -107,6 +112,9 @@ buildPythonPackage (finalAttrs: {
    #  Some binaries contain forbidden references to /build/. Check the error above!
    CMAKE_ARGS = lib.concatStringsSep " " [
      (lib.cmakeBool "CMAKE_SKIP_BUILD_RPATH" true)

      # For some cmake-tier reason, cmakeBool does not work here
      (lib.cmakeFeature "EXECUTORCH_BUILD_CUDA" (if cudaSupport then "ON" else "OFF"))
    ];
  };

@@ -122,6 +130,15 @@ buildPythonPackage (finalAttrs: {

  nativeBuildInputs = [
    gitMinimal
  ]
  ++ lib.optionals cudaSupport [
    autoPatchelfHook
    cudaPackages.cuda_nvcc
  ];

  buildInputs = lib.optionals cudaSupport [
    cudaPackages.cuda_cudart
    cudaPackages.cuda_nvrtc
  ];

  pythonRemoveDeps = [