Commit 23f1e637 authored by Kierán Meinhardt's avatar Kierán Meinhardt
Browse files

nixos/test-driver: add support for nspawn containers

parent 799cafcc
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -19,7 +19,9 @@
  qemu_test,
  setuptools,
  socat,
  systemd,
  tesseract4,
  util-linux,
  vde2,

  enableOCR ? false,
@@ -51,7 +53,9 @@ buildPythonApplication {
    netpbm
    qemu_pkg
    socat
    util-linux
    vde2
    systemd
  ]
  ++ lib.optionals enableOCR [
    imagemagick_light
+24 −0
Original line number Diff line number Diff line
@@ -86,6 +86,22 @@ def main() -> None:
        nargs="*",
        help="start scripts for participating virtual machines",
    )
    arg_parser.add_argument(
        "--container-names",
        metavar="CONTAINER-NAME",
        action=EnvDefault,
        envvar="containerNames",
        nargs="*",
        help="names of participating containers",
    )
    arg_parser.add_argument(
        "--container-start-scripts",
        metavar="CONTAINER-START-SCRIPT",
        action=EnvDefault,
        envvar="containerStartScripts",
        nargs="*",
        help="start scripts for participating containers",
    )
    arg_parser.add_argument(
        "--vlans",
        metavar="VLAN",
@@ -150,10 +166,16 @@ def main() -> None:
        assert len(args.vm_names) == len(args.vm_start_scripts), (
            f"the number of vm names and vm start scripts must be the same: {args.vm_names} vs. {args.vm_start_scripts}"
        )
    if args.container_names is not None and args.container_start_scripts is not None:
        assert len(args.container_names) == len(args.container_start_scripts), (
            f"the number of container names and container start scripts must be the same: {args.container_names} vs. {args.container_start_scripts}"
        )

    with Driver(
        vm_names=args.vm_names,
        vm_start_scripts=args.vm_start_scripts or [],
        container_names=args.container_names,
        container_start_scripts=args.container_start_scripts or [],
        vlans=args.vlans,
        tests=args.testscript.read_text(),
        out_dir=output_directory,
@@ -187,6 +209,8 @@ def generate_driver_symbols() -> None:
    d = Driver(
        vm_names=[],
        vm_start_scripts=[],
        container_names=[],
        container_start_scripts=[],
        vlans=[],
        tests="",
        out_dir=Path(),
+84 −6
Original line number Diff line number Diff line
import os
import re
import signal
import subprocess
import sys
import tempfile
import threading
@@ -16,7 +17,12 @@ from colorama import Style
from test_driver.debug import DebugAbstract, DebugNop
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger
from test_driver.machine import BaseMachine, QemuMachine, retry
from test_driver.machine import (
    BaseMachine,
    NspawnMachine,
    QemuMachine,
    retry,
)
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan

@@ -64,6 +70,7 @@ class Driver:
    tests: str
    vlans: list[VLan]
    vm_machines: list[QemuMachine]
    container_machines: list[NspawnMachine]
    polling_conditions: list[PollingCondition]
    global_timeout: int
    race_timer: threading.Timer
@@ -74,6 +81,8 @@ class Driver:
        self,
        vm_names: list[str] | None,
        vm_start_scripts: list[str],
        container_names: list[str] | None,
        container_start_scripts: list[str],
        vlans: list[int],
        tests: str,
        out_dir: Path,
@@ -112,10 +121,75 @@ class Driver:
            )
        ]

        if len(container_start_scripts) > 0:
            self._init_nspawn_environment()

        self.container_machines = [
            NspawnMachine(
                name=name,
                start_command=container_start_script,
                tmp_dir=tmp_dir,
                logger=self.logger,
                keep_vm_state=keep_vm_state,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
            )
            for name, container_start_script in zip(
                container_names or (len(container_start_scripts) * [None]),
                container_start_scripts,
            )
        ]

    def _init_nspawn_environment(self) -> None:
        assert os.geteuid() == 0, (
            f"systemd-nspawn requires root to work. You are {os.geteuid()}"
        )

        # set up prerequisites for systemd-nspawn containers.
        # these are not guaranteed to be set up in the Nix sandbox.
        # if running interactively as root, these will already be set up.

        # check if /run is writable by root
        if not os.access("/run", os.W_OK):
            Path("/run").mkdir(parents=True, exist_ok=True)
            subprocess.run(["mount", "-t", "tmpfs", "none", "/run"], check=True)
            Path("/run/netns").mkdir(parents=True, exist_ok=True)

        # check if /var/run is a symlink to /run
        if not (os.path.exists("/var/run") and os.path.samefile("/var/run", "/run")):
            Path("/var").mkdir(parents=True, exist_ok=True)
            subprocess.run(["ln", "-s", "/run", "/var/run"], check=True)

        # check if /sys/fs/cgroup is mounted as cgroup2
        with open("/proc/mounts", encoding="utf-8") as mounts:
            for line in mounts:
                parts = line.split()
                if len(parts) >= 3 and parts[1] == "/sys/fs/cgroup":
                    if parts[2] == "cgroup2":
                        break
            else:
                Path("/sys/fs/cgroup").mkdir(parents=True, exist_ok=True)
                subprocess.run(
                    ["mount", "-t", "cgroup2", "none", "/sys/fs/cgroup"], check=True
                )

        # ensure /etc/os-release exists
        if not os.path.isfile("/etc/os-release"):
            subprocess.run(["touch", "/etc/os-release"], check=True)

        # ensure /etc/machine-id exists and is non-empty
        if (
            not os.path.isfile("/etc/machine-id")
            or os.path.getsize("/etc/machine-id") == 0
        ):
            subprocess.run(
                ["systemd-machine-id-setup"], check=True
            )  # set up /etc/machine-id

    @property
    def machines(self) -> list[QemuMachine]:
        machines = self.vm_machines
        # Sort the machines by name for consistency with `nodes` in <nixos/lib/testing/network.nix>.
    def machines(self) -> list[QemuMachine | NspawnMachine]:
        machines = self.vm_machines + self.container_machines
        # Sort the machines by name for consistency with `nodesAndContainers` in <nixos/lib/testing/network.nix>.
        machines.sort(key=lambda machine: machine.name)
        return machines

@@ -155,6 +229,7 @@ class Driver:
            start_all=self.start_all,
            test_script=self.test_script,
            vm_machines=self.vm_machines,
            container_machines=self.container_machines,
            vlans=self.vlans,
            driver=self,
            log=self.logger,
@@ -286,13 +361,16 @@ class Driver:
        *,
        name: str | None = None,
        keep_vm_state: bool = False,
    ) -> QemuMachine:
    ) -> BaseMachine:
        """
        Create a `QemuMachine`. This currently only supports qemu "nodes", not containers.
        """
        tmp_dir = get_tmp_dir()

        return QemuMachine(
            start_command=start_command,
            tmp_dir=tmp_dir,
            out_dir=self.out_dir,
            start_command=start_command,
            name=name,
            keep_vm_state=keep_vm_state,
            logger=self.logger,
+177 −18
Original line number Diff line number Diff line
import base64
import io
import json
import os
import platform
import queue
@@ -16,6 +17,7 @@ import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
from functools import cached_property
from pathlib import Path
from queue import Queue
from typing import Any
@@ -218,7 +220,6 @@ class BaseMachine(ABC):
    name: str
    callbacks: list[Callable]
    tmp_dir: Path

    keep_vm_state: bool

    def __repr__(self) -> str:
@@ -239,7 +240,7 @@ class BaseMachine(ABC):
        self.callbacks = callbacks if callbacks is not None else []
        self.tmp_dir = tmp_dir

        # Note: "vm" is a bit of a misnomer here.
        # Note: "vm" is a bit of a misnomer here as we support both QEMU vms and nspawn containers.
        # Consider renaming to something more generic ("machine"?)
        self.keep_vm_state = keep_vm_state

@@ -269,26 +270,13 @@ class BaseMachine(ABC):
        return self.logger.nested(msg, my_attrs)

    @abstractmethod
    def is_up(self) -> bool:
        """
        Check whether the machine is running.
        """
        pass
    def is_up(self) -> bool: ...

    @abstractmethod
    def start(self) -> None:
        """
        Start the machine.
        """
        pass
    def start(self) -> None: ...

    @abstractmethod
    def wait_for_shutdown(self) -> None:
        """
        Wait for the machine to power off. This does *not* initiate a shutdown;
        that's usually done via `shutdown()`.
        """
        pass
    def wait_for_shutdown(self) -> None: ...

    def systemctl(self, q: str, user: str | None = None) -> tuple[int, str]:
        """
@@ -1384,3 +1372,174 @@ class QemuMachine(BaseMachine):
        )
        self.connected = False
        self.connect()


class NspawnMachine(BaseMachine):
    """
    A handle to a systemd-nspawn container machine with this name, that also
    knows how to manage the machine lifecycle with the help of a start script / command.
    """

    start_command: str
    tmp_dir: Path
    process: subprocess.Popen | None
    pid: int | None

    @staticmethod
    def machine_name_from_start_command(start_command: str) -> str:
        match = re.search("run-(.+)-nspawn", os.path.basename(start_command))
        assert match is not None, f"Could not extract node name from {start_command}"
        return match.group(1)

    def __init__(
        self,
        out_dir: Path,
        name: str | None,
        start_command: str,
        tmp_dir: Path,
        logger: AbstractLogger,
        callbacks: list[Callable] | None = None,
        keep_vm_state: bool = False,
    ):
        # TODO: don't compute `name` from `start_command` path, instead thread it down explicitly.
        # See analogous TODO in `QemuStartCommand::machine_name`.
        super().__init__(
            out_dir=out_dir,
            name=name or self.machine_name_from_start_command(start_command),
            logger=logger,
            callbacks=callbacks,
            tmp_dir=tmp_dir,
            keep_vm_state=keep_vm_state,
        )

        self.start_command = start_command
        self.process = None
        self.pid = None

    def ssh_backdoor_command(self, index: int) -> str:
        # get IP from `ip addr` inside the container:
        ip_status, ip_output = self._execute("ip -j addr show")
        assert ip_status == 0, "Failed to get IP addresses from container"
        ip_output_data = json.loads(ip_output)
        ip_addresses = [
            addr_info.get("local")
            for iface in ip_output_data
            if iface.get("ifname") != "lo"
            for addr_info in iface.get("addr_info", [])
            if addr_info.get("family") == "inet"
        ]

        return "\n".join(f"ssh -o User=root {addr}" for addr in ip_addresses)

    def release(self) -> None:
        if self.pid is None:
            return

        self.logger.info(f"kill NspawnMachine (pid {self.pid})")
        assert self.process is not None
        self.process.terminate()
        self.process = None

    def is_up(self) -> bool:
        return self.process is not None

    @cached_property
    def get_systemd_process(self) -> int:
        assert self.process is not None, "Machine not started"
        assert self.process.stdout is not None, "Machine has no stdout"

        systemd_nspawn_pid = None
        for line_bytes in self.process.stdout:
            line = line_bytes.decode()
            print(line, end="")

            systemd_nspawn_pid_prefix = "systemd-nspawn's PID is "
            if line.startswith(systemd_nspawn_pid_prefix):
                systemd_nspawn_pid = int(line.removeprefix(systemd_nspawn_pid_prefix))

            if (
                line.startswith("systemd[1]: Startup finished in")
                or "Welcome to NixOS" in line
            ):
                assert systemd_nspawn_pid is not None, "Must find systemd-nspawn PID"
                break
        else:
            raise RuntimeError(f"Failed to start container {self.name}")

        childs = (
            Path(f"/proc/{systemd_nspawn_pid}/task/{systemd_nspawn_pid}/children")
            .read_text()
            .split()
        )
        assert len(childs) == 1, (
            f"Expected exactly one child process for systemd-nspawn, got {childs}"
        )
        (child,) = childs

        try:
            return int(child)
        except ValueError as e:
            raise RuntimeError(f"Failed to parse child process id {child}") from e

    def _execute(
        self,
        command: str,
        check_return: bool = True,
        check_output: bool = True,
        timeout: int | None = 900,
    ) -> tuple[int, str]:
        self.start()

        container_pid = self.get_systemd_process
        nsenter = shutil.which("nsenter")
        assert nsenter is not None

        # Pull in /etc/profile, and some shell sanity.
        command = f"set -eo pipefail; source /etc/profile; set -xu; {command}"
        cp = subprocess.run(
            [
                nsenter,
                "--target",
                str(container_pid),
                "--mount",
                "--uts",
                "--ipc",
                "--net",
                "--pid",
                "--cgroup",
                "/bin/sh",
                "-c",
                command,
            ],
            env={},
            timeout=timeout,
            stdout=subprocess.PIPE,
            text=True,
        )
        return (cp.returncode, cp.stdout)

    def start(self) -> None:
        if self.process is not None:
            return

        self.process = subprocess.Popen(
            [self.start_command],
            env={
                "RUN_NSPAWN_ROOT_DIR": str(self.state_dir),
                "RUN_NSPAWN_SHARED_DIR": str(self.shared_dir),
            },
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
        )

        self.pid = self.process.pid

        self.log(f"system-nspawn running (pid {self.pid})")

    def wait_for_shutdown(self) -> None:
        if self.process is None:
            return

        with self.nested("waiting for the container to power off"):
            self.process.wait()
            self.process = None
+1 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@
from test_driver.debug import DebugAbstract
from test_driver.driver import Driver
from test_driver.vlan import VLan
from test_driver.machine import BaseMachine, QemuMachine
from test_driver.machine import BaseMachine, NspawnMachine, QemuMachine
from test_driver.logger import AbstractLogger
from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union
from typing_extensions import Protocol
Loading