Unverified Commit cba4d1dd authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #20871 from nsoranzo/type_annot_job_runners

Improve type annotation for job runners and ``InteractiveToolManager``
parents 76b714d2 aa13a633
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -619,6 +619,7 @@ class GalaxyManagerApplication(MinimalManagerApp, MinimalGalaxyApplication):
        self.role_manager = self._register_singleton(RoleManager)
        self.job_manager = self._register_singleton(JobManager)
        self.notification_manager = self._register_singleton(NotificationManager)
        self.interactivetool_manager = InteractiveToolManager(self)

        self.task_manager = self._register_abstract_singleton(
            AsyncTasksManager, CeleryAsyncTasksManager  # type: ignore[type-abstract]  # https://github.com/python/mypy/issues/4717
@@ -839,8 +840,6 @@ class UniverseApplication(StructuredApp, GalaxyManagerApplication, InstallationT
        # Must be initialized after job_config.
        self.workflow_scheduling_manager = scheduling_manager.WorkflowSchedulingManager(self)

        # We need InteractiveToolManager before the job handler starts
        self.interactivetool_manager = InteractiveToolManager(self)
        # Start the job manager
        self.application_stack.register_postfork_function(self.job_manager.start)
        # Must be initialized after any component that might make use of stack messaging is configured. Alternatively if
+3 −2
Original line number Diff line number Diff line
@@ -2582,8 +2582,8 @@ class MinimalJobWrapper(HasResourceParameters):
                    command = f"{dependency_shell_commands}; {command}"
        return command

    def check_for_entry_points(self, check_already_configured=True):
        if not self.tool.produces_entry_points:
    def check_for_entry_points(self, check_already_configured: bool = True) -> bool:
        if self.tool and not self.tool.produces_entry_points:
            return True

        job = self.get_job()
@@ -2611,6 +2611,7 @@ class MinimalJobWrapper(HasResourceParameters):
            self.fail(error_message)
            # local job runner uses return value to determine if we're done polling
            return True
        return False

    def container_monitor_command(self, container, **kwds):
        if (
+9 −8
Original line number Diff line number Diff line
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
        JobWrapper,
        MinimalJobWrapper,
    )
    from galaxy.schema.schema import JobState as JobStateEnum

log = get_logger(__name__)

@@ -100,7 +101,7 @@ class BaseJobRunner:
    start_methods = ["_init_monitor_thread", "_init_worker_threads"]
    DEFAULT_SPECS = dict(recheck_missing_job_retries=dict(map=int, valid=lambda x: int(x) >= 0, default=0))

    def __init__(self, app: "GalaxyManagerApplication", nworkers: int, **kwargs):
    def __init__(self, app: "GalaxyManagerApplication", nworkers: int, **kwargs) -> None:
        """Start the job runner"""
        self.app = app
        self.redact_email_in_job_name = self.app.config.redact_email_in_job_name
@@ -760,7 +761,7 @@ class AsynchronousJobState(JobState):
        self,
        files_dir=None,
        job_wrapper=None,
        job_id=None,
        job_id: Union[str, None] = None,
        job_file=None,
        output_file=None,
        error_file=None,
@@ -769,7 +770,7 @@ class AsynchronousJobState(JobState):
        job_destination=None,
    ):
        super().__init__(job_wrapper, job_destination)
        self.old_state = None
        self.old_state: Union[JobStateEnum, None] = None
        self._running = False
        self.check_count = 0
        self.start_time = None
@@ -825,15 +826,15 @@ class AsynchronousJobRunner(BaseJobRunner, Monitors):
    to the correct methods (queue, finish, cleanup) at appropriate times..
    """

    def __init__(self, app, nworkers, **kwargs):
    def __init__(self, app: "GalaxyManagerApplication", nworkers: int, **kwargs) -> None:
        super().__init__(app, nworkers, **kwargs)
        # 'watched' and 'queue' are both used to keep track of jobs to watch.
        # 'queue' is used to add new watched jobs, and can be called from
        # any thread (usually by the 'queue_job' method). 'watched' must only
        # be modified by the monitor thread, which will move items from 'queue'
        # to 'watched' and then manage the watched jobs.
        self.watched = []
        self.monitor_queue = Queue()
        self.watched: list[AsynchronousJobState] = []
        self.monitor_queue: Queue[AsynchronousJobState] = Queue()

    def _init_monitor_thread(self):
        name = f"{self.runner_name}.monitor_thread"
@@ -876,7 +877,7 @@ class AsynchronousJobRunner(BaseJobRunner, Monitors):
            # Sleep a bit before the next state check
            time.sleep(self.app.config.job_runner_monitor_sleep)

    def monitor_job(self, job_state):
    def monitor_job(self, job_state: AsynchronousJobState) -> None:
        self.monitor_queue.put(job_state)

    def shutdown(self):
@@ -903,7 +904,7 @@ class AsynchronousJobRunner(BaseJobRunner, Monitors):
        self.watched = new_watched

    # Subclasses should implement this unless they override check_watched_items all together.
    def check_watched_item(self, job_state):
    def check_watched_item(self, job_state: AsynchronousJobState) -> Union[AsynchronousJobState, None]:
        raise NotImplementedError()

    def finish_job(self, job_state: AsynchronousJobState):
+11 −11
Original line number Diff line number Diff line
@@ -449,7 +449,7 @@ class AWSBatchJobRunner(AsynchronousJobRunner):
                        # TODO: This is where any cleanup would occur
                        self.handle_stop()
                        return
                    self.watched.append((async_job_state.job_id, async_job_state))
                    self.watched.append(async_job_state)
            except Empty:
                pass
            # Iterate over the list of watched jobs and check state
@@ -463,14 +463,15 @@ class AWSBatchJobRunner(AsynchronousJobRunner):
    def check_watched_items(self):
        done: set[str] = set()
        self.check_watched_items_by_batch(0, len(self.watched), done)
        self.watched = [x for x in self.watched if x[0] not in done]
        self.watched = [ajs for ajs in self.watched if ajs.job_id not in done]

    def check_watched_items_by_batch(self, start: int, end: int, done: set[str]):
        jobs = self.watched[start : start + self.MAX_JOBS_PER_QUERY]
        if not jobs:
    def check_watched_items_by_batch(self, start: int, end: int, done: set[str]) -> None:
        async_job_states = self.watched[start : start + self.MAX_JOBS_PER_QUERY]
        if not async_job_states:
            return

        jobs_dict = dict(jobs)
        jobs_dict = {ajs.job_id: ajs for ajs in async_job_states if ajs.job_id is not None}

        resp = self._batch_client.describe_jobs(jobs=list(jobs_dict.keys()))

        gotten = set()
@@ -492,27 +493,26 @@ class AWSBatchJobRunner(AsynchronousJobRunner):
            # remain queued for "SUBMITTED", "PENDING" and "RUNNABLE"
            # TODO else?

        for job_id in jobs_dict:
        for job_id, job_state in jobs_dict.items():
            if job_id in gotten:
                continue
            job_state = jobs_dict[job_id]
            reason = f"The track of Job {job_state} was lost for unknown reason!"
            self._mark_as_failed(job_state, reason)
            done.add(job_id)

        self.check_watched_items_by_batch(start + self.MAX_JOBS_PER_QUERY, end, done)

    def _mark_as_successful(self, job_state):
    def _mark_as_successful(self, job_state: AsynchronousJobState) -> None:
        _write_logfile(job_state.output_file, "")
        _write_logfile(job_state.error_file, "")
        job_state.running = False
        self.mark_as_finished(job_state)

    def _mark_as_active(self, job_state):
    def _mark_as_active(self, job_state: AsynchronousJobState) -> None:
        job_state.running = True
        job_state.job_wrapper.change_state(model.Job.states.RUNNING)

    def _mark_as_failed(self, job_state, reason):
    def _mark_as_failed(self, job_state: AsynchronousJobState, reason: str) -> None:
        _write_logfile(job_state.error_file, reason)
        job_state.running = False
        job_state.stop_job = False
+11 −7
Original line number Diff line number Diff line
import functools
import logging
import os
from typing import Union

from galaxy import model
from galaxy.jobs.runners import (
@@ -199,14 +200,15 @@ class ChronosJobRunner(AsynchronousJobRunner):
            self.monitor_queue.put(ajs)

    @handle_exception_call
    def check_watched_item(self, job_state):
    def check_watched_item(self, job_state: AsynchronousJobState) -> Union[AsynchronousJobState, None]:
        job_name = job_state.job_id
        # TODO: how can stopped GxIT jobs be handled here?
        if job := self._retrieve_job(job_name):
            succeeded = job["successCount"]
            errors = job["errorCount"]
            if succeeded > 0:
                return self._mark_as_successful(job_state)
                self._mark_as_successful(job_state)
                return None
            elif not succeeded and not errors:
                return self._mark_as_active(job_state)
            elif errors:
@@ -216,11 +218,13 @@ class ChronosJobRunner(AsynchronousJobRunner):
                else:
                    msg = "Job {name!r} failed more than {retries!s} times."
                reason = msg.format(name=job_name, retries=str(max_retries))
                return self._mark_as_failed(job_state, reason)
                self._mark_as_failed(job_state, reason)
                return None
        reason = f"Job {job_name!r} not found"
        return self._mark_as_failed(job_state, reason)
        self._mark_as_failed(job_state, reason)
        return None

    def _mark_as_successful(self, job_state):
    def _mark_as_successful(self, job_state: AsynchronousJobState) -> None:
        msg = "Job {name!r} finished successfully"
        _write_logfile(job_state.output_file, msg.format(name=job_state.job_id))
        _write_logfile(job_state.error_file, "")
@@ -229,12 +233,12 @@ class ChronosJobRunner(AsynchronousJobRunner):
        self.mark_as_finished(job_state)
        return None

    def _mark_as_active(self, job_state):
    def _mark_as_active(self, job_state: AsynchronousJobState) -> AsynchronousJobState:
        job_state.running = True
        job_state.job_wrapper.change_state(model.Job.states.RUNNING)
        return job_state

    def _mark_as_failed(self, job_state, reason):
    def _mark_as_failed(self, job_state: AsynchronousJobState, reason: str) -> None:
        _write_logfile(job_state.error_file, reason)
        job_state.running = False
        job_state.stop_job = True
Loading