Unverified Commit 0288c851 authored by nicoo's avatar nicoo Committed by GitHub
Browse files

maintainers/scripts/sha-to-sri: accept directories as arguments, document, minor fixes (#341551)

parents 135b49bb 9259479c
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -56,3 +56,16 @@ The maintainer is designated by a `selector` which must be one of:
  see [`maintainer-list.nix`] for the fields' definition.

[`maintainer-list.nix`]: ../maintainer-list.nix


## Conventions

### `sha-to-sri.py`

`sha-to-sri.py path ...` (atomically) rewrites hash attributes (named `hash` or `sha(1|256|512)`)
into the SRI format: `hash = "{hash name}-{base64 encoded value}"`.

`path` must point to either a nix file, or a directory which will be automatically traversed.

`sha-to-sri.py` automatically skips files whose first non-empty line contains `generated by` or `do not edit`.
Moreover, when walking a directory tree, the script will skip files whose name is `yarn.nix` or contains `generated`.
+67 −49
Original line number Diff line number Diff line
#!/usr/bin/env nix-shell
#! nix-shell -i "python3 -I" -p "python3.withPackages(p: with p; [ rich structlog ])"

from abc import ABC, abstractclassmethod, abstractmethod
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from structlog.contextvars import bound_contextvars as log_context
from typing import ClassVar, List, Tuple

import hashlib, re, structlog
import hashlib, logging, re, structlog


logger = structlog.getLogger("sha-to-SRI")
@@ -26,10 +26,11 @@ class Encoding(ABC):
        assert len(digest) == self.n

        from base64 import b64encode

        return f"{self.hashName}-{b64encode(digest).decode()}"

    @classmethod
    def all(cls, h) -> 'List[Encoding]':
    def all(cls, h) -> "List[Encoding]":
        return [c(h) for c in cls.__subclasses__()]

    def __init__(self, h):
@@ -38,16 +39,14 @@ class Encoding(ABC):

    @property
    @abstractmethod
    def length(self) -> int:
        ...
    def length(self) -> int: ...

    @property
    def regex(self) -> str:
        return f"[{self.alphabet}]{{{self.length}}}"

    @abstractmethod
    def decode(self, s: str) -> bytes:
        ...
    def decode(self, s: str) -> bytes: ...


class Nix32(Encoding):
@@ -57,15 +56,15 @@ class Nix32(Encoding):
    @property
    def length(self):
        return 1 + (8 * self.n) // 5

    def decode(self, s: str):
        assert len(s) == self.length
        out = [ 0 for _ in range(self.n) ]
        # TODO: Do better than a list of byte-sized ints
        out = bytearray(self.n)

        for n, c in enumerate(reversed(s)):
            digit = self.inverted[c]
            i, j = divmod(5 * n, 8)
            out[i] = out[i] | (digit << j) & 0xff
            out[i] = out[i] | (digit << j) & 0xFF
            rem = digit >> (8 - j)
            if rem == 0:
                continue
@@ -76,16 +75,20 @@ class Nix32(Encoding):

        return bytes(out)


class Hex(Encoding):
    alphabet = "0-9A-Fa-f"

    @property
    def length(self):
        return 2 * self.n

    def decode(self, s: str):
        from binascii import unhexlify

        return unhexlify(s)


class Base64(Encoding):
    alphabet = "A-Za-z0-9+/"

@@ -94,36 +97,39 @@ class Base64(Encoding):
        """Number of characters in data and padding."""
        i, k = divmod(self.n, 3)
        return 4 * i + (0 if k == 0 else k + 1), (3 - k) % 3

    @property
    def length(self):
        return sum(self.format)

    @property
    def regex(self):
        data, padding = self.format
        return f"[{self.alphabet}]{{{data}}}={{{padding}}}"

    def decode(self, s):
        from base64 import b64decode

        return b64decode(s, validate = True)


_HASHES = (hashlib.new(n) for n in ('SHA-256', 'SHA-512'))
ENCODINGS = {
    h.name: Encoding.all(h)
    for h in _HASHES
}
_HASHES = (hashlib.new(n) for n in ("SHA-256", "SHA-512"))
ENCODINGS = {h.name: Encoding.all(h) for h in _HASHES}

RE = {
    h: "|".join(
        (f"({h}-)?" if e.name == 'base64' else '') +
        f"(?P<{h}_{e.name}>{e.regex})"
        (f"({h}-)?" if e.name == "base64" else "") + f"(?P<{h}_{e.name}>{e.regex})"
        for e in encodings
    ) for h, encodings in ENCODINGS.items()
    )
    for h, encodings in ENCODINGS.items()
}

_DEF_RE = re.compile("|".join(
_DEF_RE = re.compile(
    "|".join(
        f"(?P<{h}>{h} = (?P<{h}_quote>['\"])({re})(?P={h}_quote);)"
        for h, re in RE.items()
))
    )
)


def defToSRI(s: str) -> str:
@@ -153,7 +159,7 @@ def defToSRI(s: str) -> str:

@contextmanager
def atomicFileUpdate(target: Path):
    '''Atomically replace the contents of a file.
    """Atomically replace the contents of a file.

    Guarantees that no temporary files are left behind, and `target` is either
    left untouched, or overwritten with new content if no exception was raised.
@@ -164,18 +170,20 @@ def atomicFileUpdate(target: Path):

    Upon exiting the context, the files are closed; if no exception was
    raised, `new` (atomically) replaces the `target`, otherwise it is deleted.
    '''
    """
    # That's mostly copied from noto-emoji.py, should DRY it out
    from tempfile import mkstemp
    fd, _p = mkstemp(
        dir = target.parent,
        prefix = target.name,
    )
    tmpPath = Path(_p)
    from tempfile import NamedTemporaryFile

    try:
        with target.open() as original:
            with tmpPath.open('w') as new:
            with NamedTemporaryFile(
                dir = target.parent,
                prefix = target.stem,
                suffix = target.suffix,
                delete = False,
                mode="w",  # otherwise the file would be opened in binary mode by default
            ) as new:
                tmpPath = Path(new.name)
                yield (original, new)

        tmpPath.replace(target)
@@ -192,33 +200,27 @@ def fileToSRI(p: Path):
                new.write(defToSRI(line))


_SKIP_RE = re.compile(
    "(generated by)|(do not edit)",
    re.IGNORECASE
)
_SKIP_RE = re.compile("(generated by)|(do not edit)", re.IGNORECASE)

if __name__ == "__main__":
    from sys import argv, stderr
    from sys import argv

    logger.info("Starting!")

    for arg in argv[1:]:
        p = Path(arg)
        with log_context(path=str(p)):
    def handleFile(p: Path, skipLevel = logging.INFO):
        with log_context(file = str(p)):
            try:
                if p.name == "yarn.nix" or p.name.find("generated") != -1:
                    logger.warning("File looks autogenerated, skipping!")
                    continue

                with p.open() as f:
                    for line in f:
                        if line.strip():
                            break

                    if _SKIP_RE.search(line):
                        logger.warning("File looks autogenerated, skipping!")
                        continue
                        logger.log(skipLevel, "File looks autogenerated, skipping!")
                        return

                fileToSRI(p)

            except Exception as exn:
                logger.error(
                    "Unhandled exception, skipping file!",
@@ -226,3 +228,19 @@ if __name__ == "__main__":
                )
            else:
                logger.info("Finished processing file")

    for arg in argv[1:]:
        p = Path(arg)
        with log_context(arg = arg):
            if p.is_file():
                handleFile(p, skipLevel = logging.WARNING)

            elif p.is_dir():
                logger.info("Recursing into directory")
                for q in p.glob("**/*.nix"):
                    if q.is_file():
                        if q.name == "yarn.nix" or q.name.find("generated") != -1:
                            logger.info("File looks autogenerated, skipping!")
                            continue

                        handleFile(q)