Commit 61e61a59 authored by Thiago Kenji Okada's avatar Thiago Kenji Okada
Browse files

nixos-rebuild-ng: kill underlying remote process

`nixos-rebuild-ng` explicitly don't allocate a pseudo-TTY for SSH
because this causes lots of issues depending on the use case (for
example, multiplexing multiple SSH sessions).

Sadly not using a pseudo-TTY also cause other issues, like the fact that
using Ctrl+C (SIGINT) doesn't kill the underlying process because SSH
doesn't support it.

We can't really start using pseudo-TTY unless we want to overcomplicate
the code for parsing results (pseudo-TTY mangles the stdout/stderr
together), so we need to handle killing the underlying remote process
manually.

This is what this commit does, when we receive a `KeyboardInterrupt`
exception while calling `run_wrapper`, we will check if it is a remote
process and send a `pkill --full` with the arguments (this should ensure
that we don't kill other process, but we can't guarantee it). This
assumes the user has `procps` installed, but I think it is a safe
assumption since this seems to be a core package.

Sadly nothing we can do if the user doesn't have `procps` installed, the
good thing is that the worst that can happen is that we will silent
fail and the process will stay in background until it finishes.

Fix #403269.
parent 6358207c
Loading
Loading
Loading
Loading
+59 −8
Original line number Diff line number Diff line
import atexit
import logging
import os
import re
import shlex
import subprocess
from collections.abc import Sequence
@@ -21,6 +22,8 @@ SSH_DEFAULT_OPTS: Final = [
    "ControlPersist=60",
]

type Args = Sequence[str | bytes | os.PathLike[str] | os.PathLike[bytes]]


@dataclass(frozen=True)
class Remote:
@@ -82,7 +85,7 @@ atexit.register(cleanup_ssh)


def run_wrapper(
    args: Sequence[str | bytes | os.PathLike[str] | os.PathLike[bytes]],
    args: Args,
    *,
    check: bool = True,
    extra_env: dict[str, str] | None = None,
@@ -93,6 +96,8 @@ def run_wrapper(
    "Wrapper around `subprocess.run` that supports extra functionality."
    env = None
    process_input = None
    run_args = args

    if remote:
        if extra_env:
            extra_env_args = [f"{env}={value}" for env, value in extra_env.items()]
@@ -103,7 +108,7 @@ def run_wrapper(
                process_input = remote.sudo_password + "\n"
            else:
                args = ["sudo", *args]
        args = [
        run_args = [
            "ssh",
            *remote.opts,
            *SSH_DEFAULT_OPTS,
@@ -119,32 +124,39 @@ def run_wrapper(
        if extra_env:
            env = os.environ | extra_env
        if sudo:
            args = ["sudo", *args]
            run_args = ["sudo", *run_args]

    logger.debug(
        "calling run with args=%r, kwargs=%r, extra_env=%r",
        args,
        run_args,
        kwargs,
        extra_env,
    )

    try:
        r = subprocess.run(
            args,
            run_args,
            check=check,
            env=env,
            input=process_input,
            # Hope nobody is using NixOS with non-UTF8 encodings, but "surrogateescape"
            # should still work in those systems.
            # Hope nobody is using NixOS with non-UTF8 encodings, but
            # "surrogateescape" should still work in those systems.
            text=True,
            errors="surrogateescape",
            **kwargs,
        )

        if kwargs.get("capture_output") or kwargs.get("stderr") or kwargs.get("stdout"):
            logger.debug("captured output stdout=%r, stderr=%r", r.stdout, r.stderr)
            logger.debug(
                "captured output with stdout=%r, stderr=%r", r.stdout, r.stderr
            )

        return r
    except KeyboardInterrupt:
        # sudo commands are activation only and unlikely to be long running
        if remote and not sudo:
            _kill_long_running_ssh_process(args, remote)
        raise
    except subprocess.CalledProcessError:
        if sudo and remote and remote.sudo_password is None:
            logger.error(
@@ -152,3 +164,42 @@ def run_wrapper(
                + "--ask-sudo-password?"
            )
        raise


# SSH does not send the signals to the process when running without usage of
# pseudo-TTY (that causes a whole other can of worms), so if the process is
# long running (e.g.: a build) this will result in the underlying process
# staying alive.
# See: https://stackoverflow.com/a/44354466
# Issue: https://github.com/NixOS/nixpkgs/issues/403269
def _kill_long_running_ssh_process(args: Args, remote: Remote) -> None:
    logger.info("cleaning-up remote process, please wait...")

    # We need to escape both the shell and regex here (since pkill interprets
    # its arguments as regex)
    quoted_args = re.escape(shlex.join(str(a) for a in args))
    logger.debug("killing remote process using pkill with args=%r", quoted_args)
    r = subprocess.run(
        [
            "ssh",
            *remote.opts,
            *SSH_DEFAULT_OPTS,
            remote.host,
            "--",
            "pkill",
            "--signal",
            "SIGINT",
            "--full",
            "--",
            quoted_args,
        ],
        check=False,
        capture_output=True,
        text=True,
    )
    logger.debug(
        "remote pkill captured output with stdout=%r, stderr=%r, returncode=%s",
        r.stdout,
        r.stderr,
        r.returncode,
    )
+31 −0
Original line number Diff line number Diff line
@@ -96,6 +96,37 @@ def test_run(mock_run: Any) -> None:
    )


@patch(get_qualified_name(p.subprocess.run), autospec=True)
def test__kill_long_running_ssh_process(mock_run: Any) -> None:
    p._kill_long_running_ssh_process(
        [
            "nix",
            "--extra-experimental-features",
            "nix-command flakes",
            "build",
            "/nix/store/la0c8nmpr9xfclla0n4f3qq9iwgdrq4g-nixos-system-sankyuu-nixos-25.05.20250424.f771eb4.drv^*",
        ],
        m.Remote("user@localhost", opts=[], sudo_password=None),
    )
    mock_run.assert_called_with(
        [
            "ssh",
            *p.SSH_DEFAULT_OPTS,
            "user@localhost",
            "--",
            "pkill",
            "--signal",
            "SIGINT",
            "--full",
            "--",
            r"nix\ \-\-extra\-experimental\-features\ 'nix\-command\ flakes'\ build\ '/nix/store/la0c8nmpr9xfclla0n4f3qq9iwgdrq4g\-nixos\-system\-sankyuu\-nixos\-25\.05\.20250424\.f771eb4\.drv\^\*'",
        ],
        check=False,
        capture_output=True,
        text=True,
    )


def test_remote_from_name(monkeypatch: MonkeyPatch) -> None:
    monkeypatch.setenv("NIX_SSHOPTS", "")
    assert m.Remote.from_arg("user@localhost", None, False) == m.Remote(