Commit f1bcb617 authored by Jacek Galowicz's avatar Jacek Galowicz
Browse files

nixos-test-driver: use configuration file instead of scattered env vars

parent b998b489
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@
  ipython,
  junit-xml,
  ptpython,
  pydantic,
  python,
  remote-pdb,
  ruff,
@@ -46,6 +47,7 @@ buildPythonApplication {
    ipython
    junit-xml
    ptpython
    pydantic
    remote-pdb
  ]
  ++ extraPythonPackages python.pkgs;
+18 −83
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from pathlib import Path
import ptpython.ipython

from test_driver.debug import Debug, DebugAbstract, DebugNop
from test_driver.driver import Driver
from test_driver.driver import Driver, DriverConfiguration, load_driver_configuration
from test_driver.logger import (
    CompositeLogger,
    JunitXMLLogger,
@@ -57,6 +57,13 @@ def writeable_dir(arg: str) -> Path:

def main() -> None:
    arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
    arg_parser.add_argument(
        "-c",
        "--config",
        help="the test driver configuration file",
        type=Path,
        required=True,
    )
    arg_parser.add_argument(
        "--keep-vm-state",
        help=argparse.SUPPRESS,
@@ -79,54 +86,6 @@ def main() -> None:
        "--debug-hook-attach",
        help="Enable interactive debugging breakpoints for sandboxed runs",
    )
    arg_parser.add_argument(
        "--vm-names",
        metavar="VM-NAME",
        action=EnvDefault,
        envvar="vmNames",
        nargs="*",
        help="names of participating virtual machines",
    )
    arg_parser.add_argument(
        "--vm-start-scripts",
        metavar="VM-START-SCRIPT",
        action=EnvDefault,
        envvar="vmStartScripts",
        nargs="*",
        help="start scripts for participating virtual machines",
    )
    arg_parser.add_argument(
        "--container-names",
        metavar="CONTAINER-NAME",
        action=EnvDefault,
        envvar="containerNames",
        nargs="*",
        help="names of participating containers",
    )
    arg_parser.add_argument(
        "--container-start-scripts",
        metavar="CONTAINER-START-SCRIPT",
        action=EnvDefault,
        envvar="containerStartScripts",
        nargs="*",
        help="start scripts for participating containers",
    )
    arg_parser.add_argument(
        "--vlans",
        metavar="VLAN",
        action=EnvDefault,
        envvar="vlans",
        nargs="*",
        help="vlans to span by the driver",
    )
    arg_parser.add_argument(
        "--global-timeout",
        type=int,
        metavar="GLOBAL_TIMEOUT",
        action=EnvDefault,
        envvar="globalTimeout",
        help="Timeout in seconds for the whole test",
    )
    arg_parser.add_argument(
        "-o",
        "--output_directory",
@@ -140,18 +99,6 @@ def main() -> None:
        help="Enable JunitXML report generation to the given path",
        type=Path,
    )
    arg_parser.add_argument(
        "testscript",
        action=EnvDefault,
        envvar="testScript",
        help="the test script to run",
        type=Path,
    )
    arg_parser.add_argument(
        "--enable-ssh-backdoor",
        help="indicates that the interactive SSH backdoor is active and dumps information about it on start",
        action="store_true",
    )
    log_level_map = {level.name.lower(): level for level in LogLevel}
    arg_parser.add_argument(
        "--log-level",
@@ -191,28 +138,14 @@ def main() -> None:
    if args.debug_hook_attach is not None:
        debugger = Debug(logger, args.debug_hook_attach)

    assert len(args.vm_names) == len(args.vm_start_scripts), (
        f"the number of vm names and vm start scripts must be the same: {args.vm_names} vs. {args.vm_start_scripts}"
    )
    assert len(args.container_names) == len(args.container_start_scripts), (
        f"the number of container names and container start scripts must be the same: {args.container_names} vs. {args.container_start_scripts}"
    )

    with Driver(
        vm_names=args.vm_names,
        vm_start_scripts=args.vm_start_scripts,
        container_names=args.container_names,
        container_start_scripts=args.container_start_scripts,
        vlans=args.vlans,
        tests=args.testscript.read_text(),
        config=load_driver_configuration(args.config),
        out_dir=output_directory,
        logger=logger,
        keep_machine_state=args.keep_machine_state,
        global_timeout=args.global_timeout,
        debug=debugger,
        enable_ssh_backdoor=args.enable_ssh_backdoor,
    ) as driver:
        if args.enable_ssh_backdoor:
        if driver.config.enable_ssh_backdoor:
            driver.dump_machine_ssh()
        if args.interactive:
            history_dir = os.getcwd()
@@ -235,12 +168,14 @@ def generate_driver_symbols() -> None:
    scripts.
    """
    d = Driver(
        vm_names=[],
        vm_start_scripts=[],
        container_names=[],
        container_start_scripts=[],
        config=DriverConfiguration(
            vms=dict(),
            containers=dict(),
            vlans=[],
        tests="",
            global_timeout=0,
            enable_ssh_backdoor=False,
            test_script=Path("testScriptWithTypes"),
        ),
        out_dir=Path(),
        logger=CompositeLogger([]),
    )
+42 −36
Original line number Diff line number Diff line
import json
import os
import re
import signal
@@ -14,6 +15,7 @@ from typing import Any
from unittest import TestCase

from colorama import Style
from pydantic import BaseModel

from test_driver.debug import DebugAbstract, DebugNop
from test_driver.errors import MachineError, RequestedAssertionFailed
@@ -28,6 +30,26 @@ from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan


class NodeConfiguration(BaseModel):
    name: str
    start_script: Path


class DriverConfiguration(BaseModel):
    vms: dict[str, NodeConfiguration]
    containers: dict[str, NodeConfiguration]
    vlans: list[int]
    global_timeout: int
    enable_ssh_backdoor: bool
    test_script: Path


def load_driver_configuration(file_path: str) -> DriverConfiguration:
    with open(file_path) as f:
        data = json.load(f)
    return DriverConfiguration.model_validate(data)


class AssertionTester(TestCase):
    """
    Subclass of `unittest.TestCase` which is used in the
@@ -113,71 +135,55 @@ class Driver:
    """A handle to the driver that sets up the environment
    and runs the tests"""

    config: DriverConfiguration
    tests: str
    vlans: list[VLan] = []
    machines_qemu: list[QemuMachine] = []
    machines_nspawn: list[NspawnMachine] = []
    polling_conditions: list[PollingCondition]
    global_timeout: int
    race_timer: threading.Timer
    vm_start_scripts: dict[str, str]
    container_start_scripts: dict[str, str]
    vlan_ids: list[int]
    keep_machine_state: bool
    logger: AbstractLogger
    debug: DebugAbstract
    vhost_vsock: VHostDeviceVsock | None = None
    enable_ssh_backdoor: bool

    def __init__(
        self,
        vm_names: list[str],
        vm_start_scripts: list[str],
        container_names: list[str],
        container_start_scripts: list[str],
        vlans: list[int],
        tests: str,
        config: DriverConfiguration,
        out_dir: Path,
        logger: AbstractLogger,
        keep_machine_state: bool = False,
        global_timeout: int = 24 * 60 * 60 * 7,
        debug: DebugAbstract = DebugNop(),
        enable_ssh_backdoor: bool = False,
    ):
        self.tests = tests
        self.config = config
        self.tests = config.test_script.read_text()
        self.out_dir = out_dir
        self.global_timeout = global_timeout
        self.logger = logger
        self.debug = debug
        self.vlan_ids = list(set(vlans))
        self.polling_conditions = []
        self.keep_machine_state = keep_machine_state
        self.global_timeout = global_timeout
        self.vm_start_scripts = dict(zip(vm_names, vm_start_scripts))
        self.container_start_scripts = dict(
            zip(container_names, container_start_scripts)
        )
        self.enable_ssh_backdoor = enable_ssh_backdoor

    def __enter__(self) -> "Driver":
        self.race_timer = threading.Timer(self.global_timeout, self.terminate_test)
        self.race_timer = threading.Timer(
            self.config.global_timeout, self.terminate_test
        )
        tmp_dir = get_tmp_dir()

        with self.logger.nested("start all VLans"):
            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in self.vlan_ids]
            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in self.config.vlans]

        self.polling_conditions = []

        if self.enable_ssh_backdoor and self.vm_start_scripts:
        if self.config.enable_ssh_backdoor and self.config.vms:
            with self.logger.nested("start vhost-device-vsock"):
                self.vhost_vsock = VHostDeviceVsock(
                    tmp_dir, list(self.vm_start_scripts.keys())
                    tmp_dir, list(self.config.vms.keys())
                )

        self.machines_qemu = [
            QemuMachine(
                name=name,
                start_command=vm_start_script,
                start_command=vm_config.start_script.as_posix(),
                keep_machine_state=self.keep_machine_state,
                tmp_dir=tmp_dir,
                callbacks=[self.check_polling_conditions],
@@ -194,23 +200,23 @@ class Driver:
                    else None
                ),
            )
            for name, vm_start_script in self.vm_start_scripts.items()
            for name, vm_config in self.config.vms.items()
        ]

        if len(self.container_start_scripts) > 0 and in_nix_sandbox():
        if self.config.containers and in_nix_sandbox():
            self._init_nspawn_environment()

        self.machines_nspawn = [
            NspawnMachine(
                name=name,
                start_command=container_start_script,
                start_command=container_config.start_script.as_posix(),
                tmp_dir=tmp_dir,
                logger=self.logger,
                keep_machine_state=self.keep_machine_state,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
            )
            for name, container_start_script in self.container_start_scripts.items()
            for name, container_config in self.config.containers.items()
        ]

        return self
@@ -285,7 +291,7 @@ class Driver:
                except Exception as e:
                    self.logger.error(f"Error during cleanup of vlan{vlan.nr}: {e}")

            if self.enable_ssh_backdoor:
            if self.config.enable_ssh_backdoor:
                try:
                    del self.vhost_vsock
                except Exception as e:
@@ -309,7 +315,7 @@ class Driver:

        general_symbols = dict(
            start_all=self.start_all,
            test_script=self.test_script,
            test_script=self.config.test_script,
            machines=self.machines,
            machines_qemu=self.machines_qemu,
            machines_nspawn=self.machines_nspawn,
@@ -351,7 +357,7 @@ class Driver:
        return {**general_symbols, **machine_symbols, **vlan_symbols}

    def dump_machine_ssh(self) -> None:
        if not self.enable_ssh_backdoor:
        if not self.config.enable_ssh_backdoor:
            return

        assert self.vhost_vsock is not None
@@ -417,7 +423,7 @@ class Driver:
    def run_tests(self) -> None:
        """Run the test script (for non-interactive test runs)"""
        self.logger.info(
            f"Test will time out and terminate in {self.global_timeout} seconds"
            f"Test will time out and terminate in {self.config.global_timeout} seconds"
        )
        self.race_timer.start()
        self.test_script()
@@ -489,7 +495,7 @@ class Driver:
        """
        tmp_dir = get_tmp_dir()

        if self.enable_ssh_backdoor:
        if self.config.enable_ssh_backdoor:
            self.logger.warning(
                f"create_machine({name}): not enabling SSH backdoor, this is not supported for VMs created with create_machine!"
            )
+1 −0
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ let
  testModules = [
    ./call-test.nix
    ./driver.nix
    ./driver-configuration.nix
    ./interactive.nix
    ./legacy.nix
    ./meta.nix
+83 −0
Original line number Diff line number Diff line
{
  config,
  lib,
  pkgs,
  ...
}:
let
  inherit (lib) types;

  nodeConfigurationAttrs = lib.mkOption {
    internal = true;
    type = types.attrsOf (
      types.submodule {
        options = {
          name = lib.mkOption {
            internal = true;
            type = types.str;
          };
          start_script = lib.mkOption {
            internal = true;
            type = types.path;
          };
        };
      }
    );
  };
in
{
  options = {
    driverConfiguration = lib.mkOption {
      description = "Configuration attribute set for test driver invocation";
      internal = true;
      type = types.submodule {
        options = {
          vms = nodeConfigurationAttrs;
          containers = nodeConfigurationAttrs;
          vlans = lib.mkOption {
            internal = true;
            type = types.listOf types.ints.unsigned;
          };
          global_timeout = lib.mkOption {
            internal = true;
            type = types.ints.unsigned;
          };
          enable_ssh_backdoor = lib.mkOption {
            internal = true;
            type = types.bool;
          };
          test_script = lib.mkOption {
            internal = true;
            type = types.path;
          };
        };
      };
    };
    driverConfigurationFile = lib.mkOption {
      internal = true;
      type = types.path;
    };
  };

  config = {
    driverConfiguration = {
      vms = lib.mapAttrs (name: value: {
        inherit name;
        start_script = lib.getExe value.system.build.vm;
      }) config.nodes;
      containers = lib.mapAttrs (name: value: {
        inherit name;
        start_script = lib.getExe value.system.build.nspawn;
      }) config.containers;
      vlans = lib.unique (
        lib.concatMap (
          m: (m.virtualisation.vlans ++ (lib.mapAttrsToList (_: v: v.vlan) m.virtualisation.interfaces))
        ) (lib.attrValues config.nodes ++ lib.attrValues config.containers)
      );
      global_timeout = config.globalTimeout;
      test_script = pkgs.writeText "test-script" config.testScriptString;
      enable_ssh_backdoor = config.sshBackdoor.enable;
    };
    driverConfigurationFile = pkgs.writers.writeJSON "driverConfiguration.json" config.driverConfiguration;
  };
}
Loading