Unverified Commit aaa5b43a authored by Someone's avatar Someone Committed by GitHub
Browse files

mistral-rs: make free again, refactor (#398403)

parents 270ec701 8ead8d01
Loading
Loading
Loading
Loading
+9 −9
Original line number Diff line number Diff line
@@ -34,6 +34,8 @@
}:

let
  inherit (stdenv) hostPlatform;

  accelIsValid = builtins.elem acceleration [
    null
    false
@@ -67,7 +69,7 @@ let
  metalSupport =
    assert accelIsValid;
    (acceleration == "metal")
    || (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64 && (acceleration == null));
    || (hostPlatform.isDarwin && hostPlatform.isAarch64 && (acceleration == null));

in
rustPlatform.buildRustPackage (finalAttrs: {
@@ -119,7 +121,7 @@ rustPlatform.buildRustPackage (finalAttrs: {
  buildFeatures =
    lib.optionals cudaSupport [ "cuda" ]
    ++ lib.optionals mklSupport [ "mkl" ]
    ++ lib.optionals (stdenv.hostPlatform.isDarwin && metalSupport) [ "metal" ];
    ++ lib.optionals (hostPlatform.isDarwin && metalSupport) [ "metal" ];

  env =
    {
@@ -149,7 +151,7 @@ rustPlatform.buildRustPackage (finalAttrs: {
      CUDA_TOOLKIT_ROOT_DIR = lib.getDev cudaPackages.cuda_cudart;
    });

  appendRunpaths = [
  appendRunpaths = lib.optionals cudaSupport [
    (lib.makeLibraryPath [
      cudaPackages.libcublas
      cudaPackages.libcurand
@@ -159,7 +161,7 @@ rustPlatform.buildRustPackage (finalAttrs: {
  # swagger-ui will once more be copied in the target directory during the check phase
  # Not deleting the existing unpacked archive leads to a `PermissionDenied` error
  preCheck = ''
    rm -rf target/${stdenv.hostPlatform.config}/release/build/
    rm -rf target/${hostPlatform.config}/release/build/
  '';

  # Prevent checkFeatures from inheriting buildFeatures because
@@ -185,13 +187,11 @@ rustPlatform.buildRustPackage (finalAttrs: {
    tests = {
      version = testers.testVersion { package = mistral-rs; };

      withMkl = lib.optionalAttrs (stdenv.hostPlatform == "x86_64-linux") (
      withMkl = lib.optionalAttrs (hostPlatform.isLinux && hostPlatform.isx86_64) (
        mistral-rs.override { acceleration = "mkl"; }
      );
      withCuda = lib.optionalAttrs stdenv.hostPlatform.isLinux (
        mistral-rs.override { acceleration = "cuda"; }
      );
      withMetal = lib.optionalAttrs (stdenv.hostPlatform == "aarch64-darwin") (
      withCuda = lib.optionalAttrs hostPlatform.isLinux (mistral-rs.override { acceleration = "cuda"; });
      withMetal = lib.optionalAttrs (hostPlatform.isDarwin && hostPlatform.isAarch64) (
        mistral-rs.override { acceleration = "metal"; }
      );
    };