Commit 0bbb2099 authored by ferres's avatar ferres
Browse files

python3Packages.xformers: support aarch64-darwin

parent 8dc663ea
Loading
Loading
Loading
Loading
+22 −20
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@
  transformers,
  timm,
  #, flash-attn
  openmp,
}:
let
  inherit (torch) cudaCapabilities cudaPackages cudaSupport;
@@ -66,7 +67,9 @@ buildPythonPackage {

  stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;

  buildInputs = lib.optionals cudaSupport (
  buildInputs =
    lib.optional stdenv.hostPlatform.isDarwin openmp
    ++ lib.optionals cudaSupport (
      with cudaPackages;
      [
        # flash-attn build
@@ -79,10 +82,13 @@ buildPythonPackage {
      ]
    );

  nativeBuildInputs = [
  nativeBuildInputs =
    [
      ninja
      which
  ] ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]);
    ]
    ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ])
    ++ lib.optional stdenv.hostPlatform.isDarwin openmp.dev;

  dependencies = [
    numpy
@@ -123,9 +129,5 @@ buildPythonPackage {
    changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
    license = lib.licenses.bsd3;
    maintainers = with lib.maintainers; [ happysalada ];
    badPlatforms = [
      # fatal error: 'omp.h' file not found
      lib.systems.inspect.patterns.isDarwin
    ];
  };
}
+3 −1
Original line number Diff line number Diff line
@@ -19574,7 +19574,9 @@ self: super: with self; {
  xen = toPythonModule (pkgs.xen.override { python3Packages = self; });
  xformers = callPackage ../development/python-modules/xformers { };
  xformers = callPackage ../development/python-modules/xformers {
    inherit (pkgs.llvmPackages) openmp;
  };
  xgboost = callPackage ../development/python-modules/xgboost { inherit (pkgs) xgboost; };