Commit 8ead8d01 authored by Gaetan Lepage's avatar Gaetan Lepage
Browse files

mistral-rs: refactor use of stdenv.hostPlatform

parent 83cb45b0
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -34,6 +34,8 @@
}:

let
  inherit (stdenv) hostPlatform;

  accelIsValid = builtins.elem acceleration [
    null
    false
@@ -67,7 +69,7 @@ let
  metalSupport =
    assert accelIsValid;
    (acceleration == "metal")
    || (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64 && (acceleration == null));
    || (hostPlatform.isDarwin && hostPlatform.isAarch64 && (acceleration == null));

in
rustPlatform.buildRustPackage (finalAttrs: {
@@ -119,7 +121,7 @@ rustPlatform.buildRustPackage (finalAttrs: {
  buildFeatures =
    lib.optionals cudaSupport [ "cuda" ]
    ++ lib.optionals mklSupport [ "mkl" ]
    ++ lib.optionals (stdenv.hostPlatform.isDarwin && metalSupport) [ "metal" ];
    ++ lib.optionals (hostPlatform.isDarwin && metalSupport) [ "metal" ];

  env =
    {
@@ -159,7 +161,7 @@ rustPlatform.buildRustPackage (finalAttrs: {
  # swagger-ui will once more be copied in the target directory during the check phase
  # Not deleting the existing unpacked archive leads to a `PermissionDenied` error
  preCheck = ''
    rm -rf target/${stdenv.hostPlatform.config}/release/build/
    rm -rf target/${hostPlatform.config}/release/build/
  '';

  # Prevent checkFeatures from inheriting buildFeatures because
@@ -185,13 +187,11 @@ rustPlatform.buildRustPackage (finalAttrs: {
    tests = {
      version = testers.testVersion { package = mistral-rs; };

      withMkl = lib.optionalAttrs (stdenv.hostPlatform == "x86_64-linux") (
      withMkl = lib.optionalAttrs (hostPlatform.isLinux && hostPlatform.isx86_64) (
        mistral-rs.override { acceleration = "mkl"; }
      );
      withCuda = lib.optionalAttrs stdenv.hostPlatform.isLinux (
        mistral-rs.override { acceleration = "cuda"; }
      );
      withMetal = lib.optionalAttrs (stdenv.hostPlatform == "aarch64-darwin") (
      withCuda = lib.optionalAttrs hostPlatform.isLinux (mistral-rs.override { acceleration = "cuda"; });
      withMetal = lib.optionalAttrs (hostPlatform.isDarwin && hostPlatform.isAarch64) (
        mistral-rs.override { acceleration = "metal"; }
      );
    };