Commit 14c01b5a authored by Jacek Galowicz's avatar Jacek Galowicz
Browse files

test-driver: Parallelize OCR

parent 9f10c9bc
Loading
Loading
Loading
Loading
+44 −33
Original line number Diff line number Diff line
import itertools
import multiprocessing
import os
import shutil
import subprocess

@@ -9,7 +12,10 @@ def perform_ocr_on_screenshot(screenshot_path: str) -> str:
    Perform OCR on a screenshot that contains text.
    Returns a string with all words that could be found.
    """
    return perform_ocr_variants_on_screenshot(screenshot_path, False)[0]
    variants = perform_ocr_variants_on_screenshot(screenshot_path, False)[0]
    if len(variants) != 1:
        raise MachineError(f"Received wrong number of OCR results: {len(variants)}")
    return variants[0]


def perform_ocr_variants_on_screenshot(
@@ -29,15 +35,26 @@ def perform_ocr_variants_on_screenshot(
    #  1|lstm_only               Neural nets LSTM engine only.
    #  2|tesseract_lstm_combined Legacy + LSTM engines.
    #  3|default                 Default, based on what is available.
    model_ids: list[int] = [0, 1, 2] if variants else [3]
    model_ids: list[int] = [0, 1] if variants else [2]

    # Tesseract runs parallel on up to 4 cores.
    # Docs suggest to run it with OMP_THREAD_LIMIT=1 for hundreds of parallel
    # runs. Our average test run is somewhere inbetween.
    # https://github.com/tesseract-ocr/tesseract/issues/3109
    processes = max(1, int(os.process_cpu_count() / 4))
    with multiprocessing.Pool(processes=processes) as pool:
        image_paths: list[str] = [screenshot_path]
        if variants:
            image_paths.extend(
                pool.starmap(
                    _preprocess_screenshot,
                    [(screenshot_path, False), (screenshot_path, True)],
                )
            )
        return pool.starmap(_run_tesseract, itertools.product(image_paths, model_ids))

    image_paths = [
        screenshot_path,
        _preprocess_screenshot(screenshot_path, negate=False),
        _preprocess_screenshot(screenshot_path, negate=True),
    ]

    def run_tesseract(image: str, model_id: int) -> str:
def _run_tesseract(image: str, model_id: int) -> str:
    ret = subprocess.run(
        [
            "tesseract",
@@ -56,12 +73,6 @@ def perform_ocr_variants_on_screenshot(
        raise MachineError(f"OCR failed with exit code {ret.returncode}")
    return ret.stdout.decode("utf-8")

    return [
        run_tesseract(image, model_id)
        for image in image_paths
        for model_id in model_ids
    ]


def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
    if shutil.which("magick") is None: