Unverified Commit 2e5b118a authored by Vladimír Čunát's avatar Vladimír Čunát Committed by GitHub
Browse files

staging-nixos -> staging-next (#501226)

parents 7d1a0856 d444bfb1
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -19,9 +19,12 @@
  qemu_test,
  setuptools,
  socat,
  systemd,
  tesseract4,
  util-linux,
  vde2,

  enableNspawn ? false,
  enableOCR ? false,
  extraPythonPackages ? (_: [ ]),
}:
@@ -51,8 +54,12 @@ buildPythonApplication {
    netpbm
    qemu_pkg
    socat
    util-linux
    vde2
  ]
  ++ lib.optionals enableNspawn [
    systemd
  ]
  ++ lib.optionals enableOCR [
    imagemagick_light
    tesseract4
+1 −1
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ def main() -> None:

    class_definitions = (node for node in module.body if isinstance(node, ast.ClassDef))

    machine_class = next(filter(lambda x: x.name == "Machine", class_definitions))
    machine_class = next(filter(lambda x: x.name == "BaseMachine", class_definitions))
    assert machine_class is not None

    function_definitions = [
+77 −18
Original line number Diff line number Diff line
import argparse
import os
import sys
import time
import warnings
from pathlib import Path

import ptpython.ipython
@@ -16,7 +18,7 @@ from test_driver.logger import (


class EnvDefault(argparse.Action):
    """An argpars Action that takes values from the specified
    """An argparse Action that takes values from the specified
    environment variable as the flags default value.
    """

@@ -55,9 +57,15 @@ def writeable_dir(arg: str) -> Path:
def main() -> None:
    arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
    arg_parser.add_argument(
        "-K",
        "--keep-vm-state",
        help="re-use a VM state coming from a previous run",
        help=argparse.SUPPRESS,
        dest="keep_machine_state",
        action="store_true",
    )
    arg_parser.add_argument(
        "-K",
        "--keep-machine-state",
        help="re-use a machine state coming from a previous run",
        action="store_true",
    )
    arg_parser.add_argument(
@@ -71,13 +79,37 @@ def main() -> None:
        help="Enable interactive debugging breakpoints for sandboxed runs",
    )
    arg_parser.add_argument(
        "--start-scripts",
        metavar="START-SCRIPT",
        "--vm-names",
        metavar="VM-NAME",
        action=EnvDefault,
        envvar="vmNames",
        nargs="*",
        help="names of participating virtual machines",
    )
    arg_parser.add_argument(
        "--vm-start-scripts",
        metavar="VM-START-SCRIPT",
        action=EnvDefault,
        envvar="startScripts",
        envvar="vmStartScripts",
        nargs="*",
        help="start scripts for participating virtual machines",
    )
    arg_parser.add_argument(
        "--container-names",
        metavar="CONTAINER-NAME",
        action=EnvDefault,
        envvar="containerNames",
        nargs="*",
        help="names of participating containers",
    )
    arg_parser.add_argument(
        "--container-start-scripts",
        metavar="CONTAINER-START-SCRIPT",
        action=EnvDefault,
        envvar="containerStartScripts",
        nargs="*",
        help="start scripts for participating containers",
    )
    arg_parser.add_argument(
        "--vlans",
        metavar="VLAN",
@@ -97,8 +129,8 @@ def main() -> None:
    arg_parser.add_argument(
        "-o",
        "--output_directory",
        help="""The path to the directory where outputs copied from the VM will be placed.
                By e.g. Machine.copy_from_vm or Machine.screenshot""",
        help="""The path to the directory where outputs copied from the machine will be placed.
                By e.g. NspawnMachine.copy_from_machine or QemuMachine.screenshot""",
        default=Path.cwd(),
        type=writeable_dir,
    )
@@ -122,6 +154,12 @@ def main() -> None:

    args = arg_parser.parse_args()

    if "--keep-vm-state" in sys.argv:
        warnings.warn(
            "The flag '--keep-vm-state' is deprecated. Use '--keep-machine-state' instead.",
            DeprecationWarning,
        )

    output_directory = args.output_directory.resolve()
    logger = CompositeLogger([TerminalLogger()])

@@ -131,21 +169,33 @@ def main() -> None:
    if args.junit_xml:
        logger.add_logger(JunitXMLLogger(output_directory / args.junit_xml))

    if not args.keep_vm_state:
        logger.info("Machine state will be reset. To keep it, pass --keep-vm-state")
    if not args.keep_machine_state:
        logger.info(
            "Machine state will be reset. To keep it, pass --keep-machine-state"
        )

    debugger: DebugAbstract = DebugNop()
    if args.debug_hook_attach is not None:
        debugger = Debug(logger, args.debug_hook_attach)

    assert len(args.vm_names) == len(args.vm_start_scripts), (
        f"the number of vm names and vm start scripts must be the same: {args.vm_names} vs. {args.vm_start_scripts}"
    )
    assert len(args.container_names) == len(args.container_start_scripts), (
        f"the number of container names and container start scripts must be the same: {args.container_names} vs. {args.container_start_scripts}"
    )

    with Driver(
        args.start_scripts,
        args.vlans,
        args.testscript.read_text(),
        output_directory,
        logger,
        args.keep_vm_state,
        args.global_timeout,
        vm_names=args.vm_names,
        vm_start_scripts=args.vm_start_scripts,
        container_names=args.container_names,
        container_start_scripts=args.container_start_scripts,
        vlans=args.vlans,
        tests=args.testscript.read_text(),
        out_dir=output_directory,
        logger=logger,
        keep_machine_state=args.keep_machine_state,
        global_timeout=args.global_timeout,
        debug=debugger,
    ) as driver:
        if offset := args.dump_vsocks:
@@ -170,7 +220,16 @@ def generate_driver_symbols() -> None:
    in user's test scripts. That list is then used by pyflakes to lint those
    scripts.
    """
    d = Driver([], [], "", Path(), CompositeLogger([]))
    d = Driver(
        vm_names=[],
        vm_start_scripts=[],
        container_names=[],
        container_start_scripts=[],
        vlans=[],
        tests="",
        out_dir=Path(),
        logger=CompositeLogger([]),
    )
    test_symbols = d.test_symbols()
    with open("driver-symbols", "w") as fp:
        fp.write(",".join(test_symbols.keys()))
+119 −30
Original line number Diff line number Diff line
import os
import re
import signal
import subprocess
import sys
import tempfile
import threading
@@ -16,7 +17,12 @@ from colorama import Style
from test_driver.debug import DebugAbstract, DebugNop
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.machine import (
    BaseMachine,
    NspawnMachine,
    QemuMachine,
    retry,
)
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan

@@ -63,7 +69,8 @@ class Driver:

    tests: str
    vlans: list[VLan]
    machines: list[Machine]
    machines_qemu: list[QemuMachine]
    machines_nspawn: list[NspawnMachine]
    polling_conditions: list[PollingCondition]
    global_timeout: int
    race_timer: threading.Timer
@@ -72,12 +79,15 @@ class Driver:

    def __init__(
        self,
        start_scripts: list[str],
        vm_names: list[str],
        vm_start_scripts: list[str],
        container_names: list[str],
        container_start_scripts: list[str],
        vlans: list[int],
        tests: str,
        out_dir: Path,
        logger: AbstractLogger,
        keep_vm_state: bool = False,
        keep_machine_state: bool = False,
        global_timeout: int = 24 * 60 * 60 * 7,
        debug: DebugAbstract = DebugNop(),
    ):
@@ -94,25 +104,95 @@ class Driver:
            vlans = list(set(vlans))
            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]

        def cmd(scripts: list[str]) -> Iterator[NixStartScript]:
            for s in scripts:
                yield NixStartScript(s)

        self.polling_conditions = []

        self.machines = [
            Machine(
                start_command=cmd,
                keep_vm_state=keep_vm_state,
                name=cmd.machine_name,
        self.machines_qemu = [
            QemuMachine(
                name=name,
                start_command=vm_start_script,
                keep_machine_state=keep_machine_state,
                tmp_dir=tmp_dir,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
                logger=self.logger,
            )
            for cmd in cmd(start_scripts)
            for name, vm_start_script in zip(vm_names, vm_start_scripts)
        ]

        if len(container_start_scripts) > 0:
            self._init_nspawn_environment()

        self.machines_nspawn = [
            NspawnMachine(
                name=name,
                start_command=container_start_script,
                tmp_dir=tmp_dir,
                logger=self.logger,
                keep_machine_state=keep_machine_state,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
            )
            for name, container_start_script in zip(
                container_names,
                container_start_scripts,
            )
        ]

    def _init_nspawn_environment(self) -> None:
        assert os.geteuid() == 0, (
            f"systemd-nspawn requires root to work. You are {os.geteuid()}"
        )

        # set up prerequisites for systemd-nspawn containers.
        # these are not guaranteed to be set up in the Nix sandbox.
        # if running interactively as root, these will already be set up.

        # check if /run is writable by root
        if not os.access("/run", os.W_OK):
            Path("/run").mkdir(parents=True, exist_ok=True)
            subprocess.run(["mount", "-t", "tmpfs", "none", "/run"], check=True)
            Path("/run/netns").mkdir(parents=True, exist_ok=True)

        # check if /var/run is a symlink to /run
        if not (os.path.exists("/var/run") and os.path.samefile("/var/run", "/run")):
            Path("/var").mkdir(parents=True, exist_ok=True)
            subprocess.run(["ln", "-s", "/run", "/var/run"], check=True)

        # check if /sys/fs/cgroup is mounted as cgroup2
        with open("/proc/mounts", encoding="utf-8") as mounts:
            for line in mounts:
                parts = line.split()
                if len(parts) >= 3 and parts[1] == "/sys/fs/cgroup":
                    if parts[2] == "cgroup2":
                        break
            else:
                Path("/sys/fs/cgroup").mkdir(parents=True, exist_ok=True)
                subprocess.run(
                    ["mount", "-t", "cgroup2", "none", "/sys/fs/cgroup"], check=True
                )

        # systemd-nspawn requires that /etc/os-release exists
        # It supports SYSTEMD_NSPAWN_CHECK_OS_RELEASE=0, but that
        # would try to "fix" it by bind mounting, which is worse.
        if not os.path.isfile("/etc/os-release"):
            subprocess.run(["touch", "/etc/os-release"], check=True)

        # ensure /etc/machine-id exists and is non-empty
        if (
            not os.path.isfile("/etc/machine-id")
            or os.path.getsize("/etc/machine-id") == 0
        ):
            subprocess.run(
                ["systemd-machine-id-setup"], check=True
            )  # set up /etc/machine-id

    @property
    def machines(self) -> list[QemuMachine | NspawnMachine]:
        machines = self.machines_qemu + self.machines_nspawn
        # Sort the machines by name for consistency with `nodesAndContainers` in <nixos/lib/testing/network.nix>.
        machines.sort(key=lambda machine: machine.name)
        return machines

    def __enter__(self) -> "Driver":
        return self

@@ -148,7 +228,8 @@ class Driver:
        general_symbols = dict(
            start_all=self.start_all,
            test_script=self.test_script,
            machines=self.machines,
            machines_qemu=self.machines_qemu,
            machines_nspawn=self.machines_nspawn,
            vlans=self.vlans,
            driver=self,
            log=self.logger,
@@ -161,7 +242,7 @@ class Driver:
            serial_stdout_off=self.serial_stdout_off,
            serial_stdout_on=self.serial_stdout_on,
            polling_condition=self.polling_condition,
            Machine=Machine,  # for typing
            BaseMachine=BaseMachine,  # for typing
            t=AssertionTester(),
            debug=self.debug,
        )
@@ -186,14 +267,14 @@ class Driver:
    def dump_machine_ssh(self, offset: int) -> None:
        print("SSH backdoor enabled, the machines can be accessed like this:")
        print(
            f"{Style.BRIGHT}Note:{Style.RESET_ALL} this requires {Style.BRIGHT}systemd-ssh-proxy(1){Style.RESET_ALL} to be enabled (default on NixOS 25.05 and newer)."
            f"{Style.BRIGHT}Note:{Style.RESET_ALL} vsocks require {Style.BRIGHT}systemd-ssh-proxy(1){Style.RESET_ALL} to be enabled (default on NixOS 25.05 and newer)."
        )
        names = [machine.name for machine in self.machines]
        longest_name = len(max(names, key=len))
        for num, name in enumerate(names, start=offset + 1):
        longest_name = len(max((machine.name for machine in self.machines), key=len))
        for index, machine in enumerate(self.machines, start=offset + 1):
            name = machine.name
            spaces = " " * (longest_name - len(name) + 2)
            print(
                f"    {name}:{spaces}{Style.BRIGHT}ssh -o User=root vsock/{num}{Style.RESET_ALL}"
                f"    {name}:{spaces}{Style.BRIGHT}{machine.ssh_backdoor_command(index)}{Style.RESET_ALL}"
            )

    def test_script(self) -> None:
@@ -252,8 +333,16 @@ class Driver:
    def start_all(self) -> None:
        """Start all machines"""
        with self.logger.nested("start all VMs"):
            threads = []
            for machine in self.machines:
                machine.start()
                # Create a thread for each machine's start method
                t = threading.Thread(target=machine.start, name=f"start-{machine.name}")
                threads.append(t)
                t.start()

            # Wait for all startup threads to complete before proceeding
            for t in threads:
                t.join()

    def join_all(self) -> None:
        """Wait for all machines to shut down"""
@@ -279,19 +368,19 @@ class Driver:
        start_command: str,
        *,
        name: str | None = None,
        keep_vm_state: bool = False,
    ) -> Machine:
        keep_machine_state: bool = False,
    ) -> BaseMachine:
        """
        Create a `QemuMachine`. This currently only supports qemu "nodes", not containers.
        """
        tmp_dir = get_tmp_dir()

        cmd = NixStartScript(start_command)
        name = name or cmd.machine_name

        return Machine(
        return QemuMachine(
            tmp_dir=tmp_dir,
            out_dir=self.out_dir,
            start_command=cmd,
            start_command=start_command,
            name=name,
            keep_vm_state=keep_vm_state,
            keep_machine_state=keep_machine_state,
            logger=self.logger,
        )

+764 −426

File changed.

Preview size limit exceeded, changes collapsed.

Loading