Unverified Commit 11ff96a6 authored by Maximilian Bosch's avatar Maximilian Bosch
Browse files

nixos/test-driver: use RequestedAssertionFailed/TestScriptError in Machine class



I think it's reasonable to also have this kind of visual distinction
here between test failures and actual errors from the test framework.

A failing `machine.require_unit_state` now lookgs like this for
instance:

    !!! Traceback (most recent call last):
    !!!   File "<string>", line 3, in <module>
    !!!     machine.require_unit_state("postgresql","active")
    !!!
    !!! RequestedAssertionFailed: Expected unit 'postgresql' to to be in state 'active' but it is in state 'inactive'

Co-authored-by: default avatarBenoit de Chezelles <bew@users.noreply.github.com>
parent a1dfaf51
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ target-version = "py312"
line-length = 88

lint.select = ["E", "F", "I", "U", "N"]
lint.ignore = ["E501"]
lint.ignore = ["E501", "N818"]

# xxx: we can import https://pypi.org/project/types-colorama/ here
[[tool.mypy.overrides]]
+19 −15
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Any
from unittest import TestCase

from test_driver.errors import RequestedAssertionFailed, TestScriptError
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
@@ -19,6 +20,18 @@ from test_driver.vlan import VLan
SENTINEL = object()


class AssertionTester(TestCase):
    """
    Subclass of `unittest.TestCase` which is used in the
    `testScript` to perform assertions.

    It throws a custom exception whose parent class
    gets special treatment in the logs.
    """

    failureException = RequestedAssertionFailed


def get_tmp_dir() -> Path:
    """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
    Raises an exception in case the retrieved temporary directory is not writeable
@@ -41,14 +54,6 @@ def pythonize_name(name: str) -> str:
    return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name)


class NixOSAssertionError(AssertionError):
    pass


class Tester(TestCase):
    failureException = NixOSAssertionError


class Driver:
    """A handle to the driver that sets up the environment
    and runs the tests"""
@@ -126,7 +131,7 @@ class Driver:
            try:
                yield
            except Exception as e:
                self.logger.error(f'Test "{name}" failed with error: "{e}"')
                self.logger.log_test_error(f'Test "{name}" failed with error: "{e}"')
                raise e

    def test_symbols(self) -> dict[str, Any]:
@@ -151,7 +156,7 @@ class Driver:
            serial_stdout_on=self.serial_stdout_on,
            polling_condition=self.polling_condition,
            Machine=Machine,  # for typing
            t=Tester(),
            t=AssertionTester(),
        )
        machine_symbols = {pythonize_name(m.name): m for m in self.machines}
        # If there's exactly one machine, make it available under the name
@@ -177,7 +182,7 @@ class Driver:
            symbols = self.test_symbols()  # call eagerly
            try:
                exec(self.tests, symbols, None)
            except NixOSAssertionError:
            except TestScriptError:
                exc_type, exc, tb = sys.exc_info()
                filtered = [
                    frame
@@ -186,13 +191,12 @@ class Driver:
                ]

                self.logger.log_test_error("Traceback (most recent call last):")

                code = self.tests.splitlines()
                for frame, line in zip(filtered, traceback.format_list(filtered)):
                    self.logger.log_test_error(line.rstrip())
                    if lineno := frame.lineno:
                        self.logger.log_test_error(
                            f"    {code[lineno - 1].strip()}",
                        )
                        self.logger.log_test_error(f"    {code[lineno - 1].strip()}")

                self.logger.log_test_error("")  # blank line for readability
                exc_prefix = exc_type.__name__ if exc_type is not None else "Error"
+23 −0
Original line number Diff line number Diff line
class TestScriptError(Exception):
    """
    The base error class to indicate that the test script failed.
    This (and its subclasses) get special treatment, i.e. only stack
    frames from `testScript` are printed and the error gets prefixed
    with `!!!` to make it easier to spot between other log-lines.

    This class is used for errors that aren't an actual test failure,
    but also not a bug in the driver, e.g. failing OCR.
    """


class RequestedAssertionFailed(TestScriptError):
    """
    Subclass of `TestScriptError` that gets special treatment.

    Exception raised when a requested assertion fails,
    e.g. `machine.succeed(...)` or `t.assertEqual(...)`.

    This is separate from the base error class, to have a dedicated class name
    that better represents this kind of failures.
    (better readability in test output)
    """
+19 −12
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ from pathlib import Path
from queue import Queue
from typing import Any

from test_driver.errors import RequestedAssertionFailed, TestScriptError
from test_driver.logger import AbstractLogger

from .qmp import QMPSession
@@ -128,7 +129,7 @@ def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
    )

    if ret.returncode != 0:
        raise Exception(
        raise TestScriptError(
            f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
        )

@@ -139,7 +140,7 @@ def _perform_ocr_on_screenshot(
    screenshot_path: str, model_ids: Iterable[int]
) -> list[str]:
    if shutil.which("tesseract") is None:
        raise Exception("OCR requested but enableOCR is false")
        raise TestScriptError("OCR requested but enableOCR is false")

    processed_image = _preprocess_screenshot(screenshot_path, negate=False)
    processed_negative = _preprocess_screenshot(screenshot_path, negate=True)
@@ -162,7 +163,7 @@ def _perform_ocr_on_screenshot(
                capture_output=True,
            )
            if ret.returncode != 0:
                raise Exception(f"OCR failed with exit code {ret.returncode}")
                raise TestScriptError(f"OCR failed with exit code {ret.returncode}")
            model_results.append(ret.stdout.decode("utf-8"))

    return model_results
@@ -179,7 +180,9 @@ def retry(fn: Callable, timeout: int = 900) -> None:
        time.sleep(1)

    if not fn(True):
        raise Exception(f"action timed out after {timeout} seconds")
        raise RequestedAssertionFailed(
            f"action timed out after {timeout} tries with one-second pause in-between"
        )


class StartCommand:
@@ -402,14 +405,14 @@ class Machine:
        def check_active(_last_try: bool) -> bool:
            state = self.get_unit_property(unit, "ActiveState", user)
            if state == "failed":
                raise Exception(f'unit "{unit}" reached state "{state}"')
                raise RequestedAssertionFailed(f'unit "{unit}" reached state "{state}"')

            if state == "inactive":
                status, jobs = self.systemctl("list-jobs --full 2>&1", user)
                if "No jobs" in jobs:
                    info = self.get_unit_info(unit, user)
                    if info["ActiveState"] == state:
                        raise Exception(
                        raise RequestedAssertionFailed(
                            f'unit "{unit}" is inactive and there are no pending jobs'
                        )

@@ -424,7 +427,7 @@ class Machine:
    def get_unit_info(self, unit: str, user: str | None = None) -> dict[str, str]:
        status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
        if status != 0:
            raise Exception(
            raise RequestedAssertionFailed(
                f'retrieving systemctl info for unit "{unit}"'
                + ("" if user is None else f' under user "{user}"')
                + f" failed with exit code {status}"
@@ -454,7 +457,7 @@ class Machine:
            user,
        )
        if status != 0:
            raise Exception(
            raise RequestedAssertionFailed(
                f'retrieving systemctl property "{property}" for unit "{unit}"'
                + ("" if user is None else f' under user "{user}"')
                + f" failed with exit code {status}"
@@ -502,7 +505,7 @@ class Machine:
            info = self.get_unit_info(unit)
            state = info["ActiveState"]
            if state != require_state:
                raise Exception(
                raise RequestedAssertionFailed(
                    f"Expected unit '{unit}' to to be in state "
                    f"'{require_state}' but it is in state '{state}'"
                )
@@ -656,7 +659,9 @@ class Machine:
                (status, out) = self.execute(command, timeout=timeout)
                if status != 0:
                    self.log(f"output: {out}")
                    raise Exception(f"command `{command}` failed (exit code {status})")
                    raise RequestedAssertionFailed(
                        f"command `{command}` failed (exit code {status})"
                    )
                output += out
        return output

@@ -670,7 +675,9 @@ class Machine:
            with self.nested(f"must fail: {command}"):
                (status, out) = self.execute(command, timeout=timeout)
                if status == 0:
                    raise Exception(f"command `{command}` unexpectedly succeeded")
                    raise RequestedAssertionFailed(
                        f"command `{command}` unexpectedly succeeded"
                    )
                output += out
        return output

@@ -915,7 +922,7 @@ class Machine:
            ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
            os.unlink(tmp)
            if ret.returncode != 0:
                raise Exception("Cannot convert screenshot")
                raise TestScriptError("Cannot convert screenshot")

    def copy_from_host_via_shell(self, source: str, target: str) -> None:
        """Copy a file from the host into the guest by piping it over the