Commit 7a9bf277 authored by Jacek Galowicz's avatar Jacek Galowicz
Browse files

nix-required-mounts: Mount symlink children's targets and add tests



Breaking change: instead of expanding a single symlink chain,
unsafeFollowSymlinks now walks directories recursively and
mounts all symlink descendants' targets.
This accommodates usecases like:
/sys/dev/char/226:128 -> ../../devices/pci0000:00/0000:00:02.0/drm/renderD128
/sys/dev/char/226:128/subsystem -> ../../../../../class/drm

Fixes issue #497824

Co-authored-by: default avatarSomeoneSerge <else@someonex.net>
parent dfd9566f
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -43,6 +43,11 @@ let
          the `paths` contain symlinks. This may not work correctly with glob
          patterns.
        '';
        options.safePrefixes = lib.mkOption {
          default = [ builtins.storeDir ];
          type = listOf path;
          description = "A list of path prefixes that do not need and shall not be searched recursively for further symlink targets. Everything in the nix store does not need to be searched as the derivation already calculcated the full closure of all nix store paths for the drivers package.";
        };
      }
    );

+228 −115
Original line number Diff line number Diff line
#!/usr/bin/env python3

import glob
import json
import os
import subprocess
import textwrap
from argparse import ArgumentParser
from collections import deque
from itertools import chain
from pathlib import Path
from typing import Deque, Dict, List, Set, Tuple, TypeAlias, TypedDict
from pathlib import Path, PurePath
from typing import (
    TypeAlias,
    TypedDict,
    Iterable,
)
import logging

Glob: TypeAlias = str
@@ -21,19 +24,20 @@ class Mount(TypedDict):


class Pattern(TypedDict):
    onFeatures: List[str]
    paths: List[Glob | Mount]
    onFeatures: list[str]
    paths: list[Glob | Mount]
    unsafeFollowSymlinks: bool
    safePrefixes: list[str]


AllowedPatterns: TypeAlias = Dict[str, Pattern]
AllowedPatterns: TypeAlias = dict[str, Pattern]


parser = ArgumentParser("pre-build-hook")
parser.add_argument("derivation_path")
parser.add_argument("sandbox_path", nargs="?")
parser.add_argument("--patterns", type=Path, required=True)
parser.add_argument("--nix-exe", type=Path, required=True)
parser.add_argument("--nix-exe", type=Path)
parser.add_argument(
    "--issue-command",
    choices=("always", "conditional", "never"),
@@ -49,153 +53,262 @@ parser.add_argument(
parser.add_argument("-v", "--verbose", action="count", default=0)


def symlink_parents(p: Path) -> List[Path]:
    out = []
    while p.is_symlink() and p not in out:
        parent = p.readlink()
        if parent.is_relative_to("."):
            p = p / parent
        else:
            p = parent
        out.append(p)
    return out


def get_required_system_features(parsed_drv: dict) -> List[str]:
    # Newer versions of Nix (since https://github.com/NixOS/nix/pull/13263) store structuredAttrs
    # in the derivation JSON output.
    if "structuredAttrs" in parsed_drv:
        return parsed_drv["structuredAttrs"].get("requiredSystemFeatures", [])

    # Older versions of Nix store structuredAttrs in the env as a JSON string.
    drv_env = parsed_drv.get("env", {})
    if "__json" in drv_env:
        return list(json.loads(drv_env["__json"]).get("requiredSystemFeatures", []))

    # Without structuredAttrs, requiredSystemFeatures is a space-separated string in env.
    return drv_env.get("requiredSystemFeatures", "").split()


def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool]]:
    roots = []
    for mount in pattern["paths"]:
        if isinstance(mount, PathString):
            matches = glob.glob(mount)
            assert matches, f"Specified host paths do not exist: {mount}"

            roots.extend((m, m, pattern["unsafeFollowSymlinks"]) for m in matches)
        else:
            assert isinstance(mount, dict) and "host" in mount, mount
            assert Path(
                mount["host"]
            ).exists(), f"Specified host paths do not exist: {mount['host']}"
            roots.append(
                (
                    mount["guest"],
                    mount["host"],
                    pattern["unsafeFollowSymlinks"],
                )
            )

    return roots


def entrypoint():
    args = parser.parse_args()

    VERBOSITY_LEVELS = [logging.ERROR, logging.INFO, logging.DEBUG]

    level_index = min(args.verbose, len(VERBOSITY_LEVELS) - 1)
    logging.basicConfig(level=VERBOSITY_LEVELS[level_index])

    drv_path = args.derivation_path

    with open(args.patterns, "r") as f:
        allowed_patterns = json.load(f)

    if not Path(drv_path).exists():
def parse_derivation(
    derivation_path: PathString, nix_exe: PathString | None
) -> dict:
    """Extract the content of a .drv file into a JSON dict"""
    if not Path(derivation_path).exists():
        logging.error(
            f"{drv_path} doesn't exist."
            f"{derivation_path} doesn't exist."
            " Cf. https://github.com/NixOS/nix/issues/9272"
            " Exiting the hook",
        )

    proc = subprocess.run(
        [
            args.nix_exe,
            nix_exe if nix_exe else "nix",
            "show-derivation",
            drv_path,
            derivation_path,
        ],
        capture_output=True,
    )
    try:
        parsed_drv = json.loads(proc.stdout)

        # compabitility: https://github.com/NixOS/nix/pull/14770
        if "derivations" in parsed_drv:
            parsed_drv = parsed_drv["derivations"]
    except json.JSONDecodeError:
        output_str: str = proc.stdout.decode("utf-8")
        logging.error(
            "Couldn't parse the output of"
            "`nix show-derivation`"
            f". Expected JSON, observed: {proc.stdout}",
            f". Expected JSON, observed: {output_str}",
        )
        logging.error(textwrap.indent(proc.stdout.decode("utf8"), prefix=" " * 4))
        logging.error(textwrap.indent(output_str, prefix=" " * 4))
        logging.info("Exiting the nix-required-binds hook")
        return

    [canon_drv_path] = parsed_drv.keys()

    known_features = set(
        chain.from_iterable(
            pattern["onFeatures"] for pattern in allowed_patterns.values()
    return parsed_drv[canon_drv_path]


def get_required_system_features(parsed_drv: dict) -> set[str]:
    # Newer versions of Nix (since https://github.com/NixOS/nix/pull/13263)
    # store structuredAttrs in the derivation JSON output.
    if "structuredAttrs" in parsed_drv:
        return set(
            parsed_drv["structuredAttrs"].get("requiredSystemFeatures", [])
        )

    # Older versions of Nix store structuredAttrs in the env as a JSON string.
    drv_env = parsed_drv.get("env", {})
    if "__json" in drv_env:
        return set(
            json.loads(drv_env["__json"]).get("requiredSystemFeatures", [])
        )

    parsed_drv = parsed_drv[canon_drv_path]
    required_features = get_required_system_features(parsed_drv)
    required_features = list(filter(known_features.__contains__, required_features))
    # Without structuredAttrs, requiredSystemFeatures is a space-separated string in env.
    return set(drv_env.get("requiredSystemFeatures", "").split())

    patterns: List[Pattern] = list(
        pattern
        for pattern in allowed_patterns.values()
        for path in pattern["paths"]
        if any(feature in required_features for feature in pattern["onFeatures"])
    )  # noqa: E501

    queue: Deque[Tuple[PathString, PathString, bool]] = deque(
        (mnt for pattern in patterns for mnt in validate_mounts(pattern))
    )
def expand_globs(paths: list[PathString]) -> list[PathString]:
    """Expand all existing paths from globbed paths like bash would"""
    return sum(map(glob.glob, paths), [])

    unique_mounts: Set[Tuple[PathString, PathString]] = set()
    mounts: List[Tuple[PathString, PathString]] = []

def symlink_targets(p: Path) -> list[Path]:
    """Traverse a chain of symlinks to collect every intermediate path up to the final destination."""

    out = []
    while p.is_symlink():
        target = p.readlink()
        if target.is_absolute():
            p = target
        else:
            # we need to resolve paths before concatenation because of things like
            # $ ls -l /sys/dev/char/226:128/subsystem
            # ... /sys/dev/char/226:128/subsystem
            # -> ../../../../../../class/drm
            #
            # Path(normpath(...)) needed to normalize `foo/../bar` to `bar`
            p = Path(os.path.normpath(p.parent.resolve() / target))

        if p in out:
            break
        out.append(p)

    return out


def symlink_targets_deep(
    inputs: Iterable[PathString], follow_symlinks: bool
) -> list[PathString]:
    """Walk the file system tree and discover all possible symlink targets"""
    queue: deque[PathString] = deque(inputs)
    unique_paths: set[PathString] = set()
    reachable_paths: list[PathString] = []

    while queue:
        guest_path_str, host_path_str, follow_symlinks = queue.popleft()
        if (guest_path_str, host_path_str) not in unique_mounts:
            mounts.append((guest_path_str, host_path_str))
            unique_mounts.add((guest_path_str, host_path_str))
        path_str = str(queue.popleft())
        if path_str not in unique_paths:
            reachable_paths.append(path_str)
            unique_paths.add(path_str)

        if not follow_symlinks:
            continue

        host_path = Path(host_path_str)
        if not (host_path.is_dir() or host_path.is_symlink()):
        path = Path(path_str)
        if not (path.is_dir() or path.is_symlink()):
            continue

        # assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
        paths: Iterable[Path] = [path]
        if path.is_dir():
            paths = chain(paths, path.iterdir())

        for child in host_path.iterdir() if host_path.is_dir() else [host_path]:
            for parent in symlink_parents(child):
                parent_str = parent.absolute().as_posix()
                queue.append((parent_str, parent_str, follow_symlinks))
        for child in paths:
            for parent in symlink_targets(child):
                path = parent.absolute()
                if all(
                    not path.is_relative_to(existing_path)
                    for existing_path in unique_paths
                ):
                    queue.append(path.as_posix())
    return reachable_paths


def prune_paths(inputs: Iterable[PathString]) -> list[PathString]:
    """
    From a list of paths prune all paths that are subdirectories of others

    >>> prune_paths(["/a/b", "/a"])
    ['/a']
    >>> prune_paths(["/a/b/c", "/a/b"])
    ['/a/b']
    """
    sorted_inputs = sorted(inputs)
    pruned = [sorted_inputs[0]]

    last_kept: PathString = pruned[0]
    for current in sorted_inputs[1:]:
        if not Path(current).is_relative_to(last_kept):
            pruned.append(current)
            last_kept = current

    return pruned


def mount_closure(pattern: Pattern) -> list[tuple[PathString, PathString]]:
    """
    This function extracts all paths from a pattern into the following:
    - list of nix store paths
    - host/hardware specific paths (anything outside the nix store)
    - translations from host to guest (necessary for some non-NixOS hosts)

    As the host paths are often multiple levels of symlinks, these can be
    followed to be able to provide them all in the sandbox as they would
    otherwise be broken (see `unsafeFollowSymlinks`).

    The finally returned list contains tuples with guest-host mappings between
    those paths. Most of them are 1:1.
    """

    def safe_prefix(p):
        safe_prefixes = pattern.get("safePrefixes", [])
        return any(p.startswith(safe_prefix) for safe_prefix in safe_prefixes)

    # All nix store paths have been statically calculated before.
    # There is no need to look into them or add anything
    store_paths = [
        p
        for p in pattern["paths"]
        if isinstance(p, PathString) and safe_prefix(p)
    ]
    # Paths that e.g. point to /dev/... or /run/... paths etc. might further
    # point to other paths and these need to be added to the sandbox, too.
    host_paths = [
        p
        for p in pattern["paths"]
        if isinstance(p, PathString) and not safe_prefix(p)
    ]
    # Translations on the non-NixOS hosts like e.g. /usr/lib to /run/opengl-driver
    # need to be applied on the final path list
    translations: dict[PathString, PathString] = {
        p["host"]: p["guest"]
        for p in pattern["paths"]
        if not isinstance(p, PathString)
    }

    host_paths.extend(translations.keys())

    all_paths = prune_paths(
        chain(
            store_paths,
            symlink_targets_deep(
                expand_globs(host_paths), pattern["unsafeFollowSymlinks"]
            ),
        )
    )

    return [(translations.get(x, x), x) for x in all_paths]


def patterns_for_features(
    patterns: AllowedPatterns, features: set[str]
) -> AllowedPatterns:
    """
    Return the list of patterns that are required for a given set of features.
    >>> patterns = {
    ...   "a": {
    ...     "onFeatures": ["a"],
    ...     "unsafeFollowSymlinks": True,
    ...     "paths": []
    ...   },
    ...   "b": {
    ...     "onFeatures": ["b"],
    ...     "unsafeFollowSymlinks": True,
    ...     "paths": []
    ...   }
    ... }
    >>> list(patterns_for_features(patterns, {"a"}).keys())
    ['a']
    >>> list(patterns_for_features(patterns, {"a", "b"}).keys())
    ['a', 'b']
    """
    return {
        k: v for k, v in patterns.items() if set(v["onFeatures"]) & features
    }


def entrypoint() -> None:
    args = parser.parse_args()

    VERBOSITY_LEVELS = [logging.ERROR, logging.INFO, logging.DEBUG]

    level_index = min(args.verbose, len(VERBOSITY_LEVELS) - 1)
    logging.basicConfig(level=VERBOSITY_LEVELS[level_index])

    with open(args.patterns, "r") as f:
        patterns = json.load(f)

    parsed_drv: dict = parse_derivation(args.derivation_path, args.nix_exe)
    features: set[str] = get_required_system_features(parsed_drv)
    required_patterns: AllowedPatterns = patterns_for_features(
        patterns, features
    )

    mounts = list(
        chain.from_iterable(
            (mount_closure(pattern) for pattern in required_patterns.values())
        )
    )

    # the pre-build-hook command
    if args.issue_command == "always" or (
        args.issue_command == "conditional" and mounts
    ):
        print("extra-sandbox-paths")
        print_paths = True
    else:
        print_paths = False

    # arguments, one per line
    for guest_path_str, host_path_str in mounts if print_paths else []:
        for guest_path_str, host_path_str in mounts:
            print(f"{guest_path_str}={host_path_str}")

    # terminated by an empty line
+12 −1
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
      "/dev/nvidia*"
    ];
    nvidia-gpu.unsafeFollowSymlinks = true;
    nvidia-gpu.safePrefixes = [ builtins.storeDir ];
  },
  callPackage,
  extraWrapperArgs ? [ ],
@@ -37,13 +38,23 @@ python3Packages.buildPythonApplication {
  inherit pname version;
  pyproject = true;

  src = lib.cleanSource ./.;
  src = lib.sourceByRegex ./. [
    "^pyproject.toml$"
    "^.*nix_required_mounts.py$" # app and unit test file
  ];

  nativeBuildInputs = [
    makeWrapper
    python3Packages.setuptools
  ];

  checkInputs = [
    python3Packages.pytestCheckHook
  ];
  pythonImportsCheck = [
    "nix_required_mounts"
  ];

  postFixup = ''
    wrapProgram $out/bin/${pname} \
      --add-flags "--patterns ${allowedPatternsPath}" \
+6 −0
Original line number Diff line number Diff line
@@ -16,5 +16,11 @@ Homepage = "https://github.com/NixOS/nixpkgs/tree/master/pkgs/by-name/ni/nix-req
[project.scripts]
nix-required-mounts = "nix_required_mounts:entrypoint"

[tool.setuptools]
py-modules = [ "nix_required_mounts" ]

[tool.black]
line-length = 79

[tool.pytest.ini_options]
addopts = ["--doctest-modules"]
+205 −0
Original line number Diff line number Diff line
import unittest
import tempfile
import shutil
from pathlib import Path
from nix_required_mounts import (
    PathString,
    Pattern,
    expand_globs,
    mount_closure,
    prune_paths,
    symlink_targets,
    symlink_targets_deep,
)
import os
import pytest
from pathlib import Path


class TreeBuilder:
    """Helper to create files and symlinks from a simple dict."""

    def __init__(self, root: Path):
        self.root = root

    def build(self, structure: dict):
        for path_str, target in structure.items():
            path = self.root / path_str
            path.parent.mkdir(parents=True, exist_ok=True)

            if target == "file":
                path.touch()
            elif target == "dir":
                path.mkdir(parents=True, exist_ok=True)
            elif target.startswith("->"):
                link_to = target.replace("->", "").strip()
                path.symlink_to(link_to)
        return self.root


@pytest.fixture
def tree(tmp_path):
    # https://docs.pytest.org/en/stable/how-to/tmp_path.html
    # the paths are kept on error for inspection
    return TreeBuilder(tmp_path)


def test_symlink_chain(tree):
    # fmt: off
    root = tree.build({
        "a": "file",
        "b": "-> a",
        "c": "-> b",
    })
    # fmt: on

    assert symlink_targets(root / "a") == []
    assert symlink_targets(root / "b") == [root / "a"]
    assert symlink_targets(root / "c") == [root / "b", root / "a"]


def test_far_up_relative_links(tree):
    depth = 15
    path = "x/" * depth
    root = tree.build(
        {
            "o/p": "file",
            "a/b": f"-> ../{path}",
            f"{path}/n": f"-> ../{'../' * depth}o/p",
        }
    )

    assert symlink_targets(root / "a/b") == [root / Path(path)]


def test_jump_outside_folder(tree):
    # fmt: off
    root = tree.build({
        "c/d": "file",
        "a/b": "-> ../c/d",
    })
    # fmt: on
    assert symlink_targets(root / "a/b") == [root / "c/d"]


def test_path_discovery_resolve_relative_links(tree):
    depth = 15
    path = "x/" * depth
    root = tree.build(
        {
            "o/p": "file",
            "a/b": f"-> ../{path}",
            f"{path}/n": f"-> {'../' * depth}o/p",
        }
    )

    assert symlink_targets_deep([root / "a"], follow_symlinks=True) == [
        str(x) for x in [root / "a", root / path, root / "o/p"]
    ]


def test_pattern_extraction(tree):
    root = tree.build(
        {
            "a": "file",
            "b": "-> a",
            "c": "-> b",
            "d/e": "file",
            "f/g": "-> ../d/e",
        }
    )

    def strs(paths: list[Path]) -> list[PathString]:
        return [str(x) if isinstance(x, Path) else x for x in paths]

    def pairs(paths: list[Path]) -> list[tuple[str, str]]:
        return [(x, x) for x in strs(paths)]

    a = {
        "onFeatures": ["feature_a", "feature_a1"],
        "paths": [str(root / "c")],
        "unsafeFollowSymlinks": True,
    }

    assert mount_closure(a) == pairs([root / "a", root / "b", root / "c"])

    assert mount_closure(a | {"unsafeFollowSymlinks": False}) == pairs(
        [root / "c"]
    )

    b = {
        "onFeatures": ["feature_b", "feature_b2"],
        "paths": strs([root / "d", root / "f"]),
        "unsafeFollowSymlinks": True,
    }

    assert mount_closure(b) == pairs(
        [
            root / "d",
            root / "f",
        ]
    )

    with_mounts = {
        "onFeatures": ["feature_b", "feature_b2"],
        "paths": strs(
            [
                root / "d",  # prevent formatter
                root / "f",
                {"host": str(root), "guest": "/foo/bar"},
            ]
        ),
        "unsafeFollowSymlinks": True,
    }

    assert mount_closure(with_mounts) == [
        ("/foo/bar", str(root)),
    ]


def test_glob_expansion(tree):
    # fmt: off
    root = tree.build({
        "a": "file",
        "b": "-> a",
        "c": "-> b",
    })
    # fmt: on

    assert sorted(expand_globs([str(root / "*")])) == [
        str(x) for x in [root / "a", root / "b", root / "c"]
    ]


def test_path_discovery(tree):
    root = tree.build(
        {
            "a": "file",
            "b": "-> a",
            "c": "-> b",
            "d/e": "file",
            "f/g": "-> ../d/e",
        }
    )

    ss = lambda unsorted_paths: sorted(map(str, unsorted_paths))

    assert ss(
        symlink_targets_deep(
            [root / "c", root / "f"], follow_symlinks=True
        )
    ) == ss(
        # fmt: off
        [
            root / "a",
            root / "b",
            root / "c",
            root / "d/e",
            root / "f",
        ]
        # fmt: on
    )


if __name__ == "__main__":
    unittest.main()