Commit 0cbdae41 authored by Thiago Kenji Okada's avatar Thiago Kenji Okada
Browse files

nixos-rebuild-ng: error if --upgrade/--upgrade-all is called without --sudo/root

parent d75ba2e8
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -699,6 +699,12 @@ def upgrade_channels(all_channels: bool = False, sudo: bool = False) -> None:
    It will either upgrade just the `nixos` channel (including any channel
    that has a `.update-on-nixos-rebuild` file) or all.
    """
    if not sudo and os.geteuid() != 0:
        raise NixOSRebuildError(
            "if you pass the '--upgrade' or '--upgrade-all' flag, you must "
            "also pass '--sudo' or run the command as root (e.g., with sudo)"
        )

    for channel_path in Path("/nix/var/nix/profiles/per-user/root/channels/").glob("*"):
        if channel_path.is_dir() and (
            all_channels
+18 −5
Original line number Diff line number Diff line
@@ -836,14 +836,27 @@ def test_switch_to_configuration_with_systemd_run(
    ],
)
@patch("pathlib.Path.is_dir", autospec=True, return_value=True)
def test_upgrade_channels(mock_is_dir: Mock, mock_glob: Mock) -> None:
    with patch(get_qualified_name(n.run_wrapper, n), autospec=True) as mock_run:
@patch("os.geteuid", autospec=True, return_value=1000)
@patch(get_qualified_name(n.run_wrapper, n), autospec=True)
def test_upgrade_channels(
    mock_run: Mock,
    mock_geteuid: Mock,
    mock_is_dir: Mock,
    mock_glob: Mock,
) -> None:
    with pytest.raises(m.NixOSRebuildError) as e:
        n.upgrade_channels(all_channels=False, sudo=False)
    assert str(e.value) == (
        "error: if you pass the '--upgrade' or '--upgrade-all' flag, you must "
        "also pass '--sudo' or run the command as root (e.g., with sudo)"
    )

    n.upgrade_channels(all_channels=False, sudo=True)
    mock_run.assert_called_once_with(
        ["nix-channel", "--update", "nixos"], check=False, sudo=True
    )

    with patch(get_qualified_name(n.run_wrapper, n), autospec=True) as mock_run:
    mock_geteuid.return_value = 0
    n.upgrade_channels(all_channels=True, sudo=False)
    mock_run.assert_has_calls(
        [