Unverified Commit 2464f33b authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #14046 from mvdbeek/fix_exception_handing_in_pulsar_job_runner

[22.05] Fix exception handling in pulsar job runner
parents 522e0c25 b453e360
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -993,6 +993,7 @@ class MinimalJobWrapper(HasResourceParameters):
        self.params = None
        if job.params:
            self.params = loads(job.params)
        self.runner_command_line = None

        # Wrapper holding the info required to restore and clean up from files used for setting metadata externally
        self.__external_output_metadata = None
@@ -1161,7 +1162,7 @@ class MinimalJobWrapper(HasResourceParameters):
    def galaxy_url(self):
        return self.get_destination_configuration("galaxy_infrastructure_url")

    def get_job(self):
    def get_job(self) -> model.Job:
        return self.sa_session.query(model.Job).get(self.job_id)

    def get_id_tag(self):
@@ -1528,7 +1529,7 @@ class MinimalJobWrapper(HasResourceParameters):
        if flush:
            self.sa_session.flush()

    def get_state(self):
    def get_state(self) -> str:
        job = self.get_job()
        self.sa_session.refresh(job)
        return job.state
@@ -2512,7 +2513,9 @@ class TaskWrapper(JobWrapper):
        self.status = "prepared"
        return self.extra_filenames

    def fail(self, message, exception=False):
    def fail(
        self, message, exception=False, tool_stdout="", tool_stderr="", exit_code=None, job_stdout=None, job_stderr=None
    ):
        log.error(f"TaskWrapper Failure {message}")
        self.status = "error"
        # How do we want to handle task failure?  Fail the job and let it clean up?
+55 −35
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ from galaxy.job_execution.output_collect import (
    default_exit_code_file,
    read_exit_code_from,
)
from galaxy.job_execution.setup import JobIO
from galaxy.jobs.command_factory import build_command
from galaxy.jobs.runners.util import runner_states
from galaxy.jobs.runners.util.env import env_to_statement
@@ -49,6 +50,7 @@ if typing.TYPE_CHECKING:
    from galaxy.jobs import (
        JobDestination,
        JobWrapper,
        MinimalJobWrapper,
    )

log = get_logger(__name__)
@@ -79,10 +81,12 @@ class RunnerParams(ParamsWithSpecs):

class BaseJobRunner:

    runner_name = "BaseJobRunner"

    start_methods = ["_init_monitor_thread", "_init_worker_threads"]
    DEFAULT_SPECS = dict(recheck_missing_job_retries=dict(map=int, valid=lambda x: int(x) >= 0, default=0))

    def __init__(self, app, nworkers, **kwargs):
    def __init__(self, app, nworkers: int, **kwargs):
        """Start the job runner"""
        self.app = app
        self.redact_email_in_job_name = self.app.config.redact_email_in_job_name
@@ -163,7 +167,7 @@ class BaseJobRunner:
                    self.work_queue.put((self.fail_job, job_state))

    # Causes a runner's `queue_job` method to be called from a worker thread
    def put(self, job_wrapper):
    def put(self, job_wrapper: "MinimalJobWrapper"):
        """Add a job to the queue (by job identifier), indicate that the job is ready to run."""
        put_timer = ExecutionTimer()
        queue_job = job_wrapper.enqueue()
@@ -171,7 +175,7 @@ class BaseJobRunner:
            self.mark_as_queued(job_wrapper)
            log.debug(f"Job [{job_wrapper.job_id}] queued {put_timer}")

    def mark_as_queued(self, job_wrapper):
    def mark_as_queued(self, job_wrapper: "MinimalJobWrapper"):
        self.work_queue.put((self.queue_job, job_wrapper))

    def shutdown(self):
@@ -213,7 +217,7 @@ class BaseJobRunner:
                )

    # Most runners should override the legacy URL handler methods and destination param method
    def url_to_destination(self, url):
    def url_to_destination(self, url: str):
        """
        Convert a legacy URL to a JobDestination.

@@ -223,22 +227,21 @@ class BaseJobRunner:
        """
        return galaxy.jobs.JobDestination(runner=url.split(":")[0])

    def parse_destination_params(self, params):
    def parse_destination_params(self, params: typing.Dict[str, typing.Any]):
        """Parse the JobDestination ``params`` dict and return the runner's native representation of those params."""
        raise NotImplementedError()

    def prepare_job(
        self,
        job_wrapper,
        include_metadata=False,
        include_work_dir_outputs=True,
        modify_command_for_container=True,
        stream_stdout_stderr=False,
        job_wrapper: "MinimalJobWrapper",
        include_metadata: bool = False,
        include_work_dir_outputs: bool = True,
        modify_command_for_container: bool = True,
        stream_stdout_stderr: bool = False,
    ):
        """Some sanity checks that all runners' queue_job() methods are likely to want to do"""
        job_id = job_wrapper.get_id_tag()
        job_state = job_wrapper.get_state()
        job_wrapper.is_ready = False
        job_wrapper.runner_command_line = None

        # Make sure the job hasn't been deleted
@@ -285,10 +288,10 @@ class BaseJobRunner:

    def build_command_line(
        self,
        job_wrapper,
        include_metadata=False,
        include_work_dir_outputs=True,
        modify_command_for_container=True,
        job_wrapper: "MinimalJobWrapper",
        include_metadata: bool = False,
        include_work_dir_outputs: bool = True,
        modify_command_for_container: bool = True,
        stream_stdout_stderr=False,
    ):
        container = self._find_container(job_wrapper)
@@ -304,7 +307,12 @@ class BaseJobRunner:
            stream_stdout_stderr=stream_stdout_stderr,
        )

    def get_work_dir_outputs(self, job_wrapper, job_working_directory=None, tool_working_directory=None):
    def get_work_dir_outputs(
        self,
        job_wrapper: "MinimalJobWrapper",
        job_working_directory: typing.Optional[str] = None,
        tool_working_directory: typing.Optional[str] = None,
    ):
        """
        Returns list of pairs (source_file, destination) describing path
        to work_dir output file and ultimate destination.
@@ -317,6 +325,7 @@ class BaseJobRunner:
        if tool_working_directory is None:
            if not job_working_directory:
                job_working_directory = os.path.abspath(job_wrapper.working_directory)
                assert job_working_directory
            tool_working_directory = os.path.join(job_working_directory, "working")

        # Set up dict of dataset id --> output path; output path can be real or
@@ -351,7 +360,7 @@ class BaseJobRunner:
                        )
        return output_pairs

    def _walk_dataset_outputs(self, job):
    def _walk_dataset_outputs(self, job: model.Job):
        for dataset_assoc in job.output_datasets + job.output_library_datasets:
            for dataset in (
                dataset_assoc.dataset.dataset.history_associations + dataset_assoc.dataset.dataset.library_associations
@@ -368,7 +377,7 @@ class BaseJobRunner:
        #      yield (dataset_assoc, dataset_assoc.dataset)
        #  I don't understand the reworking it backwards.  -John

    def _handle_metadata_externally(self, job_wrapper, resolve_requirements=False):
    def _handle_metadata_externally(self, job_wrapper: "MinimalJobWrapper", resolve_requirements: bool = False):
        """
        Set metadata externally. Used by the Pulsar job runner where this
        shouldn't be attached to command line to execute.
@@ -433,13 +442,13 @@ class BaseJobRunner:
                external_metadata_proc.wait()
            log.debug("execution of external set_meta for job %d finished" % job_wrapper.job_id)

    def get_job_file(self, job_wrapper, **kwds):
    def get_job_file(self, job_wrapper: "MinimalJobWrapper", **kwds) -> str:
        job_metrics = job_wrapper.app.job_metrics
        job_instrumenter = job_metrics.job_instrumenters[job_wrapper.job_destination.id]

        env_setup_commands = kwds.get("env_setup_commands", [])
        env_setup_commands.append(job_wrapper.get_env_setup_clause() or "")
        destination = job_wrapper.job_destination or {}
        destination = job_wrapper.job_destination
        envs = destination.get("env", [])
        envs.extend(job_wrapper.environment_variables)
        for env in envs:
@@ -464,16 +473,16 @@ class BaseJobRunner:
        options.update(**kwds)
        return job_script(**options)

    def write_executable_script(self, path, contents, job_io):
    def write_executable_script(self, path: str, contents: str, job_io: JobIO):
        write_script(path, contents, job_io)

    def _find_container(
        self,
        job_wrapper,
        compute_working_directory=None,
        compute_tool_directory=None,
        compute_job_directory=None,
        compute_tmp_directory=None,
        job_wrapper: "MinimalJobWrapper",
        compute_working_directory: typing.Optional[str] = None,
        compute_tool_directory: typing.Optional[str] = None,
        compute_job_directory: typing.Optional[str] = None,
        compute_tmp_directory: typing.Optional[str] = None,
    ):
        job_directory_type = "galaxy" if compute_working_directory is None else "pulsar"
        if not compute_working_directory:
@@ -515,7 +524,7 @@ class BaseJobRunner:
            job_wrapper.set_container(container)
        return container

    def _handle_runner_state(self, runner_state, job_state):
    def _handle_runner_state(self, runner_state, job_state: "JobState"):
        try:
            for handler in self.runner_state_handlers.get(runner_state, []):
                handler(self.app, self, job_state)
@@ -524,7 +533,7 @@ class BaseJobRunner:
        except Exception:
            log.exception("Caught exception in runner state handler")

    def fail_job(self, job_state, exception=False):
    def fail_job(self, job_state: "JobState", exception=False, message="Job failed", full_status=None):
        if getattr(job_state, "stop_job", True):
            self.stop_job(job_state.job_wrapper)
        job_state.job_wrapper.reclaim_ownership()
@@ -532,11 +541,17 @@ class BaseJobRunner:
        # Not convinced this is the best way to indicate this state, but
        # something necessary
        if not job_state.runner_state_handled:
            job_state.job_wrapper.fail(getattr(job_state, "fail_message", "Job failed"), exception=exception)
            if job_state.job_wrapper.cleanup_job == "always":
                job_state.cleanup()
            # full_status currently only passed in pulsar runner,
            # but might be useful for other runners in the future.
            full_status = full_status or {}
            tool_stdout = full_status.get("stdout")
            tool_stderr = full_status.get("stderr")
            fail_message = getattr(job_state, "fail_message", message)
            job_state.job_wrapper.fail(
                fail_message, tool_stdout=tool_stdout, tool_stderr=tool_stderr, exception=exception
            )

    def mark_as_resubmitted(self, job_state: "JobState", info=None):
    def mark_as_resubmitted(self, job_state: "JobState", info: typing.Optional[str] = None):
        job_state.job_wrapper.mark_as_resubmitted(info=info)
        if not self.app.config.track_jobs_in_database:
            job_state.job_wrapper.change_state(model.Job.states.QUEUED)
@@ -551,6 +566,10 @@ class BaseJobRunner:
        job_wrapper = job_state.job_wrapper
        try:
            job = job_state.job_wrapper.get_job()
            if job_id is None:
                job_id = job.get_id_tag()
            if external_job_id is None:
                external_job_id = job.get_job_runner_external_id()
            exit_code = job_state.read_exit_code()

            outputs_directory = os.path.join(job_wrapper.working_directory, "outputs")
@@ -631,6 +650,7 @@ class JobState:
        self.job_wrapper = job_wrapper
        self.job_destination = job_destination
        self.runner_state = None
        self.exit_code_file = default_exit_code_file(job_wrapper.working_directory, job_wrapper.get_id_tag())

        self.redact_email_in_job_name = True
        if self.job_wrapper:
@@ -645,7 +665,6 @@ class JobState:
                self.job_file = JobState.default_job_file(files_dir, id_tag)
                self.output_file = os.path.join(files_dir, f"galaxy_{id_tag}.o")
                self.error_file = os.path.join(files_dir, f"galaxy_{id_tag}.e")
                self.exit_code_file = default_exit_code_file(files_dir, id_tag)
            job_name = f"g{id_tag}"
            if self.job_wrapper.tool.old_id:
                job_name += f"_{self.job_wrapper.tool.old_id}"
@@ -705,6 +724,7 @@ class AsynchronousJobState(JobState):
        self.job_file = job_file
        self.output_file = output_file
        self.error_file = error_file
        if exit_code_file:
            self.exit_code_file = exit_code_file
        self.job_name = job_name

@@ -822,7 +842,7 @@ class AsynchronousJobRunner(BaseJobRunner, Monitors):
    def check_watched_item(self, job_state):
        raise NotImplementedError()

    def finish_job(self, job_state):
    def finish_job(self, job_state: AsynchronousJobState):
        """
        Get the output/error for a finished job, pass to `job_wrapper.finish`
        and cleanup all the job's temporary files.
+6 −15
Original line number Diff line number Diff line
@@ -180,11 +180,13 @@ class ChronosJobRunner(AsynchronousJobRunner):
    def recover(self, job, job_wrapper):
        msg = "(name!r/runner!r) is still in {state!s} state, adding to" " the runner monitor queue"
        job_id = job.get_job_runner_external_id()
        ajs = AsynchronousJobState(files_dir=job_wrapper.working_directory, job_wrapper=job_wrapper)
        ajs.job_id = self.JOB_NAME_PREFIX + str(job_id)
        ajs = AsynchronousJobState(
            files_dir=job_wrapper.working_directory,
            job_wrapper=job_wrapper,
            job_id=self.JOB_NAME_PREFIX + str(job_id),
            job_destination=job_wrapper.job_destination,
        )
        ajs.command_line = job.command_line
        ajs.job_wrapper = job_wrapper
        ajs.job_destination = job_wrapper.job_destination
        if job.state in (model.Job.states.RUNNING, model.Job.states.STOPPED):
            LOGGER.debug(msg.format(name=job.id, runner=job.job_runner_external_id, state=job.state))
            ajs.old_state = model.Job.states.RUNNING
@@ -196,17 +198,6 @@ class ChronosJobRunner(AsynchronousJobRunner):
            ajs.running = False
            self.monitor_queue.put(ajs)

    def fail_job(self, job_state, exception=False):
        if getattr(job_state, "stop_job", True):
            self.stop_job(job_state.job_wrapper)
        job_state.job_wrapper.reclaim_ownership()
        self._handle_runner_state("failure", job_state)
        if not job_state.runner_state_handled:
            job_state.job_wrapper.fail(getattr(job_state, "fail_message", "Job failed"), exception=exception)
            self._finish_or_resubmit_job(job_state, "", job_state.fail_message, job_id=job_state.job_id)
            if job_state.job_wrapper.cleanup_job == "always":
                job_state.cleanup()

    @handle_exception_call
    def check_watched_item(self, job_state):
        job_name = job_state.job_id
+6 −4
Original line number Diff line number Diff line
@@ -259,11 +259,13 @@ class ShellJobRunner(AsynchronousJobRunner):
        if job_id is None:
            self.put(job_wrapper)
            return
        ajs = AsynchronousJobState(files_dir=job_wrapper.working_directory, job_wrapper=job_wrapper)
        ajs.job_id = str(job_id)
        ajs = AsynchronousJobState(
            files_dir=job_wrapper.working_directory,
            job_wrapper=job_wrapper,
            job_id=job_id,
            job_destination=job_wrapper.job_destination,
        )
        ajs.command_line = job.command_line
        ajs.job_wrapper = job_wrapper
        ajs.job_destination = job_wrapper.job_destination
        if job.state in (model.Job.states.RUNNING, model.Job.states.STOPPED):
            log.debug(
                f"({job.id}/{job.job_runner_external_id}) is still in {job.state} state, adding to the runner monitor queue"
+6 −4
Original line number Diff line number Diff line
@@ -397,11 +397,13 @@ class DRMAAJobRunner(AsynchronousJobRunner):
        if job_id is None:
            self.put(job_wrapper)
            return
        ajs = AsynchronousJobState(files_dir=job_wrapper.working_directory, job_wrapper=job_wrapper)
        ajs.job_id = str(job_id)
        ajs = AsynchronousJobState(
            files_dir=job_wrapper.working_directory,
            job_wrapper=job_wrapper,
            job_id=job_id,
            job_destination=job_wrapper.job_destination,
        )
        ajs.command_line = job.get_command_line()
        ajs.job_wrapper = job_wrapper
        ajs.job_destination = job_wrapper.job_destination
        if job.state in (model.Job.states.RUNNING, model.Job.states.STOPPED):
            log.debug(
                f"({job.id}/{job.get_job_runner_external_id()}) is still in {job.state} state, adding to the DRM queue"
Loading