Unverified Commit e84e7156 authored by mvdbeek's avatar mvdbeek
Browse files

Guard state update with limit queries

These are essentially the same queries that are done in
`JobHandler.__check_user_jobs`, `JobHandler.__check_destination_jobs`
etc, but now it's all in in a single update statement.

I suppose performance might be a concern, however we still run through
the (cached) checks before we decide to queue the job, so I think
the cost is likely minimal. By integrating the limit check in the query
i think it should become very unlikely that jobs can bypass limits in a
multi handler scenario.
parent f4692d3c
Loading
Loading
Loading
Loading
+127 −14
Original line number Diff line number Diff line
@@ -14,6 +14,10 @@ import shutil
import sys
import time
import traceback
from dataclasses import (
    dataclass,
    field,
)
from json import loads
from typing import (
    Any,
@@ -27,7 +31,12 @@ from typing import (
import yaml
from packaging.version import Version
from pulsar.client.staging import COMMAND_VERSION_FILENAME
from sqlalchemy import select
from sqlalchemy import (
    and_,
    func,
    select,
    update,
)

from galaxy import (
    model,
@@ -296,6 +305,18 @@ def job_config_xml_to_dict(config, root):
    return config_dict


@dataclass
class JobConfigurationLimits:
    registered_user_concurrent_jobs: Optional[int] = None
    anonymous_user_concurrent_jobs: Optional[int] = None
    walltime: Optional[str] = None
    walltime_delta: Optional[datetime.timedelta] = None
    total_walltime: Dict[str, Any] = field(default_factory=dict)
    output_size: Optional[int] = None
    destination_user_concurrent_jobs: Dict[str, int] = field(default_factory=dict)
    destination_total_concurrent_jobs: Dict[str, int] = field(default_factory=dict)


class JobConfiguration(ConfiguresHandlers):
    """A parser and interface to advanced job management features.

@@ -344,16 +365,7 @@ class JobConfiguration(ConfiguresHandlers):
        self.resource_groups = {}
        self.default_resource_group = None
        self.resource_parameters = {}
        self.limits = Bunch(
            registered_user_concurrent_jobs=None,
            anonymous_user_concurrent_jobs=None,
            walltime=None,
            walltime_delta=None,
            total_walltime={},
            output_size=None,
            destination_user_concurrent_jobs={},
            destination_total_concurrent_jobs={},
        )
        self.limits = JobConfigurationLimits()

        default_resubmits = []
        default_resubmit_condition = self.app.config.default_job_resubmission_condition
@@ -1610,12 +1622,113 @@ class MinimalJobWrapper(HasResourceParameters):
        dest_params = self.job_destination.params
        return self.get_job().get_destination_configuration(dest_params, self.app.config, key, default)

    def queue_with_limit(self, job: Job, job_destination: JobDestination):
        anonymous_user_concurrent_jobs = self.app.job_config.limits.anonymous_user_concurrent_jobs
        registered_user_concurrent_jobs = self.app.job_config.limits.registered_user_concurrent_jobs
        destination_total_limit = self.app.job_config.limits.destination_total_concurrent_jobs.get(job_destination.id)
        destination_user_limit = self.app.job_config.limits.destination_user_concurrent_jobs.get(job_destination.id)

        conditions = [model.Job.table.c.id == job.id]

        if job.user_id:
            user_job_count = (
                select(func.count(model.Job.table.c.id))
                .where(
                    and_(
                        model.Job.table.c.state.in_(
                            [
                                model.Job.states.QUEUED,
                                model.Job.states.RUNNING,
                                model.Job.states.RESUBMITTED,
                            ]
                        ),
                        model.Job.table.c.user_id == job.user_id,
                    )
                )
                .scalar_subquery()
            )

            if registered_user_concurrent_jobs is not None:
                conditions.append(user_job_count < registered_user_concurrent_jobs)
            if destination_user_limit is not None:
                destination_job_count = (
                    select(func.count(model.Job.table.c.id))
                    .where(
                        and_(
                            model.Job.table.c.state.in_(
                                [
                                    model.Job.states.QUEUED,
                                    model.Job.states.RUNNING,
                                    model.Job.states.RESUBMITTED,
                                ]
                            ),
                            model.Job.table.c.destination_id == job_destination.id,
                            model.Job.table.c.user_id == job.user_id,
                        )
                    )
                    .scalar_subquery()
                )
                conditions.append(destination_job_count < destination_user_limit)

        elif anonymous_user_concurrent_jobs and job.galaxy_session and job.galaxy_session.id:
            anon_job_count = (
                select(func.count(model.Job.table.c.id))
                .where(
                    and_(
                        model.Job.table.c.state.in_(
                            [
                                model.Job.states.QUEUED,
                                model.Job.states.RUNNING,
                                model.Job.states.RESUBMITTED,
                            ]
                        ),
                        model.Job.table.c.session_id == job.galaxy_session.id,
                    )
                )
                .scalar_subquery()
            )
            conditions.append(anon_job_count < anonymous_user_concurrent_jobs)

        if destination_total_limit is not None:
            destination_total_count = (
                select(func.count(model.Job.table.c.id))
                .where(
                    and_(
                        model.Job.table.c.state.in_(
                            [
                                model.Job.states.QUEUED,
                                model.Job.states.RUNNING,
                                model.Job.states.RESUBMITTED,
                            ]
                        ),
                        model.Job.table.c.destination_id == job_destination.id,
                    )
                )
                .scalar_subquery()
            )
            conditions.append(destination_total_count < destination_total_limit)

        update_stmt = (
            update(model.Job)
            .where(*conditions)
            .values(
                state=model.Job.states.QUEUED,
                destination_id=job_destination.id,
                destination_params=job_destination.params,
                job_runner_name=job_destination.runner,
            )
        )

        result = self.sa_session.execute(update_stmt)
        self.sa_session.commit()

        return result.rowcount > 0

    def enqueue(self):
        job = self.get_job()
        # Change to queued state before handing to worker thread so the runner won't pick it up again
        self.change_state(model.Job.states.QUEUED, flush=False, job=job)
        # Persist the destination so that the job will be included in counts if using concurrency limits
        self.set_job_destination(self.job_destination, None, flush=False, job=job)
        if not self.queue_with_limit(job, self.job_destination):
            return False
        # Set object store after job destination so can leverage parameters...
        self._set_object_store_ids(job)
        # Now that we have the object store id, check if we are over the limit