Unverified Commit cf1909dd authored by nikstur's avatar nikstur Committed by GitHub
Browse files

Merge pull request #301772 from hertrste/junit-xml-prod

nixos/test-driver: Add Junit XML report creation
parents df6d4213 d07866cd
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ python3Packages.buildPythonApplication {
    coreutils
    netpbm
    python3Packages.colorama
    python3Packages.junit-xml
    python3Packages.ptpython
    qemu_pkg
    socat
+4 −0
Original line number Diff line number Diff line
@@ -31,6 +31,10 @@ ignore_missing_imports = true
module = "ptpython.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "junit_xml.*"
ignore_missing_imports = true

[tool.black]
line-length = 88
target-version = ['py39']
+25 −5
Original line number Diff line number Diff line
@@ -6,7 +6,12 @@ from pathlib import Path
import ptpython.repl

from test_driver.driver import Driver
from test_driver.logger import rootlog
from test_driver.logger import (
    CompositeLogger,
    JunitXMLLogger,
    TerminalLogger,
    XMLLogger,
)


class EnvDefault(argparse.Action):
@@ -92,6 +97,11 @@ def main() -> None:
        default=Path.cwd(),
        type=writeable_dir,
    )
    arg_parser.add_argument(
        "--junit-xml",
        help="Enable JunitXML report generation to the given path",
        type=Path,
    )
    arg_parser.add_argument(
        "testscript",
        action=EnvDefault,
@@ -102,14 +112,24 @@ def main() -> None:

    args = arg_parser.parse_args()

    output_directory = args.output_directory.resolve()
    logger = CompositeLogger([TerminalLogger()])

    if "LOGFILE" in os.environ.keys():
        logger.add_logger(XMLLogger(os.environ["LOGFILE"]))

    if args.junit_xml:
        logger.add_logger(JunitXMLLogger(output_directory / args.junit_xml))

    if not args.keep_vm_state:
        rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
        logger.info("Machine state will be reset. To keep it, pass --keep-vm-state")

    with Driver(
        args.start_scripts,
        args.vlans,
        args.testscript.read_text(),
        args.output_directory.resolve(),
        output_directory,
        logger,
        args.keep_vm_state,
        args.global_timeout,
    ) as driver:
@@ -125,7 +145,7 @@ def main() -> None:
            tic = time.time()
            driver.run_tests()
            toc = time.time()
            rootlog.info(f"test script finished in {(toc-tic):.2f}s")
            logger.info(f"test script finished in {(toc-tic):.2f}s")


def generate_driver_symbols() -> None:
@@ -134,7 +154,7 @@ def generate_driver_symbols() -> None:
    in user's test scripts. That list is then used by pyflakes to lint those
    scripts.
    """
    d = Driver([], [], "", Path())
    d = Driver([], [], "", Path(), CompositeLogger([]))
    test_symbols = d.test_symbols()
    with open("driver-symbols", "w") as fp:
        fp.write(",".join(test_symbols.keys()))
+26 −18
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional

from colorama import Fore, Style

from test_driver.logger import rootlog
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan
@@ -49,6 +49,7 @@ class Driver:
    polling_conditions: List[PollingCondition]
    global_timeout: int
    race_timer: threading.Timer
    logger: AbstractLogger

    def __init__(
        self,
@@ -56,6 +57,7 @@ class Driver:
        vlans: List[int],
        tests: str,
        out_dir: Path,
        logger: AbstractLogger,
        keep_vm_state: bool = False,
        global_timeout: int = 24 * 60 * 60 * 7,
    ):
@@ -63,12 +65,13 @@ class Driver:
        self.out_dir = out_dir
        self.global_timeout = global_timeout
        self.race_timer = threading.Timer(global_timeout, self.terminate_test)
        self.logger = logger

        tmp_dir = get_tmp_dir()

        with rootlog.nested("start all VLans"):
        with self.logger.nested("start all VLans"):
            vlans = list(set(vlans))
            self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]

        def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
            for s in scripts:
@@ -84,6 +87,7 @@ class Driver:
                tmp_dir=tmp_dir,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
                logger=self.logger,
            )
            for cmd in cmd(start_scripts)
        ]
@@ -92,19 +96,19 @@ class Driver:
        return self

    def __exit__(self, *_: Any) -> None:
        with rootlog.nested("cleanup"):
        with self.logger.nested("cleanup"):
            self.race_timer.cancel()
            for machine in self.machines:
                machine.release()

    def subtest(self, name: str) -> Iterator[None]:
        """Group logs under a given test name"""
        with rootlog.nested("subtest: " + name):
        with self.logger.subtest(name):
            try:
                yield
                return True
            except Exception as e:
                rootlog.error(f'Test "{name}" failed with error: "{e}"')
                self.logger.error(f'Test "{name}" failed with error: "{e}"')
                raise e

    def test_symbols(self) -> Dict[str, Any]:
@@ -118,7 +122,7 @@ class Driver:
            machines=self.machines,
            vlans=self.vlans,
            driver=self,
            log=rootlog,
            log=self.logger,
            os=os,
            create_machine=self.create_machine,
            subtest=subtest,
@@ -150,13 +154,13 @@ class Driver:

    def test_script(self) -> None:
        """Run the test script"""
        with rootlog.nested("run the VM test script"):
        with self.logger.nested("run the VM test script"):
            symbols = self.test_symbols()  # call eagerly
            exec(self.tests, symbols, None)

    def run_tests(self) -> None:
        """Run the test script (for non-interactive test runs)"""
        rootlog.info(
        self.logger.info(
            f"Test will time out and terminate in {self.global_timeout} seconds"
        )
        self.race_timer.start()
@@ -168,13 +172,13 @@ class Driver:

    def start_all(self) -> None:
        """Start all machines"""
        with rootlog.nested("start all VMs"):
        with self.logger.nested("start all VMs"):
            for machine in self.machines:
                machine.start()

    def join_all(self) -> None:
        """Wait for all machines to shut down"""
        with rootlog.nested("wait for all VMs to finish"):
        with self.logger.nested("wait for all VMs to finish"):
            for machine in self.machines:
                machine.wait_for_shutdown()
            self.race_timer.cancel()
@@ -182,7 +186,7 @@ class Driver:
    def terminate_test(self) -> None:
        # This will be usually running in another thread than
        # the thread actually executing the test script.
        with rootlog.nested("timeout reached; test terminating..."):
        with self.logger.nested("timeout reached; test terminating..."):
            for machine in self.machines:
                machine.release()
            # As we cannot `sys.exit` from another thread
@@ -227,7 +231,7 @@ class Driver:
                    f"Unsupported arguments passed to create_machine: {args}"
                )

            rootlog.warning(
            self.logger.warning(
                Fore.YELLOW
                + Style.BRIGHT
                + "WARNING: Using create_machine with a single dictionary argument is deprecated and will be removed in NixOS 24.11"
@@ -246,13 +250,14 @@ class Driver:
            start_command=cmd,
            name=name,
            keep_vm_state=keep_vm_state,
            logger=self.logger,
        )

    def serial_stdout_on(self) -> None:
        rootlog._print_serial_logs = True
        self.logger.print_serial_logs(True)

    def serial_stdout_off(self) -> None:
        rootlog._print_serial_logs = False
        self.logger.print_serial_logs(False)

    def check_polling_conditions(self) -> None:
        for condition in self.polling_conditions:
@@ -271,6 +276,7 @@ class Driver:
            def __init__(self, fun: Callable):
                self.condition = PollingCondition(
                    fun,
                    driver.logger,
                    seconds_interval,
                    description,
                )
@@ -285,15 +291,17 @@ class Driver:
            def wait(self, timeout: int = 900) -> None:
                def condition(last: bool) -> bool:
                    if last:
                        rootlog.info(f"Last chance for {self.condition.description}")
                        driver.logger.info(
                            f"Last chance for {self.condition.description}"
                        )
                    ret = self.condition.check(force=True)
                    if not ret and not last:
                        rootlog.info(
                        driver.logger.info(
                            f"({self.condition.description} failure not fatal yet)"
                        )
                    return ret

                with rootlog.nested(f"waiting for {self.condition.description}"):
                with driver.logger.nested(f"waiting for {self.condition.description}"):
                    retry(condition, timeout=timeout)

        if fun_ is None:
+227 −23
Original line number Diff line number Diff line
import atexit
import codecs
import os
import sys
import time
import unicodedata
from contextlib import contextmanager
from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Dict, Iterator
from typing import Any, Dict, Iterator, List
from xml.sax.saxutils import XMLGenerator
from xml.sax.xmlreader import AttributesImpl

from colorama import Fore, Style
from junit_xml import TestCase, TestSuite


class Logger:
class AbstractLogger(ABC):
    @abstractmethod
    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
        pass

    @abstractmethod
    @contextmanager
    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        pass

    @abstractmethod
    @contextmanager
    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        pass

    @abstractmethod
    def info(self, *args, **kwargs) -> None:  # type: ignore
        pass

    @abstractmethod
    def warning(self, *args, **kwargs) -> None:  # type: ignore
        pass

    @abstractmethod
    def error(self, *args, **kwargs) -> None:  # type: ignore
        pass

    @abstractmethod
    def log_serial(self, message: str, machine: str) -> None:
        pass

    @abstractmethod
    def print_serial_logs(self, enable: bool) -> None:
        pass


class JunitXMLLogger(AbstractLogger):

    class TestCaseState:
        def __init__(self) -> None:
        self.logfile = os.environ.get("LOGFILE", "/dev/null")
        self.logfile_handle = codecs.open(self.logfile, "wb")
        self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
        self.queue: "Queue[Dict[str, str]]" = Queue()
            self.stdout = ""
            self.stderr = ""
            self.failure = False

        self.xml.startDocument()
        self.xml.startElement("logfile", attrs=AttributesImpl({}))
    def __init__(self, outfile: Path) -> None:
        self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
            "main": self.TestCaseState()
        }
        self.currentSubtest = "main"
        self.outfile: Path = outfile
        self._print_serial_logs = True
        atexit.register(self.close)

    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
        self.tests[self.currentSubtest].stdout += message + os.linesep

    @contextmanager
    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        old_test = self.currentSubtest
        self.tests.setdefault(name, self.TestCaseState())
        self.currentSubtest = name

        yield

        self.currentSubtest = old_test

    @contextmanager
    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        self.log(message)
        yield

    def info(self, *args, **kwargs) -> None:  # type: ignore
        self.tests[self.currentSubtest].stdout += args[0] + os.linesep

    def warning(self, *args, **kwargs) -> None:  # type: ignore
        self.tests[self.currentSubtest].stdout += args[0] + os.linesep

    def error(self, *args, **kwargs) -> None:  # type: ignore
        self.tests[self.currentSubtest].stderr += args[0] + os.linesep
        self.tests[self.currentSubtest].failure = True

    def log_serial(self, message: str, machine: str) -> None:
        if not self._print_serial_logs:
            return

        self.log(f"{machine} # {message}")

    def print_serial_logs(self, enable: bool) -> None:
        self._print_serial_logs = enable

    def close(self) -> None:
        with open(self.outfile, "w") as f:
            test_cases = []
            for name, test_case_state in self.tests.items():
                tc = TestCase(
                    name,
                    stdout=test_case_state.stdout,
                    stderr=test_case_state.stderr,
                )
                if test_case_state.failure:
                    tc.add_failure_info("test case failed")

                test_cases.append(tc)
            ts = TestSuite("NixOS integration test", test_cases)
            f.write(TestSuite.to_xml_string([ts]))


class CompositeLogger(AbstractLogger):
    def __init__(self, logger_list: List[AbstractLogger]) -> None:
        self.logger_list = logger_list

    def add_logger(self, logger: AbstractLogger) -> None:
        self.logger_list.append(logger)

    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
        for logger in self.logger_list:
            logger.log(message, attributes)

    @contextmanager
    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        with ExitStack() as stack:
            for logger in self.logger_list:
                stack.enter_context(logger.subtest(name, attributes))
            yield

    @contextmanager
    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        with ExitStack() as stack:
            for logger in self.logger_list:
                stack.enter_context(logger.nested(message, attributes))
            yield

    def info(self, *args, **kwargs) -> None:  # type: ignore
        for logger in self.logger_list:
            logger.info(*args, **kwargs)

    def warning(self, *args, **kwargs) -> None:  # type: ignore
        for logger in self.logger_list:
            logger.warning(*args, **kwargs)

    def error(self, *args, **kwargs) -> None:  # type: ignore
        for logger in self.logger_list:
            logger.error(*args, **kwargs)
        sys.exit(1)

    def print_serial_logs(self, enable: bool) -> None:
        for logger in self.logger_list:
            logger.print_serial_logs(enable)

    def log_serial(self, message: str, machine: str) -> None:
        for logger in self.logger_list:
            logger.log_serial(message, machine)


class TerminalLogger(AbstractLogger):
    def __init__(self) -> None:
        self._print_serial_logs = True

    def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
        if "machine" in attributes:
            return f"{attributes['machine']}: {message}"
        return message

    @staticmethod
    def _eprint(*args: object, **kwargs: Any) -> None:
        print(*args, file=sys.stderr, **kwargs)

    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
        self._eprint(self.maybe_prefix(message, attributes))

    @contextmanager
    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        with self.nested("subtest: " + name, attributes):
            yield

    @contextmanager
    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        self._eprint(
            self.maybe_prefix(
                Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
            )
        )

        tic = time.time()
        yield
        toc = time.time()
        self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")

    def info(self, *args, **kwargs) -> None:  # type: ignore
        self.log(*args, **kwargs)

    def warning(self, *args, **kwargs) -> None:  # type: ignore
        self.log(*args, **kwargs)

    def error(self, *args, **kwargs) -> None:  # type: ignore
        self.log(*args, **kwargs)

    def print_serial_logs(self, enable: bool) -> None:
        self._print_serial_logs = enable

    def log_serial(self, message: str, machine: str) -> None:
        if not self._print_serial_logs:
            return

        self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)


class XMLLogger(AbstractLogger):
    def __init__(self, outfile: str) -> None:
        self.logfile_handle = codecs.open(outfile, "wb")
        self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
        self.queue: "Queue[Dict[str, str]]" = Queue()

        self._print_serial_logs = True

        self.xml.startDocument()
        self.xml.startElement("logfile", attrs=AttributesImpl({}))

    def close(self) -> None:
        self.xml.endElement("logfile")
        self.xml.endDocument()
@@ -54,17 +260,19 @@ class Logger:

    def error(self, *args, **kwargs) -> None:  # type: ignore
        self.log(*args, **kwargs)
        sys.exit(1)

    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
        self._eprint(self.maybe_prefix(message, attributes))
        self.drain_log_queue()
        self.log_line(message, attributes)

    def print_serial_logs(self, enable: bool) -> None:
        self._print_serial_logs = enable

    def log_serial(self, message: str, machine: str) -> None:
        if not self._print_serial_logs:
            return

        self.enqueue({"msg": message, "machine": machine, "type": "serial"})
        if self._print_serial_logs:
            self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)

    def enqueue(self, item: Dict[str, str]) -> None:
        self.queue.put(item)
@@ -80,13 +288,12 @@ class Logger:
            pass

    @contextmanager
    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        self._eprint(
            self.maybe_prefix(
                Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
            )
        )
    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        with self.nested("subtest: " + name, attributes):
            yield

    @contextmanager
    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
        self.xml.startElement("nest", attrs=AttributesImpl({}))
        self.xml.startElement("head", attrs=AttributesImpl(attributes))
        self.xml.characters(message)
@@ -100,6 +307,3 @@ class Logger:
        self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")

        self.xml.endElement("nest")


rootlog = Logger()
Loading