Commit 1cb53125 authored by Hines, Jesse's avatar Hines, Jesse
Browse files

Clean up stuck sims on local

parent a6f7fed5
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ dependencies = [
    "orjson==3.11.3",
    "confluent_kafka==2.11.1",
    "pyjson5==2.0.0",
    "psutil==7.1.0",
    "raps@{root:uri}/raps",
]

+1 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ class Sim(BaseModel):
    user: Optional[str] = None
    """ User who launched the simulation """

    system: str
    system: Optional[str] = None

    state: Optional[Literal['running', 'success', 'fail']] = None

+11 −12
Original line number Diff line number Diff line
@@ -22,12 +22,12 @@ settings = AppSettings()


def repeat_task(func, seconds):
    if not asyncio.iscoroutinefunction(func):
        func = functools.partial(run_in_threadpool, func)

    async def loop() -> None:
        while True:
            try:
                await func()
            except Exception as e:
                logger.exception(f"Background task failed: {e}")
            await asyncio.sleep(seconds)

    return asyncio.create_task(loop())
@@ -40,13 +40,13 @@ async def lifespan(api: FastAPI):
    for dep in deps:
        api.dependency_overrides.get(dep, dep)()

    # TODO: Should add cleanup handler for local as well
    background_task_loop = None
    if settings.env == 'prod' and 'KUBERNETES_SERVICE_HOST' in os.environ:
        background_task_loop = repeat_task(
            lambda: cleanup_jobs(druid_engine = get_druid_engine(), kafka_producer = get_kafka_producer()),
            seconds = 5 * 60,
    async def background_task():
        cleanup_jobs(
            druid_engine = get_druid_engine(),
            kafka_producer = get_kafka_producer(),
            settings = get_app_settings(),
        )
    background_task_loop = repeat_task(background_task, seconds = 5)

    if settings.env == 'dev':
        kafka_admin = get_kafka_admin()
@@ -80,8 +80,7 @@ async def lifespan(api: FastAPI):

    yield

    # if background_task_loop:
    #     background_task_loop.cancel()
    background_task_loop.cancel()


app = FastAPI(
+36 −40
Original line number Diff line number Diff line
from typing import Optional, Any
from datetime import datetime, timedelta, timezone
import psutil
import functools
import uuid, time, json, base64, os, sys, subprocess
import sqlalchemy as sqla
@@ -15,14 +16,14 @@ from ..models.output import (
    SCHEDULER_SIM_JOB_POWER_HISTORY_API_FIELDS, SCHEDULER_SIM_JOB_POWER_HISTORY_FIELD_SELECTORS,
)
from ..util.misc import pick, omit
from ..util.k8s import submit_job, get_job, get_job_state, get_job_end_time
from ..util.k8s import submit_job, get_k8s_jobs, get_k8s_job_state, get_k8s_job_end_time
from ..util.druid import to_timestamp, any_value, latest, execute_ignore_missing
from ..util.api_queries import (
    Filters, Sort, QuerySpan, Granularity, expand_field_selectors, DatetimeValidator,
    DEFAULT_FIELD_TYPES,
)
from . import orm
from .config import AppDeps, AppSettings
from .config import AppDeps


def wait_until_exists(stmt: sqla.Select, *, timeout: timedelta = timedelta(minutes=1), druid_engine: sqla.Engine):
@@ -121,62 +122,57 @@ def run_simulation(sim_config: ServerSimConfig, deps: AppDeps):
    return sim


_sim_jobs_cache: dict[str, tuple[Any, datetime]] = {}
_sim_job_cache_expire = timedelta(minutes=5)
def get_sim_job(sim_id: str):
    now = datetime.now()
    # Expire old entries
    for cid in list(_sim_jobs_cache.keys()):
        if (now - _sim_jobs_cache[cid][1]) > _sim_job_cache_expire:
            del _sim_jobs_cache[cid]

    if sim_id not in _sim_jobs_cache:
        _sim_jobs_cache[sim_id] = (get_job(f"exadigit-simulation-server-{sim_id}"), now)

    return _sim_jobs_cache[sim_id][0]


def cleanup_jobs(druid_engine, kafka_producer):
def cleanup_jobs(druid_engine, kafka_producer, settings):
    """
    If a simulation job dies unexpectedly (e.g. OOM error), it won't be able to send the kafka
    message marking the sim as complete, leaving the sim stuck as running. This task checks all
    running sim jobs and cleans them up if their job is dead.
    """
    if 'KUBERNETES_SERVICE_HOST' in os.environ and settings.env != 'prod':
        # Skip job cleanup on stage/dev k8s deployments to avoid multiple instances of the server
        # trying to cancel jobs
        return
    logger.info(f"Checking for stuck jobs")

    now = datetime.now(timezone.utc)
    threshold = timedelta(minutes=5)
    # threshold after job has ended before sending a cancel (incase the job did send its own
    # cancel message and it just hasn't shown up in Druid yet)
    threshold = timedelta(minutes=1)

    running_jobs = set()
    if 'KUBERNETES_SERVICE_HOST' in os.environ:
        for job in get_k8s_jobs():
            if job.metadata.name.startswith('exadigit-simulation-server-'):
                sim_id = job.metadata.name.removeprefix('exadigit-simulation-server-')
                # Add a little bit of threshold to avoid potentially sending duplicate fail messages
                if get_k8s_job_state(job) == "running" or get_k8s_job_end_time(job) < now - threshold:
                    running_jobs.add(sim_id)
    else:
        for proc in psutil.Process().children():
            try:
                if 'simulation_server.simulation.main' in proc.cmdline():
                    sim_id = json.loads(proc.environ()["SIM"])['id']
                    if proc.is_running():
                        running_jobs.add(sim_id)
            except (psutil.Error):
                pass

    sims, _ = query_sims(
    running_sims, _ = query_sims(
        filters=SIM_FILTERS(state = ["eq:running"]),
        fields = ["id"],
        fields = ["all"],
        limit = 1000, # If somehow there's more than that we'll just get them next trigger
        druid_engine = druid_engine,
    )

    stuck_ids = []
    for sim in sims:
        job = get_sim_job(sim.id)
        job_state = get_job_state(job)
        if job_state != 'running' and (not job or get_job_end_time(job) < now - threshold):
            stuck_ids.append(sim.id)

    if stuck_ids:
        stuck_sims, _ = query_sims(
            filters = SIM_FILTERS(id = [f'one_of:{",".join(stuck_ids)}']),
            fields = ['all'],
            limit = len(stuck_ids),
            druid_engine = druid_engine,
        )
    
        for sim in stuck_sims:
    for sim in running_sims:
        if sim.id not in running_jobs and now - sim.execution_start > threshold:
            sim.state = 'fail'
            sim.execution_end = now
            sim.error_messages = "Simulation crashed"
            logger.warning(f"Marking stuck sim {sim.id} as failed")
            kafka_producer.produce("svc-event-exadigit-sim", sim.serialize_for_druid())

        for sim in stuck_sims:
            # Block until saved to make sure we don't double-send
            stmt = (
                sqla.select(orm.sim.c.id)
                    .where(orm.sim.c.id == sim.id, orm.sim.c.state == 'fail')
+4 −4
Original line number Diff line number Diff line
@@ -17,9 +17,9 @@ def submit_job(job: dict):
    return get_batch_api().create_namespaced_job(namespace = get_namespace(), body = job)


def get_job(name: str):
def get_k8s_jobs():
    try:
        return get_batch_api().read_namespaced_job(namespace = get_namespace(), name = name)
        return get_batch_api().list_namespaced_job(namespace = get_namespace())
    except k8s.client.ApiException as e:
        if e.status == 404:
            return None
@@ -27,7 +27,7 @@ def get_job(name: str):
            raise e


def get_job_state(job):
def get_k8s_job_state(job):
    if job:
        if job.status.succeeded:
            return 'success'
@@ -39,6 +39,6 @@ def get_job_state(job):
        return 'deleted'


def get_job_end_time(job):
def get_k8s_job_end_time(job):
    # completion_time for failed jobs is null
    return job.status.completion_time or job.status.conditions[-1].last_transition_time