Unverified Commit ecc4b478 authored by Nate Coraor's avatar Nate Coraor Committed by GitHub
Browse files

Merge pull request #19824 from mvdbeek/fix_limit_bypass

[24.2] Fix various job concurrency limit issues
parents b8c902a0 c088f9c0
Loading
Loading
Loading
Loading
+156 −14
Original line number Diff line number Diff line
@@ -14,6 +14,10 @@ import shutil
import sys
import time
import traceback
from dataclasses import (
    dataclass,
    field,
)
from json import loads
from typing import (
    Any,
@@ -27,7 +31,12 @@ from typing import (
import yaml
from packaging.version import Version
from pulsar.client.staging import COMMAND_VERSION_FILENAME
from sqlalchemy import select
from sqlalchemy import (
    and_,
    func,
    select,
    update,
)

from galaxy import (
    model,
@@ -296,6 +305,18 @@ def job_config_xml_to_dict(config, root):
    return config_dict


@dataclass
class JobConfigurationLimits:
    registered_user_concurrent_jobs: Optional[int] = None
    anonymous_user_concurrent_jobs: Optional[int] = None
    walltime: Optional[str] = None
    walltime_delta: Optional[datetime.timedelta] = None
    total_walltime: Dict[str, Any] = field(default_factory=dict)
    output_size: Optional[int] = None
    destination_user_concurrent_jobs: Dict[str, int] = field(default_factory=dict)
    destination_total_concurrent_jobs: Dict[str, int] = field(default_factory=dict)


class JobConfiguration(ConfiguresHandlers):
    """A parser and interface to advanced job management features.

@@ -344,16 +365,7 @@ class JobConfiguration(ConfiguresHandlers):
        self.resource_groups = {}
        self.default_resource_group = None
        self.resource_parameters = {}
        self.limits = Bunch(
            registered_user_concurrent_jobs=None,
            anonymous_user_concurrent_jobs=None,
            walltime=None,
            walltime_delta=None,
            total_walltime={},
            output_size=None,
            destination_user_concurrent_jobs={},
            destination_total_concurrent_jobs={},
        )
        self.limits = JobConfigurationLimits()

        default_resubmits = []
        default_resubmit_condition = self.app.config.default_job_resubmission_condition
@@ -1610,12 +1622,142 @@ class MinimalJobWrapper(HasResourceParameters):
        dest_params = self.job_destination.params
        return self.get_job().get_destination_configuration(dest_params, self.app.config, key, default)

    def queue_with_limit(self, job: Job, job_destination: JobDestination):
        anonymous_user_concurrent_jobs = self.app.job_config.limits.anonymous_user_concurrent_jobs
        registered_user_concurrent_jobs = self.app.job_config.limits.registered_user_concurrent_jobs
        destination_total_concurrent_jobs = self.app.job_config.limits.destination_total_concurrent_jobs
        destination_total_limit = self.app.job_config.limits.destination_total_concurrent_jobs.get(job_destination.id)
        destination_user_limit = self.app.job_config.limits.destination_user_concurrent_jobs.get(job_destination.id)
        destination_tag_limits = {}
        if job_destination.tags:
            for tag in job_destination.tags:
                if tag_limit := destination_total_concurrent_jobs.get(tag):
                    destination_tag_limits[tag] = tag_limit

        conditions = [model.Job.table.c.id == job.id]

        if job.user_id:
            user_job_count = (
                select(func.count(model.Job.table.c.id))
                .where(
                    and_(
                        model.Job.table.c.state.in_(
                            [
                                model.Job.states.QUEUED,
                                model.Job.states.RUNNING,
                                model.Job.states.RESUBMITTED,
                            ]
                        ),
                        model.Job.table.c.user_id == job.user_id,
                    )
                )
                .scalar_subquery()
            )

            if registered_user_concurrent_jobs is not None:
                conditions.append(user_job_count < registered_user_concurrent_jobs)
            if destination_user_limit is not None:
                destination_job_count = (
                    select(func.count(model.Job.table.c.id))
                    .where(
                        and_(
                            model.Job.table.c.state.in_(
                                [
                                    model.Job.states.QUEUED,
                                    model.Job.states.RUNNING,
                                    model.Job.states.RESUBMITTED,
                                ]
                            ),
                            model.Job.table.c.destination_id == job_destination.id,
                            model.Job.table.c.user_id == job.user_id,
                        )
                    )
                    .scalar_subquery()
                )
                conditions.append(destination_job_count < destination_user_limit)

        elif anonymous_user_concurrent_jobs and job.galaxy_session and job.galaxy_session.id:
            anon_job_count = (
                select(func.count(model.Job.table.c.id))
                .where(
                    and_(
                        model.Job.table.c.state.in_(
                            [
                                model.Job.states.QUEUED,
                                model.Job.states.RUNNING,
                                model.Job.states.RESUBMITTED,
                            ]
                        ),
                        model.Job.table.c.session_id == job.galaxy_session.id,
                    )
                )
                .scalar_subquery()
            )
            conditions.append(anon_job_count < anonymous_user_concurrent_jobs)

        if destination_total_limit is not None:
            destination_total_count = (
                select(func.count(model.Job.table.c.id))
                .where(
                    and_(
                        model.Job.table.c.state.in_(
                            [
                                model.Job.states.QUEUED,
                                model.Job.states.RUNNING,
                                model.Job.states.RESUBMITTED,
                            ]
                        ),
                        model.Job.table.c.destination_id == job_destination.id,
                    )
                )
                .scalar_subquery()
            )
            conditions.append(destination_total_count < destination_total_limit)

        if destination_tag_limits:
            for tag, limit in destination_tag_limits.items():
                destination_ids = {destination.id for destination in self.app.job_config.get_destinations(tag)}
                tag_count = (
                    select(func.count(model.Job.table.c.id))
                    .where(
                        and_(
                            model.Job.table.c.state.in_(
                                [
                                    model.Job.states.QUEUED,
                                    model.Job.states.RUNNING,
                                    model.Job.states.RESUBMITTED,
                                ]
                            ),
                            model.Job.table.c.destination_id.in_(destination_ids),
                        )
                    )
                    .scalar_subquery()
                )
                conditions.append(tag_count < limit)

        update_stmt = (
            update(model.Job)
            .where(*conditions)
            .values(
                state=model.Job.states.QUEUED,
                destination_id=job_destination.id,
                destination_params=job_destination.params,
                job_runner_name=job_destination.runner,
            )
        )

        result = self.sa_session.execute(update_stmt)
        self.sa_session.commit()

        return result.rowcount > 0

    def enqueue(self):
        job = self.get_job()
        # Change to queued state before handing to worker thread so the runner won't pick it up again
        if self.is_task:
            self.change_state(model.Job.states.QUEUED, flush=False, job=job)
        # Persist the destination so that the job will be included in counts if using concurrency limits
        self.set_job_destination(self.job_destination, None, flush=False, job=job)
        elif not self.queue_with_limit(job, self.job_destination):
            return False
        # Set object store after job destination so can leverage parameters...
        self._set_object_store_ids(job)
        # Now that we have the object store id, check if we are over the limit
+3 −3
Original line number Diff line number Diff line
@@ -518,6 +518,9 @@ class JobHandlerQueue(BaseJobHandlerQueue):
                pass
        # Ensure that we get new job counts on each iteration
        self.__clear_job_count()
        self.__cache_total_job_count_per_destination()
        self.__cache_user_job_count_per_destination()
        self.__cache_user_job_count()
        # Check resubmit jobs first so that limits of new jobs will still be enforced
        for job in resubmit_jobs:
            log.debug("(%s) Job was resubmitted and is being dispatched immediately", job.id)
@@ -824,7 +827,6 @@ class JobHandlerQueue(BaseJobHandlerQueue):
        self.total_job_count_per_destination = None

    def get_user_job_count(self, user_id):
        self.__cache_user_job_count()
        # This could have been incremented by a previous job dispatched on this iteration, even if we're not caching
        rval = self.user_job_count.get(user_id, 0)
        if not self.app.config.cache_user_job_count:
@@ -865,7 +867,6 @@ class JobHandlerQueue(BaseJobHandlerQueue):
            self.user_job_count = {}

    def get_user_job_count_per_destination(self, user_id):
        self.__cache_user_job_count_per_destination()
        cached = self.user_job_count_per_destination.get(user_id, {})
        if self.app.config.cache_user_job_count:
            rval = cached
@@ -1006,7 +1007,6 @@ class JobHandlerQueue(BaseJobHandlerQueue):
                self.total_job_count_per_destination[row["destination_id"]] = row["job_count"]

    def get_total_job_count_per_destination(self):
        self.__cache_total_job_count_per_destination()
        # Always use caching (at worst a job will have to wait one iteration,
        # and this would be more fair anyway as it ensures FIFO scheduling,
        # insofar as FIFO would be fair...)
+1 −3
Original line number Diff line number Diff line
@@ -451,9 +451,7 @@ class PulsarJobRunner(AsynchronousJobRunner):
            job = job_wrapper.get_job()
            # Set the job destination here (unlike other runners) because there are likely additional job destination
            # params from the Pulsar client.
            # Flush with change_state.
            job_wrapper.set_job_destination(job_destination, external_id=external_job_id, flush=False, job=job)
            job_wrapper.change_state(model.Job.states.QUEUED, job=job)
            job_wrapper.set_job_destination(job_destination, external_id=external_job_id, flush=True, job=job)
        except Exception:
            job_wrapper.fail("failure running job", exception=True)
            log.exception("failure running job %d", job_wrapper.job_id)
+2 −10
Original line number Diff line number Diff line
@@ -138,16 +138,8 @@ class SlurmJobRunner(DRMAAJobRunner):
                        ajs.job_wrapper.get_id_tag(),
                        ajs.job_id,
                    )
                    ajs.job_wrapper.change_state(
                        model.Job.states.QUEUED, info="Job was resubmitted due to node failure"
                    )
                    try:
                        self.queue_job(ajs.job_wrapper)
                    self.mark_as_resubmitted(ajs, info="Job was resubmitted due to node failure")
                    return
                    except Exception:
                        ajs.fail_message = (
                            "This job failed due to a cluster node failure, and an attempt to resubmit the job failed."
                        )
                elif slurm_state == "OUT_OF_MEMORY":
                    log.info(
                        "(%s/%s) Job hit memory limit (SLURM state: OUT_OF_MEMORY)",
+133 −0
Original line number Diff line number Diff line
from typing import Optional
from unittest.mock import Mock

from galaxy.jobs import (
    JobConfigurationLimits,
    MinimalJobWrapper,
)
from galaxy.model import (
    GalaxySession,
    Job,
)
from galaxy.model.unittest_utils import GalaxyDataTestApp
from galaxy.model.unittest_utils.data_app import GalaxyDataTestConfig


class MockJobConfig:

    def __init__(self) -> None:
        self.limits = JobConfigurationLimits()

    def get_destinations(self, tag):
        return [create_mock_destination()]


class GalaxyJobConfigApp(GalaxyDataTestApp):

    def __init__(self, config: Optional[GalaxyDataTestConfig] = None, **kwd):
        super().__init__(config, **kwd)
        self.job_config = MockJobConfig()


def create_mock_app():
    return GalaxyJobConfigApp()


def create_job_wrapper(app: GalaxyJobConfigApp, user_id=None, session_id=None):
    if session_id:
        session = GalaxySession(id=session_id)
        app.model.session.add(session)
        app.model.session.commit()
    job = create_mock_job(app, user_id, session_id)
    job_wrapper = MinimalJobWrapper(job, app)  # type: ignore[arg-type]
    return job_wrapper


def create_mock_job(app: GalaxyJobConfigApp, user_id=None, session_id=None, state="new"):
    """Creates a mock job object."""
    job = Job()
    job.user_id = user_id
    job.session_id = session_id
    job.state = state
    app.model.session.add(job)
    app.model.session.commit()
    return job


def create_mock_destination():
    """Creates a mock job destination."""
    job_destination_mock = Mock()
    job_destination_mock.id = "one"
    job_destination_mock.params = {}
    job_destination_mock.runner = "test_runner"
    job_destination_mock.tags = ["one", "two"]
    return job_destination_mock


def test_registered_user_limit():
    app = create_mock_app()
    job_wrapper = create_job_wrapper(app, user_id=1)
    job = job_wrapper.get_job()
    job_destination_mock = create_mock_destination()

    for _ in range(2):
        create_mock_job(app, user_id=1, state="running")

    # Test below limit
    job_wrapper.app.job_config.limits.registered_user_concurrent_jobs = 3
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is True

    # Test at limit
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is False


def test_anonymous_user_limit():
    app = create_mock_app()
    job_wrapper = create_job_wrapper(app, session_id=1)
    job = job_wrapper.get_job()
    job_destination_mock = create_mock_destination()

    create_mock_job(app, session_id=1, state="running")

    # Test below limit
    app.job_config.limits.anonymous_user_concurrent_jobs = 2
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is True

    # Test at limit
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is False


def test_destination_total_limit():
    app = create_mock_app()
    job_wrapper = create_job_wrapper(app, user_id=1)
    job = job_wrapper.get_job()
    job_destination_mock = create_mock_destination()

    # Test below limit
    app.job_config.limits.destination_total_concurrent_jobs["one"] = 1
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is True

    # Test at limit
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is False


def test_destination_tag_limit():
    app = create_mock_app()
    job_wrapper = create_job_wrapper(app, user_id=1)
    job = job_wrapper.get_job()
    job_destination_mock = create_mock_destination()

    # Test below limit
    app.job_config.limits.destination_total_concurrent_jobs["two"] = 1
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is True

    # Test at limit
    result = job_wrapper.queue_with_limit(job, job_destination_mock)
    assert result is False