Unverified Commit 9dcf7113 authored by Maximilian Bosch's avatar Maximilian Bosch
Browse files

nixos/test-driver: push allocation into the context-manager

The context manager's purpose is to allocate its resources in `__enter__` and
release them again in `__exit__`.

Right now, the approach is merely a hack since we allocate everything in
the constructor, but use the context-manager protocol as way to reliably
terminate everything.

This isn't a functional change, but merely a correctness change by using
the methods the way they were intended.
parent d406be44
Loading
Loading
Loading
Loading
+25 −17
Original line number Diff line number Diff line
@@ -68,12 +68,16 @@ class Driver:
    and runs the tests"""

    tests: str
    vlans: list[VLan]
    machines_qemu: list[QemuMachine]
    machines_nspawn: list[NspawnMachine]
    vlans: list[VLan] = []
    machines_qemu: list[QemuMachine] = []
    machines_nspawn: list[NspawnMachine] = []
    polling_conditions: list[PollingCondition]
    global_timeout: int
    race_timer: threading.Timer
    vm_start_scripts: dict[str, str]
    container_start_scripts: dict[str, str]
    vlan_ids: list[int]
    keep_machine_state: bool
    logger: AbstractLogger
    debug: DebugAbstract

@@ -94,15 +98,23 @@ class Driver:
        self.tests = tests
        self.out_dir = out_dir
        self.global_timeout = global_timeout
        self.race_timer = threading.Timer(global_timeout, self.terminate_test)
        self.logger = logger
        self.debug = debug
        self.vlan_ids = list(set(vlans))
        self.polling_conditions = []
        self.keep_machine_state = keep_machine_state
        self.global_timeout = global_timeout
        self.vm_start_scripts = dict(zip(vm_names, vm_start_scripts))
        self.container_start_scripts = dict(
            zip(container_names, container_start_scripts)
        )

    def __enter__(self) -> "Driver":
        self.race_timer = threading.Timer(self.global_timeout, self.terminate_test)
        tmp_dir = get_tmp_dir()

        with self.logger.nested("start all VLans"):
            vlans = list(set(vlans))
            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in self.vlan_ids]

        self.polling_conditions = []

@@ -110,16 +122,16 @@ class Driver:
            QemuMachine(
                name=name,
                start_command=vm_start_script,
                keep_machine_state=keep_machine_state,
                keep_machine_state=self.keep_machine_state,
                tmp_dir=tmp_dir,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
                logger=self.logger,
            )
            for name, vm_start_script in zip(vm_names, vm_start_scripts)
            for name, vm_start_script in self.vm_start_scripts.items()
        ]

        if len(container_start_scripts) > 0:
        if len(self.container_start_scripts) > 0:
            self._init_nspawn_environment()

        self.machines_nspawn = [
@@ -128,16 +140,15 @@ class Driver:
                start_command=container_start_script,
                tmp_dir=tmp_dir,
                logger=self.logger,
                keep_machine_state=keep_machine_state,
                keep_machine_state=self.keep_machine_state,
                callbacks=[self.check_polling_conditions],
                out_dir=self.out_dir,
            )
            for name, container_start_script in zip(
                container_names,
                container_start_scripts,
            )
            for name, container_start_script in self.container_start_scripts.items()
        ]

        return self

    def _init_nspawn_environment(self) -> None:
        assert os.geteuid() == 0, (
            f"systemd-nspawn requires root to work. You are {os.geteuid()}"
@@ -193,9 +204,6 @@ class Driver:
        machines.sort(key=lambda machine: machine.name)
        return machines

    def __enter__(self) -> "Driver":
        return self

    def __exit__(self, *_: Any) -> None:
        with self.logger.nested("cleanup"):
            self.race_timer.cancel()