Commit d07866cd authored by Stefan Hertrampf's avatar Stefan Hertrampf
Browse files

nixos/test-driver: rm global logger

We remove the global rootlog in favor of instantiating the logger as
required in the __init__.py and pass it down as a parameter (of our
AbstractLogger type).
parent 303618c7
Loading
Loading
Loading
Loading
+13 −6
Original line number Diff line number Diff line
@@ -6,7 +6,12 @@ from pathlib import Path
import ptpython.repl

from test_driver.driver import Driver
from test_driver.logger import JunitXMLLogger, XMLLogger, rootlog
from test_driver.logger import (
    CompositeLogger,
    JunitXMLLogger,
    TerminalLogger,
    XMLLogger,
)


class EnvDefault(argparse.Action):
@@ -108,21 +113,23 @@ def main() -> None:
    args = arg_parser.parse_args()

    output_directory = args.output_directory.resolve()
    logger = CompositeLogger([TerminalLogger()])

    if "LOGFILE" in os.environ.keys():
        rootlog.add_logger(XMLLogger(os.environ["LOGFILE"]))
        logger.add_logger(XMLLogger(os.environ["LOGFILE"]))

    if args.junit_xml:
        rootlog.add_logger(JunitXMLLogger(output_directory / args.junit_xml))
        logger.add_logger(JunitXMLLogger(output_directory / args.junit_xml))

    if not args.keep_vm_state:
        rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
        logger.info("Machine state will be reset. To keep it, pass --keep-vm-state")

    with Driver(
        args.start_scripts,
        args.vlans,
        args.testscript.read_text(),
        output_directory,
        logger,
        args.keep_vm_state,
        args.global_timeout,
    ) as driver:
@@ -138,7 +145,7 @@ def main() -> None:
            tic = time.time()
            driver.run_tests()
            toc = time.time()
            rootlog.info(f"test script finished in {(toc-tic):.2f}s")
            logger.info(f"test script finished in {(toc-tic):.2f}s")


def generate_driver_symbols() -> None:
@@ -147,7 +154,7 @@ def generate_driver_symbols() -> None:
    in user's test scripts. That list is then used by pyflakes to lint those
    scripts.
    """
    d = Driver([], [], "", Path())
    d = Driver([], [], "", Path(), CompositeLogger([]))
    test_symbols = d.test_symbols()
    with open("driver-symbols", "w") as fp:
        fp.write(",".join(test_symbols.keys()))
+26 −18
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional

from colorama import Fore, Style

from test_driver.logger import rootlog
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan
@@ -49,6 +49,7 @@ class Driver:
    polling_conditions: List[PollingCondition]
    global_timeout: int
    race_timer: threading.Timer
    logger: AbstractLogger

    def __init__(
        self,
@@ -56,6 +57,7 @@ class Driver:
        vlans: List[int],
        tests: str,
        out_dir: Path,
        logger: AbstractLogger,
        keep_vm_state: bool = False,
        global_timeout: int = 24 * 60 * 60 * 7,
    ):
@@ -63,12 +65,13 @@ class Driver:
        self.out_dir = out_dir
        self.global_timeout = global_timeout
        self.race_timer = threading.Timer(global_timeout, self.terminate_test)
        self.logger = logger

        tmp_dir = get_tmp_dir()

        with rootlog.nested("start all VLans"):
        with self.logger.nested("start all VLans"):
            vlans = list(set(vlans))
            self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]

        def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
            for s in scripts:
@@ -84,6 +87,7 @@ class Driver:
                tmp_dir=tmp_dir,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
                logger=self.logger,
            )
            for cmd in cmd(start_scripts)
        ]
@@ -92,19 +96,19 @@ class Driver:
        return self

    def __exit__(self, *_: Any) -> None:
        with rootlog.nested("cleanup"):
        with self.logger.nested("cleanup"):
            self.race_timer.cancel()
            for machine in self.machines:
                machine.release()

    def subtest(self, name: str) -> Iterator[None]:
        """Group logs under a given test name"""
        with rootlog.subtest(name):
        with self.logger.subtest(name):
            try:
                yield
                return True
            except Exception as e:
                rootlog.error(f'Test "{name}" failed with error: "{e}"')
                self.logger.error(f'Test "{name}" failed with error: "{e}"')
                raise e

    def test_symbols(self) -> Dict[str, Any]:
@@ -118,7 +122,7 @@ class Driver:
            machines=self.machines,
            vlans=self.vlans,
            driver=self,
            log=rootlog,
            log=self.logger,
            os=os,
            create_machine=self.create_machine,
            subtest=subtest,
@@ -150,13 +154,13 @@ class Driver:

    def test_script(self) -> None:
        """Run the test script"""
        with rootlog.nested("run the VM test script"):
        with self.logger.nested("run the VM test script"):
            symbols = self.test_symbols()  # call eagerly
            exec(self.tests, symbols, None)

    def run_tests(self) -> None:
        """Run the test script (for non-interactive test runs)"""
        rootlog.info(
        self.logger.info(
            f"Test will time out and terminate in {self.global_timeout} seconds"
        )
        self.race_timer.start()
@@ -168,13 +172,13 @@ class Driver:

    def start_all(self) -> None:
        """Start all machines"""
        with rootlog.nested("start all VMs"):
        with self.logger.nested("start all VMs"):
            for machine in self.machines:
                machine.start()

    def join_all(self) -> None:
        """Wait for all machines to shut down"""
        with rootlog.nested("wait for all VMs to finish"):
        with self.logger.nested("wait for all VMs to finish"):
            for machine in self.machines:
                machine.wait_for_shutdown()
            self.race_timer.cancel()
@@ -182,7 +186,7 @@ class Driver:
    def terminate_test(self) -> None:
        # This will be usually running in another thread than
        # the thread actually executing the test script.
        with rootlog.nested("timeout reached; test terminating..."):
        with self.logger.nested("timeout reached; test terminating..."):
            for machine in self.machines:
                machine.release()
            # As we cannot `sys.exit` from another thread
@@ -227,7 +231,7 @@ class Driver:
                    f"Unsupported arguments passed to create_machine: {args}"
                )

            rootlog.warning(
            self.logger.warning(
                Fore.YELLOW
                + Style.BRIGHT
                + "WARNING: Using create_machine with a single dictionary argument is deprecated and will be removed in NixOS 24.11"
@@ -246,13 +250,14 @@ class Driver:
            start_command=cmd,
            name=name,
            keep_vm_state=keep_vm_state,
            logger=self.logger,
        )

    def serial_stdout_on(self) -> None:
        rootlog.print_serial_logs(True)
        self.logger.print_serial_logs(True)

    def serial_stdout_off(self) -> None:
        rootlog.print_serial_logs(False)
        self.logger.print_serial_logs(False)

    def check_polling_conditions(self) -> None:
        for condition in self.polling_conditions:
@@ -271,6 +276,7 @@ class Driver:
            def __init__(self, fun: Callable):
                self.condition = PollingCondition(
                    fun,
                    driver.logger,
                    seconds_interval,
                    description,
                )
@@ -285,15 +291,17 @@ class Driver:
            def wait(self, timeout: int = 900) -> None:
                def condition(last: bool) -> bool:
                    if last:
                        rootlog.info(f"Last chance for {self.condition.description}")
                        driver.logger.info(
                            f"Last chance for {self.condition.description}"
                        )
                    ret = self.condition.check(force=True)
                    if not ret and not last:
                        rootlog.info(
                        driver.logger.info(
                            f"({self.condition.description} failure not fatal yet)"
                        )
                    return ret

                with rootlog.nested(f"waiting for {self.condition.description}"):
                with driver.logger.nested(f"waiting for {self.condition.description}"):
                    retry(condition, timeout=timeout)

        if fun_ is None:
+0 −3
Original line number Diff line number Diff line
@@ -307,6 +307,3 @@ class XMLLogger(AbstractLogger):
        self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")

        self.xml.endElement("nest")


rootlog: CompositeLogger = CompositeLogger([TerminalLogger()])
+9 −7
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from pathlib import Path
from queue import Queue
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from test_driver.logger import rootlog
from test_driver.logger import AbstractLogger

from .qmp import QMPSession

@@ -270,6 +270,7 @@ class Machine:
        out_dir: Path,
        tmp_dir: Path,
        start_command: StartCommand,
        logger: AbstractLogger,
        name: str = "machine",
        keep_vm_state: bool = False,
        callbacks: Optional[List[Callable]] = None,
@@ -280,6 +281,7 @@ class Machine:
        self.name = name
        self.start_command = start_command
        self.callbacks = callbacks if callbacks is not None else []
        self.logger = logger

        # set up directories
        self.shared_dir = self.tmp_dir / "shared-xchg"
@@ -307,15 +309,15 @@ class Machine:
        return self.booted and self.connected

    def log(self, msg: str) -> None:
        rootlog.log(msg, {"machine": self.name})
        self.logger.log(msg, {"machine": self.name})

    def log_serial(self, msg: str) -> None:
        rootlog.log_serial(msg, self.name)
        self.logger.log_serial(msg, self.name)

    def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
        my_attrs = {"machine": self.name}
        my_attrs.update(attrs)
        return rootlog.nested(msg, my_attrs)
        return self.logger.nested(msg, my_attrs)

    def wait_for_monitor_prompt(self) -> str:
        assert self.monitor is not None
@@ -1113,8 +1115,8 @@ class Machine:

    def cleanup_statedir(self) -> None:
        shutil.rmtree(self.state_dir)
        rootlog.log(f"deleting VM state directory {self.state_dir}")
        rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
        self.logger.log(f"deleting VM state directory {self.state_dir}")
        self.logger.log("if you want to keep the VM state, pass --keep-vm-state")

    def shutdown(self) -> None:
        """
@@ -1221,7 +1223,7 @@ class Machine:
    def release(self) -> None:
        if self.pid is None:
            return
        rootlog.info(f"kill machine (pid {self.pid})")
        self.logger.info(f"kill machine (pid {self.pid})")
        assert self.process
        assert self.shell
        assert self.monitor
+7 −4
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import time
from math import isfinite
from typing import Callable, Optional

from .logger import rootlog
from test_driver.logger import AbstractLogger


class PollingConditionError(Exception):
@@ -13,6 +13,7 @@ class PollingCondition:
    condition: Callable[[], bool]
    seconds_interval: float
    description: Optional[str]
    logger: AbstractLogger

    last_called: float
    entry_count: int
@@ -20,11 +21,13 @@ class PollingCondition:
    def __init__(
        self,
        condition: Callable[[], Optional[bool]],
        logger: AbstractLogger,
        seconds_interval: float = 2.0,
        description: Optional[str] = None,
    ):
        self.condition = condition  # type: ignore
        self.seconds_interval = seconds_interval
        self.logger = logger

        if description is None:
            if condition.__doc__:
@@ -41,7 +44,7 @@ class PollingCondition:
        if (self.entered or not self.overdue) and not force:
            return True

        with self, rootlog.nested(self.nested_message):
        with self, self.logger.nested(self.nested_message):
            time_since_last = time.monotonic() - self.last_called
            last_message = (
                f"Time since last: {time_since_last:.2f}s"
@@ -49,13 +52,13 @@ class PollingCondition:
                else "(not called yet)"
            )

            rootlog.info(last_message)
            self.logger.info(last_message)
            try:
                res = self.condition()  # type: ignore
            except Exception:
                res = False
            res = res is None or res
            rootlog.info(self.status_message(res))
            self.logger.info(self.status_message(res))
            return res

    def maybe_raise(self) -> None:
Loading