Unverified Commit b59a398f authored by Thiago Kenji Okada's avatar Thiago Kenji Okada Committed by GitHub
Browse files

nixos-rebuild-ng: kill underlying remote process (#403436)

parents c2e815c7 b74e861c
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import os
import sys
from pathlib import Path
from subprocess import CalledProcessError, run
from typing import assert_never
from typing import Final, assert_never

from . import nix, tmpdir
from .constants import EXECUTABLE, WITH_NIX_2_18, WITH_REEXEC, WITH_SHELL_FILES
@@ -13,7 +13,7 @@ from .models import Action, BuildAttr, Flake, ImageVariants, NRError, Profile
from .process import Remote, cleanup_ssh
from .utils import Args, LogFormatter, tabulate

logger = logging.getLogger()
logger: Final = logging.getLogger()
logger.setLevel(logging.INFO)


+6 −4
Original line number Diff line number Diff line
from typing import Final

# Build-time flags
# Use strings to avoid breaking standalone (e.g.: `python -m nixos_rebuild`)
# usage
EXECUTABLE = "@executable@"
EXECUTABLE: Final[str] = "@executable@"
# Use either `== "true"` if the default (e.g.: `python -m nixos_rebuild`) is
# `False` or `!= "false"` if the default is `True`
WITH_NIX_2_18 = "@withNix218@" != "false"  # type: ignore
WITH_REEXEC = "@withReexec@" == "true"  # type: ignore
WITH_SHELL_FILES = "@withShellFiles@" == "true"  # type: ignore
WITH_NIX_2_18: Final[bool] = "@withNix218@" != "false"
WITH_REEXEC: Final[bool] = "@withReexec@" == "true"
WITH_SHELL_FILES: Final[bool] = "@withShellFiles@" == "true"
+1 −1
Original line number Diff line number Diff line
@@ -45,7 +45,7 @@ SWITCH_TO_CONFIGURATION_CMD_PREFIX: Final = [
    "--service-type=exec",
    "--unit=nixos-rebuild-switch-to-configuration",
]
logger = logging.getLogger(__name__)
logger: Final = logging.getLogger(__name__)


def build(
+73 −9
Original line number Diff line number Diff line
import atexit
import logging
import os
import re
import shlex
import subprocess
from collections.abc import Sequence
@@ -10,7 +11,7 @@ from typing import Final, Self, TypedDict, Unpack

from . import tmpdir

logger = logging.getLogger(__name__)
logger: Final = logging.getLogger(__name__)

SSH_DEFAULT_OPTS: Final = [
    "-o",
@@ -21,6 +22,8 @@ SSH_DEFAULT_OPTS: Final = [
    "ControlPersist=60",
]

type Args = Sequence[str | bytes | os.PathLike[str] | os.PathLike[bytes]]


@dataclass(frozen=True)
class Remote:
@@ -82,7 +85,7 @@ atexit.register(cleanup_ssh)


def run_wrapper(
    args: Sequence[str | bytes | os.PathLike[str] | os.PathLike[bytes]],
    args: Args,
    *,
    check: bool = True,
    extra_env: dict[str, str] | None = None,
@@ -93,6 +96,8 @@ def run_wrapper(
    "Wrapper around `subprocess.run` that supports extra functionality."
    env = None
    process_input = None
    run_args = args

    if remote:
        if extra_env:
            extra_env_args = [f"{env}={value}" for env, value in extra_env.items()]
@@ -103,7 +108,7 @@ def run_wrapper(
                process_input = remote.sudo_password + "\n"
            else:
                args = ["sudo", *args]
        args = [
        run_args = [
            "ssh",
            *remote.opts,
            *SSH_DEFAULT_OPTS,
@@ -119,32 +124,39 @@ def run_wrapper(
        if extra_env:
            env = os.environ | extra_env
        if sudo:
            args = ["sudo", *args]
            run_args = ["sudo", *run_args]

    logger.debug(
        "calling run with args=%r, kwargs=%r, extra_env=%r",
        args,
        run_args,
        kwargs,
        extra_env,
    )

    try:
        r = subprocess.run(
            args,
            run_args,
            check=check,
            env=env,
            input=process_input,
            # Hope nobody is using NixOS with non-UTF8 encodings, but "surrogateescape"
            # should still work in those systems.
            # Hope nobody is using NixOS with non-UTF8 encodings, but
            # "surrogateescape" should still work in those systems.
            text=True,
            errors="surrogateescape",
            **kwargs,
        )

        if kwargs.get("capture_output") or kwargs.get("stderr") or kwargs.get("stdout"):
            logger.debug("captured output stdout=%r, stderr=%r", r.stdout, r.stderr)
            logger.debug(
                "captured output with stdout=%r, stderr=%r", r.stdout, r.stderr
            )

        return r
    except KeyboardInterrupt:
        # sudo commands are activation only and unlikely to be long running
        if remote and not sudo:
            _kill_long_running_ssh_process(args, remote)
        raise
    except subprocess.CalledProcessError:
        if sudo and remote and remote.sudo_password is None:
            logger.error(
@@ -152,3 +164,55 @@ def run_wrapper(
                + "--ask-sudo-password?"
            )
        raise


# SSH does not send the signals to the process when running without usage of
# pseudo-TTY (that causes a whole other can of worms), so if the process is
# long running (e.g.: a build) this will result in the underlying process
# staying alive.
# See: https://stackoverflow.com/a/44354466
# Issue: https://github.com/NixOS/nixpkgs/issues/403269
def _kill_long_running_ssh_process(args: Args, remote: Remote) -> None:
    logger.info("cleaning-up remote process, please wait...")

    # We need to escape both the shell and regex here (since pkill interprets
    # its arguments as regex)
    quoted_args = re.escape(shlex.join(str(a) for a in args))
    logger.debug("killing remote process using pkill with args=%r", quoted_args)
    cleanup_interrupted = False

    try:
        r = subprocess.run(
            [
                "ssh",
                *remote.opts,
                *SSH_DEFAULT_OPTS,
                remote.host,
                "--",
                "pkill",
                "--signal",
                "SIGINT",
                "--full",
                "--",
                quoted_args,
            ],
            check=False,
            capture_output=True,
            text=True,
        )
        logger.debug(
            "remote pkill captured output with stdout=%r, stderr=%r, returncode=%s",
            r.stdout,
            r.stderr,
            r.returncode,
        )
    except KeyboardInterrupt:
        cleanup_interrupted = True
        raise
    finally:
        if cleanup_interrupted or r.returncode:
            logger.warning(
                "could not clean-up remote process, the command %s may still be running in host '%s'",
                args,
                remote.host,
            )
+31 −0
Original line number Diff line number Diff line
@@ -96,6 +96,37 @@ def test_run(mock_run: Any) -> None:
    )


@patch(get_qualified_name(p.subprocess.run), autospec=True)
def test__kill_long_running_ssh_process(mock_run: Any) -> None:
    p._kill_long_running_ssh_process(
        [
            "nix",
            "--extra-experimental-features",
            "nix-command flakes",
            "build",
            "/nix/store/la0c8nmpr9xfclla0n4f3qq9iwgdrq4g-nixos-system-sankyuu-nixos-25.05.20250424.f771eb4.drv^*",
        ],
        m.Remote("user@localhost", opts=[], sudo_password=None),
    )
    mock_run.assert_called_with(
        [
            "ssh",
            *p.SSH_DEFAULT_OPTS,
            "user@localhost",
            "--",
            "pkill",
            "--signal",
            "SIGINT",
            "--full",
            "--",
            r"nix\ \-\-extra\-experimental\-features\ 'nix\-command\ flakes'\ build\ '/nix/store/la0c8nmpr9xfclla0n4f3qq9iwgdrq4g\-nixos\-system\-sankyuu\-nixos\-25\.05\.20250424\.f771eb4\.drv\^\*'",
        ],
        check=False,
        capture_output=True,
        text=True,
    )


def test_remote_from_name(monkeypatch: MonkeyPatch) -> None:
    monkeypatch.setenv("NIX_SSHOPTS", "")
    assert m.Remote.from_arg("user@localhost", None, False) == m.Remote(