Unverified Commit c1c21ed5 authored by Luna Nova's avatar Luna Nova
Browse files

rocmPackages.migraphx: add migraphx-driver impureTest using small resnet

parent 4781916d
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
{
  lib,
  stdenv,
  callPackage,
  fetchFromGitHub,
  rocmUpdateScript,
  pkg-config,
@@ -183,6 +184,12 @@ stdenv.mkDerivation (finalAttrs: {
      patchelf $test/bin/test_* --shrink-rpath --allowed-rpath-prefixes "$NIX_STORE"
    '';

  passthru.impureTests = {
    # NIXPKGS_ALLOW_UNFREE=1 bash $(nix-build -A rocmPackages.migraphx.impureTests.migraphx-driver)
    migraphx-driver = callPackage ./test-migraphx-driver.nix {
      migraphx = finalAttrs.finalPackage;
    };
  };
  passthru.updateScript = rocmUpdateScript {
    name = finalAttrs.pname;
    inherit (finalAttrs.src) owner;
+47 −0
Original line number Diff line number Diff line
{
  lib,
  fetchurl,
  makeImpureTest,
  writableTmpDirAsHomeHook,
  migraphx,
  clr,
  rocm-smi,
}:

# Verify that a ≈50MiB resnet onnx can run with migraphx
let
  resnet18 = fetchurl {
    url = "https://huggingface.co/onnxmodelzoo/resnet18_Opset18_timm/resolve/main/resnet18_Opset18_timm.onnx";
    hash = "sha256-u2Io20n72qoA9atRsFIWb0zHF1WdJYgHQdMWfJhJGHA=";
    meta.license = lib.licenses.unfree;
  };
in
makeImpureTest {
  name = "migraphx-driver";
  testedPackage = "rocmPackages.migraphx";

  sandboxPaths = [
    "/sys"
    "/dev/dri"
    "/dev/kfd"
  ];

  nativeBuildInputs = [
    writableTmpDirAsHomeHook
    migraphx
    clr
    rocm-smi
  ];

  # FIXME(@LunNova): tol values are set too high - was seeing high divergence on iGPU
  # want this test to be useful for verifying workloads run at all
  # and will investigate what's broken for accuracy
  testScript = ''
    rocm-smi
    migraphx-driver verify -O --rms-tol 0.03 --atol 1.0 --rtol 0.01 ${resnet18}
  '';

  meta = {
    teams = [ lib.teams.rocm ];
  };
}