Commit 9f10c9bc authored by Jacek Galowicz's avatar Jacek Galowicz
Browse files

test-driver: Factor out OCR related code to machine/ocr.py

parent 2c8500b9
Loading
Loading
Loading
Loading
+26 −94
Original line number Diff line number Diff line
@@ -13,8 +13,8 @@ import sys
import tempfile
import threading
import time
from collections.abc import Callable, Iterable
from contextlib import _GeneratorContextManager, nullcontext
from collections.abc import Callable, Generator
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
from pathlib import Path
from queue import Queue
from typing import Any
@@ -22,6 +22,7 @@ from typing import Any
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger

from .ocr import perform_ocr_on_screenshot, perform_ocr_variants_on_screenshot
from .qmp import QMPSession

CHAR_TO_KEY = {
@@ -92,84 +93,6 @@ def make_command(args: list) -> str:
    return " ".join(map(shlex.quote, (map(str, args))))


def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
    magick_args = [
        "-filter",
        "Catrom",
        "-density",
        "72",
        "-resample",
        "300",
        "-contrast",
        "-normalize",
        "-despeckle",
        "-type",
        "grayscale",
        "-sharpen",
        "1",
        "-posterize",
        "3",
    ]
    out_file = screenshot_path

    if negate:
        magick_args.append("-negate")
        out_file += ".negative"

    magick_args += [
        "-gamma",
        "100",
        "-blur",
        "1x65535",
    ]
    out_file += ".png"

    ret = subprocess.run(
        ["magick", "convert"] + magick_args + [screenshot_path, out_file],
        capture_output=True,
    )

    if ret.returncode != 0:
        raise MachineError(
            f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
        )

    return out_file


def _perform_ocr_on_screenshot(
    screenshot_path: str, model_ids: Iterable[int]
) -> list[str]:
    if shutil.which("tesseract") is None:
        raise MachineError("OCR requested but enableOCR is false")

    processed_image = _preprocess_screenshot(screenshot_path, negate=False)
    processed_negative = _preprocess_screenshot(screenshot_path, negate=True)

    model_results = []
    for image in [screenshot_path, processed_image, processed_negative]:
        for model_id in model_ids:
            ret = subprocess.run(
                [
                    "tesseract",
                    image,
                    "-",
                    "--oem",
                    str(model_id),
                    "-c",
                    "debug_file=/dev/null",
                    "--psm",
                    "11",
                ],
                capture_output=True,
            )
            if ret.returncode != 0:
                raise MachineError(f"OCR failed with exit code {ret.returncode}")
            model_results.append(ret.stdout.decode("utf-8"))

    return model_results


def retry(fn: Callable, timeout: int = 900) -> None:
    """Call the given function repeatedly, with 1 second intervals,
    until it returns True or a timeout is reached.
@@ -910,6 +833,17 @@ class Machine:
            self.log(f"(connecting took {toc - tic:.2f} seconds)")
            self.connected = True

    @contextmanager
    def _managed_screenshot(self) -> Generator[str]:
        """
        Take a screenshot and yield the screenshot filepath.
        The file will be deleted when leaving the generator.
        """
        with tempfile.TemporaryDirectory() as tmpdir:
            screenshot_path: str = os.path.join(tmpdir, "ppm")
            self.send_monitor_command(f"screendump {screenshot_path}")
            yield screenshot_path

    def screenshot(self, filename: str) -> None:
        """
        Take a picture of the display of the virtual machine, in PNG format.
@@ -919,17 +853,19 @@ class Machine:
            filename += ".png"
        if "/" not in filename:
            filename = os.path.join(self.out_dir, filename)
        tmp = f"{filename}.ppm"

        with self.nested(
            f"making screenshot {filename}",
            {"image": os.path.basename(filename)},
        ):
            self.send_monitor_command(f"screendump {tmp}")
            ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
            os.unlink(tmp)
            with self._managed_screenshot() as screenshot_path:
                ret = subprocess.run(
                    f"pnmtopng '{screenshot_path}' > '{filename}'", shell=True
                )
                if ret.returncode != 0:
                raise MachineError("Cannot convert screenshot")
                    raise MachineError(
                        f"Cannot convert screenshot (pnmtopng returned code {ret.returncode})"
                    )

    def copy_from_host_via_shell(self, source: str, target: str) -> None:
        """Copy a file from the host into the guest by piping it over the
@@ -1003,12 +939,6 @@ class Machine:
        """Debugging: Dump the contents of the TTY<n>"""
        self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat")

    def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]:
        with tempfile.TemporaryDirectory() as tmpdir:
            screenshot_path = os.path.join(tmpdir, "ppm")
            self.send_monitor_command(f"screendump {screenshot_path}")
            return _perform_ocr_on_screenshot(screenshot_path, model_ids)

    def get_screen_text_variants(self) -> list[str]:
        """
        Return a list of different interpretations of what is currently
@@ -1021,7 +951,8 @@ class Machine:
        This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
        :::
        """
        return self._get_screen_text_variants([0, 1, 2])
        with self._managed_screenshot() as screenshot_path:
            return perform_ocr_variants_on_screenshot(screenshot_path)

    def get_screen_text(self) -> str:
        """
@@ -1032,7 +963,8 @@ class Machine:
        This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
        :::
        """
        return self._get_screen_text_variants([2])[0]
        with self._managed_screenshot() as screenshot_path:
            return perform_ocr_on_screenshot(screenshot_path)

    def wait_for_text(self, regex: str, timeout: int = 900) -> None:
        """
+111 −0
Original line number Diff line number Diff line
import shutil
import subprocess

from test_driver.errors import MachineError


def perform_ocr_on_screenshot(screenshot_path: str) -> str:
    """
    Perform OCR on a screenshot that contains text.
    Returns a string with all words that could be found.
    """
    return perform_ocr_variants_on_screenshot(screenshot_path, False)[0]


def perform_ocr_variants_on_screenshot(
    screenshot_path: str, variants: bool = True
) -> list[str]:
    """
    Same as perform_ocr_on_screenshot but will create variants of the images
    that can lead to more words being detected.
    Returns a string with words for each variant.
    """
    if shutil.which("tesseract") is None:
        raise MachineError("OCR requested but `tesseract` is not available")

    # tesseract --help-oem
    # OCR Engine modes (OEM):
    #  0|tesseract_only          Legacy engine only.
    #  1|lstm_only               Neural nets LSTM engine only.
    #  2|tesseract_lstm_combined Legacy + LSTM engines.
    #  3|default                 Default, based on what is available.
    model_ids: list[int] = [0, 1, 2] if variants else [3]

    image_paths = [
        screenshot_path,
        _preprocess_screenshot(screenshot_path, negate=False),
        _preprocess_screenshot(screenshot_path, negate=True),
    ]

    def run_tesseract(image: str, model_id: int) -> str:
        ret = subprocess.run(
            [
                "tesseract",
                image,
                "-",
                "--oem",
                str(model_id),
                "-c",
                "debug_file=/dev/null",
                "--psm",
                "11",
            ],
            capture_output=True,
        )
        if ret.returncode != 0:
            raise MachineError(f"OCR failed with exit code {ret.returncode}")
        return ret.stdout.decode("utf-8")

    return [
        run_tesseract(image, model_id)
        for image in image_paths
        for model_id in model_ids
    ]


def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
    if shutil.which("magick") is None:
        raise MachineError("OCR requested but `magick` is not available")

    magick_args = [
        "-filter",
        "Catrom",
        "-density",
        "72",
        "-resample",
        "300",
        "-contrast",
        "-normalize",
        "-despeckle",
        "-type",
        "grayscale",
        "-sharpen",
        "1",
        "-posterize",
        "3",
    ]
    out_file = screenshot_path

    if negate:
        magick_args.append("-negate")
        out_file += ".negative"

    magick_args += [
        "-gamma",
        "100",
        "-blur",
        "1x65535",
    ]
    out_file += ".png"

    ret = subprocess.run(
        ["magick", "convert"] + magick_args + [screenshot_path, out_file],
        capture_output=True,
    )

    if ret.returncode != 0:
        raise MachineError(
            f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
        )

    return out_file