Loading pyproject.toml +1 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ dependencies = [ "orjson==3.11.3", "confluent_kafka==2.11.1", "pyjson5==2.0.0", "psutil==7.1.0", "raps@{root:uri}/raps", ] Loading simulation_server/models/sim.py +1 −1 Original line number Diff line number Diff line Loading @@ -20,7 +20,7 @@ class Sim(BaseModel): user: Optional[str] = None """ User who launched the simulation """ system: str system: Optional[str] = None state: Optional[Literal['running', 'success', 'fail']] = None Loading simulation_server/server/main.py +11 −12 Original line number Diff line number Diff line Loading @@ -22,12 +22,12 @@ settings = AppSettings() def repeat_task(func, seconds): if not asyncio.iscoroutinefunction(func): func = functools.partial(run_in_threadpool, func) async def loop() -> None: while True: try: await func() except Exception as e: logger.exception(f"Background task failed: {e}") await asyncio.sleep(seconds) return asyncio.create_task(loop()) Loading @@ -40,13 +40,13 @@ async def lifespan(api: FastAPI): for dep in deps: api.dependency_overrides.get(dep, dep)() # TODO: Should add cleanup handler for local as well background_task_loop = None if settings.env == 'prod' and 'KUBERNETES_SERVICE_HOST' in os.environ: background_task_loop = repeat_task( lambda: cleanup_jobs(druid_engine = get_druid_engine(), kafka_producer = get_kafka_producer()), seconds = 5 * 60, async def background_task(): cleanup_jobs( druid_engine = get_druid_engine(), kafka_producer = get_kafka_producer(), settings = get_app_settings(), ) background_task_loop = repeat_task(background_task, seconds = 5) if settings.env == 'dev': kafka_admin = get_kafka_admin() Loading Loading @@ -80,8 +80,7 @@ async def lifespan(api: FastAPI): yield # if background_task_loop: # background_task_loop.cancel() background_task_loop.cancel() app = FastAPI( Loading simulation_server/server/service.py +36 −40 Original line number Diff line number Diff line from typing import Optional, Any from datetime import datetime, timedelta, timezone import psutil import functools import uuid, time, json, base64, os, sys, subprocess import sqlalchemy as sqla Loading @@ -15,14 +16,14 @@ from ..models.output import ( SCHEDULER_SIM_JOB_POWER_HISTORY_API_FIELDS, SCHEDULER_SIM_JOB_POWER_HISTORY_FIELD_SELECTORS, ) from ..util.misc import pick, omit from ..util.k8s import submit_job, get_job, get_job_state, get_job_end_time from ..util.k8s import submit_job, get_k8s_jobs, get_k8s_job_state, get_k8s_job_end_time from ..util.druid import to_timestamp, any_value, latest, execute_ignore_missing from ..util.api_queries import ( Filters, Sort, QuerySpan, Granularity, expand_field_selectors, DatetimeValidator, DEFAULT_FIELD_TYPES, ) from . import orm from .config import AppDeps, AppSettings from .config import AppDeps def wait_until_exists(stmt: sqla.Select, *, timeout: timedelta = timedelta(minutes=1), druid_engine: sqla.Engine): Loading Loading @@ -121,62 +122,57 @@ def run_simulation(sim_config: ServerSimConfig, deps: AppDeps): return sim _sim_jobs_cache: dict[str, tuple[Any, datetime]] = {} _sim_job_cache_expire = timedelta(minutes=5) def get_sim_job(sim_id: str): now = datetime.now() # Expire old entries for cid in list(_sim_jobs_cache.keys()): if (now - _sim_jobs_cache[cid][1]) > _sim_job_cache_expire: del _sim_jobs_cache[cid] if sim_id not in _sim_jobs_cache: _sim_jobs_cache[sim_id] = (get_job(f"exadigit-simulation-server-{sim_id}"), now) return _sim_jobs_cache[sim_id][0] def cleanup_jobs(druid_engine, kafka_producer): def cleanup_jobs(druid_engine, kafka_producer, settings): """ If a simulation job dies unexpectedly (e.g. OOM error), it won't be able to send the kafka message marking the sim as complete, leaving the sim stuck as running. This task checks all running sim jobs and cleans them up if their job is dead. """ if 'KUBERNETES_SERVICE_HOST' in os.environ and settings.env != 'prod': # Skip job cleanup on stage/dev k8s deployments to avoid multiple instances of the server # trying to cancel jobs return logger.info(f"Checking for stuck jobs") now = datetime.now(timezone.utc) threshold = timedelta(minutes=5) # threshold after job has ended before sending a cancel (incase the job did send its own # cancel message and it just hasn't shown up in Druid yet) threshold = timedelta(minutes=1) running_jobs = set() if 'KUBERNETES_SERVICE_HOST' in os.environ: for job in get_k8s_jobs(): if job.metadata.name.startswith('exadigit-simulation-server-'): sim_id = job.metadata.name.removeprefix('exadigit-simulation-server-') # Add a little bit of threshold to avoid potentially sending duplicate fail messages if get_k8s_job_state(job) == "running" or get_k8s_job_end_time(job) < now - threshold: running_jobs.add(sim_id) else: for proc in psutil.Process().children(): try: if 'simulation_server.simulation.main' in proc.cmdline(): sim_id = json.loads(proc.environ()["SIM"])['id'] if proc.is_running(): running_jobs.add(sim_id) except (psutil.Error): pass sims, _ = query_sims( running_sims, _ = query_sims( filters=SIM_FILTERS(state = ["eq:running"]), fields = ["id"], fields = ["all"], limit = 1000, # If somehow there's more than that we'll just get them next trigger druid_engine = druid_engine, ) stuck_ids = [] for sim in sims: job = get_sim_job(sim.id) job_state = get_job_state(job) if job_state != 'running' and (not job or get_job_end_time(job) < now - threshold): stuck_ids.append(sim.id) if stuck_ids: stuck_sims, _ = query_sims( filters = SIM_FILTERS(id = [f'one_of:{",".join(stuck_ids)}']), fields = ['all'], limit = len(stuck_ids), druid_engine = druid_engine, ) for sim in stuck_sims: for sim in running_sims: if sim.id not in running_jobs and now - sim.execution_start > threshold: sim.state = 'fail' sim.execution_end = now sim.error_messages = "Simulation crashed" logger.warning(f"Marking stuck sim {sim.id} as failed") kafka_producer.produce("svc-event-exadigit-sim", sim.serialize_for_druid()) for sim in stuck_sims: # Block until saved to make sure we don't double-send stmt = ( sqla.select(orm.sim.c.id) .where(orm.sim.c.id == sim.id, orm.sim.c.state == 'fail') Loading simulation_server/util/k8s.py +4 −4 Original line number Diff line number Diff line Loading @@ -17,9 +17,9 @@ def submit_job(job: dict): return get_batch_api().create_namespaced_job(namespace = get_namespace(), body = job) def get_job(name: str): def get_k8s_jobs(): try: return get_batch_api().read_namespaced_job(namespace = get_namespace(), name = name) return get_batch_api().list_namespaced_job(namespace = get_namespace()) except k8s.client.ApiException as e: if e.status == 404: return None Loading @@ -27,7 +27,7 @@ def get_job(name: str): raise e def get_job_state(job): def get_k8s_job_state(job): if job: if job.status.succeeded: return 'success' Loading @@ -39,6 +39,6 @@ def get_job_state(job): return 'deleted' def get_job_end_time(job): def get_k8s_job_end_time(job): # completion_time for failed jobs is null return job.status.completion_time or job.status.conditions[-1].last_transition_time Loading
pyproject.toml +1 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ dependencies = [ "orjson==3.11.3", "confluent_kafka==2.11.1", "pyjson5==2.0.0", "psutil==7.1.0", "raps@{root:uri}/raps", ] Loading
simulation_server/models/sim.py +1 −1 Original line number Diff line number Diff line Loading @@ -20,7 +20,7 @@ class Sim(BaseModel): user: Optional[str] = None """ User who launched the simulation """ system: str system: Optional[str] = None state: Optional[Literal['running', 'success', 'fail']] = None Loading
simulation_server/server/main.py +11 −12 Original line number Diff line number Diff line Loading @@ -22,12 +22,12 @@ settings = AppSettings() def repeat_task(func, seconds): if not asyncio.iscoroutinefunction(func): func = functools.partial(run_in_threadpool, func) async def loop() -> None: while True: try: await func() except Exception as e: logger.exception(f"Background task failed: {e}") await asyncio.sleep(seconds) return asyncio.create_task(loop()) Loading @@ -40,13 +40,13 @@ async def lifespan(api: FastAPI): for dep in deps: api.dependency_overrides.get(dep, dep)() # TODO: Should add cleanup handler for local as well background_task_loop = None if settings.env == 'prod' and 'KUBERNETES_SERVICE_HOST' in os.environ: background_task_loop = repeat_task( lambda: cleanup_jobs(druid_engine = get_druid_engine(), kafka_producer = get_kafka_producer()), seconds = 5 * 60, async def background_task(): cleanup_jobs( druid_engine = get_druid_engine(), kafka_producer = get_kafka_producer(), settings = get_app_settings(), ) background_task_loop = repeat_task(background_task, seconds = 5) if settings.env == 'dev': kafka_admin = get_kafka_admin() Loading Loading @@ -80,8 +80,7 @@ async def lifespan(api: FastAPI): yield # if background_task_loop: # background_task_loop.cancel() background_task_loop.cancel() app = FastAPI( Loading
simulation_server/server/service.py +36 −40 Original line number Diff line number Diff line from typing import Optional, Any from datetime import datetime, timedelta, timezone import psutil import functools import uuid, time, json, base64, os, sys, subprocess import sqlalchemy as sqla Loading @@ -15,14 +16,14 @@ from ..models.output import ( SCHEDULER_SIM_JOB_POWER_HISTORY_API_FIELDS, SCHEDULER_SIM_JOB_POWER_HISTORY_FIELD_SELECTORS, ) from ..util.misc import pick, omit from ..util.k8s import submit_job, get_job, get_job_state, get_job_end_time from ..util.k8s import submit_job, get_k8s_jobs, get_k8s_job_state, get_k8s_job_end_time from ..util.druid import to_timestamp, any_value, latest, execute_ignore_missing from ..util.api_queries import ( Filters, Sort, QuerySpan, Granularity, expand_field_selectors, DatetimeValidator, DEFAULT_FIELD_TYPES, ) from . import orm from .config import AppDeps, AppSettings from .config import AppDeps def wait_until_exists(stmt: sqla.Select, *, timeout: timedelta = timedelta(minutes=1), druid_engine: sqla.Engine): Loading Loading @@ -121,62 +122,57 @@ def run_simulation(sim_config: ServerSimConfig, deps: AppDeps): return sim _sim_jobs_cache: dict[str, tuple[Any, datetime]] = {} _sim_job_cache_expire = timedelta(minutes=5) def get_sim_job(sim_id: str): now = datetime.now() # Expire old entries for cid in list(_sim_jobs_cache.keys()): if (now - _sim_jobs_cache[cid][1]) > _sim_job_cache_expire: del _sim_jobs_cache[cid] if sim_id not in _sim_jobs_cache: _sim_jobs_cache[sim_id] = (get_job(f"exadigit-simulation-server-{sim_id}"), now) return _sim_jobs_cache[sim_id][0] def cleanup_jobs(druid_engine, kafka_producer): def cleanup_jobs(druid_engine, kafka_producer, settings): """ If a simulation job dies unexpectedly (e.g. OOM error), it won't be able to send the kafka message marking the sim as complete, leaving the sim stuck as running. This task checks all running sim jobs and cleans them up if their job is dead. """ if 'KUBERNETES_SERVICE_HOST' in os.environ and settings.env != 'prod': # Skip job cleanup on stage/dev k8s deployments to avoid multiple instances of the server # trying to cancel jobs return logger.info(f"Checking for stuck jobs") now = datetime.now(timezone.utc) threshold = timedelta(minutes=5) # threshold after job has ended before sending a cancel (incase the job did send its own # cancel message and it just hasn't shown up in Druid yet) threshold = timedelta(minutes=1) running_jobs = set() if 'KUBERNETES_SERVICE_HOST' in os.environ: for job in get_k8s_jobs(): if job.metadata.name.startswith('exadigit-simulation-server-'): sim_id = job.metadata.name.removeprefix('exadigit-simulation-server-') # Add a little bit of threshold to avoid potentially sending duplicate fail messages if get_k8s_job_state(job) == "running" or get_k8s_job_end_time(job) < now - threshold: running_jobs.add(sim_id) else: for proc in psutil.Process().children(): try: if 'simulation_server.simulation.main' in proc.cmdline(): sim_id = json.loads(proc.environ()["SIM"])['id'] if proc.is_running(): running_jobs.add(sim_id) except (psutil.Error): pass sims, _ = query_sims( running_sims, _ = query_sims( filters=SIM_FILTERS(state = ["eq:running"]), fields = ["id"], fields = ["all"], limit = 1000, # If somehow there's more than that we'll just get them next trigger druid_engine = druid_engine, ) stuck_ids = [] for sim in sims: job = get_sim_job(sim.id) job_state = get_job_state(job) if job_state != 'running' and (not job or get_job_end_time(job) < now - threshold): stuck_ids.append(sim.id) if stuck_ids: stuck_sims, _ = query_sims( filters = SIM_FILTERS(id = [f'one_of:{",".join(stuck_ids)}']), fields = ['all'], limit = len(stuck_ids), druid_engine = druid_engine, ) for sim in stuck_sims: for sim in running_sims: if sim.id not in running_jobs and now - sim.execution_start > threshold: sim.state = 'fail' sim.execution_end = now sim.error_messages = "Simulation crashed" logger.warning(f"Marking stuck sim {sim.id} as failed") kafka_producer.produce("svc-event-exadigit-sim", sim.serialize_for_druid()) for sim in stuck_sims: # Block until saved to make sure we don't double-send stmt = ( sqla.select(orm.sim.c.id) .where(orm.sim.c.id == sim.id, orm.sim.c.state == 'fail') Loading
simulation_server/util/k8s.py +4 −4 Original line number Diff line number Diff line Loading @@ -17,9 +17,9 @@ def submit_job(job: dict): return get_batch_api().create_namespaced_job(namespace = get_namespace(), body = job) def get_job(name: str): def get_k8s_jobs(): try: return get_batch_api().read_namespaced_job(namespace = get_namespace(), name = name) return get_batch_api().list_namespaced_job(namespace = get_namespace()) except k8s.client.ApiException as e: if e.status == 404: return None Loading @@ -27,7 +27,7 @@ def get_job(name: str): raise e def get_job_state(job): def get_k8s_job_state(job): if job: if job.status.succeeded: return 'success' Loading @@ -39,6 +39,6 @@ def get_job_state(job): return 'deleted' def get_job_end_time(job): def get_k8s_job_end_time(job): # completion_time for failed jobs is null return job.status.completion_time or job.status.conditions[-1].last_transition_time