Commit 7994b0eb authored by Brewer, Wes's avatar Brewer, Wes
Browse files

Add JSONL output, demo controls, and queued job visibility

parent dcc59ba0
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ def _initialize_raps_sim(site_name: str, sim_config_path: str, sim_time: str = "
        output="none",           # No file output
        workload="random",       # Generates initial random jobs
        policy="fcfs",           # First-come-first-served for dynamic jobs
        backfill="easy",         # Allow smaller jobs to bypass blocking jobs
    )
    engine = Engine(sim_config)

@@ -111,6 +112,18 @@ def _poll_site_metrics(engine, tick: int, tick_data=None) -> Optional[Dict[str,
            'account': getattr(job, 'account', ''),
        })

    # Build top queued jobs list (up to 5)
    top_queued_jobs = []
    for job in engine.queue[:5]:
        top_queued_jobs.append({
            'id': job.id,
            'name': str(job.name),
            'nodes_required': job.nodes_required,
            'wall_time': getattr(job, 'time_limit', getattr(job, 'wall_time', 0)),
            'state': 'PD',
            'account': getattr(job, 'account', ''),
        })

    # Extract richer metrics from tick_data if available
    system_util = 0.0
    total_power_kw = 0.0
@@ -137,6 +150,7 @@ def _poll_site_metrics(engine, tick: int, tick_data=None) -> Optional[Dict[str,
        'p_flops': p_flops,
        'g_flops_w': g_flops_w,
        'top_running_jobs': top_running_jobs,
        'top_queued_jobs': top_queued_jobs,
        'sim_time': _get_sim_time(engine),
        'tick': tick,
    }
+84 −8
Original line number Diff line number Diff line
@@ -6,10 +6,11 @@ per-site columns showing running jobs and system metrics.
"""

import time
from dataclasses import dataclass, field
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Optional

from rich.align import Align
from rich.console import Group
from rich.layout import Layout
from rich.panel import Panel
from rich.table import Table
@@ -37,6 +38,7 @@ class SiteState:
    p_flops: float = 0.0
    g_flops_w: float = 0.0
    top_running_jobs: List[dict] = field(default_factory=list)
    top_queued_jobs: List[dict] = field(default_factory=list)
    sim_time: Optional[int] = None
    tick: int = 0
    last_update: float = field(default_factory=time.time)
@@ -97,6 +99,7 @@ class FederationDashboard:
        }
        self.meta = MetaSchedulerState()
        self.layout = self._build_layout()
        self._rr_index = 0

    # ------------------------------------------------------------------
    # Layout construction
@@ -240,7 +243,21 @@ class FederationDashboard:
        for col in ["JOBID", "NAME", "NODES", "WALL", "RUN", "ST"]:
            table.add_column(col, justify="center", no_wrap=True)

        combined = []
        for job in site.top_running_jobs[:10]:
            combined.append({
                **job,
                "state": job.get("state", "R"),
                "run_time": job.get("run_time", 0),
            })
        for job in site.top_queued_jobs[:5]:
            combined.append({
                **job,
                "state": "PD",
                "run_time": 0,
            })

        for job in combined:
            wall_str = convert_seconds_to_hhmm(job.get('wall_time', 0))
            run_str = convert_seconds_to_hhmm(job.get('run_time', 0))
            # Truncate UUID-style IDs to 8 chars; zero-pad short integer IDs
@@ -257,7 +274,8 @@ class FederationDashboard:
            )

        count = len(site.top_running_jobs)
        return Panel(table, title=title, subtitle=f"Jobs: {count}", border_style="cyan")
        qcount = len(site.top_queued_jobs)
        return Panel(table, title=title, subtitle=f"Run: {count} | PD: {qcount}", border_style="cyan")

    def _render_site_stats(self, site_name: str) -> Panel:
        site = self.sites[site_name]
@@ -320,6 +338,7 @@ class FederationDashboard:
                site.p_flops = msg.get('p_flops', site.p_flops)
                site.g_flops_w = msg.get('g_flops_w', site.g_flops_w)
                site.top_running_jobs = msg.get('top_running_jobs', site.top_running_jobs)
                site.top_queued_jobs = msg.get('top_queued_jobs', site.top_queued_jobs)
                site.sim_time = msg.get('sim_time', site.sim_time)
                site.tick = msg.get('tick', site.tick)
                site.last_update = time.time()
@@ -365,7 +384,8 @@ class FederationDashboard:
            submit_wall_time=time.time(),
        ))

    def try_dispatch(self, ms) -> int:
    def try_dispatch(self, ms, max_dispatch: Optional[int] = None,
                     policy: str = "max_free") -> int:
        """
        Try to dispatch WAITING jobs to sites with enough free nodes.
        Returns the number of jobs dispatched this cycle.
@@ -373,17 +393,40 @@ class FederationDashboard:
        dispatched = 0
        still_waiting = []
        for pj in self.meta.pending_jobs:
            if max_dispatch is not None and max_dispatch > 0 and dispatched >= max_dispatch:
                still_waiting.append(pj)
                continue
            if pj.status != "WAITING":
                still_waiting.append(pj)
                continue
            # Find a site with enough free nodes
            best_site = None
            best_free = -1
            eligible = []
            for name in self.site_names:
                site = self.sites[name]
                if site.status not in ("READY", "RUNNING"):
                    continue
                if site.free_nodes >= pj.nodes and site.free_nodes > best_free:
                if site.free_nodes >= pj.nodes:
                    eligible.append(name)

            best_site = None
            if eligible:
                if policy == "round_robin":
                    start = self._rr_index % len(self.site_names)
                    for offset in range(len(self.site_names)):
                        name = self.site_names[(start + offset) % len(self.site_names)]
                        if name in eligible:
                            best_site = name
                            self._rr_index = (start + offset + 1) % len(self.site_names)
                            break
                elif policy == "random":
                    import random
                    best_site = random.choice(eligible)
                else:
                    # default: pick the site with most free nodes
                    best_free = -1
                    for name in eligible:
                        site = self.sites[name]
                        if site.free_nodes > best_free:
                            best_site = name
                            best_free = site.free_nodes
            if best_site is not None:
@@ -412,6 +455,39 @@ class FederationDashboard:
        """Called by the main loop when ms.submit() raises PermissionError."""
        self.meta.iam_denials += 1

    def waiting_count(self) -> int:
        return sum(1 for pj in self.meta.pending_jobs if pj.status == "WAITING")

    # ------------------------------------------------------------------
    # Structured snapshot
    # ------------------------------------------------------------------

    def snapshot(self) -> dict:
        """Return a JSON-serializable snapshot of current federation state."""
        meta = asdict(self.meta)
        sites = {name: asdict(site) for name, site in self.sites.items()}

        waiting_jobs = sum(1 for pj in self.meta.pending_jobs if pj.status == "WAITING")
        fed_pflops = sum(s.p_flops for s in self.sites.values())
        fed_power_mw = sum(s.total_power_kw for s in self.sites.values()) / 1000.0
        fed_running = sum(s.running_jobs for s in self.sites.values())
        fed_queued = sum(s.queued_jobs for s in self.sites.values())
        active_sites = [s for s in self.sites.values() if s.system_util > 0]
        avg_util = (sum(s.system_util for s in active_sites) / len(active_sites)) if active_sites else 0.0

        meta["waiting_jobs"] = waiting_jobs
        meta["fed_pflops"] = fed_pflops
        meta["fed_power_mw"] = fed_power_mw
        meta["fed_running_jobs"] = fed_running
        meta["fed_queued_jobs"] = fed_queued
        meta["avg_util"] = avg_util

        return {
            "site_names": list(self.site_names),
            "meta": meta,
            "sites": sites,
        }

    # ------------------------------------------------------------------
    # Layout update
    # ------------------------------------------------------------------
+84 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
"""
Quick inspector for federation JSONL output.

Usage:
  python scripts/inspect_federation_json.py /tmp/fed.jsonl
"""

import argparse
import json
from collections import Counter


def _fmt_ts(ts):
    if ts is None:
        return "n/a"
    try:
        return f"{ts:.3f}"
    except Exception:
        return str(ts)


def main() -> int:
    parser = argparse.ArgumentParser(description="Inspect federation JSONL output")
    parser.add_argument("path", help="Path to federation JSONL file")
    parser.add_argument("--max-events", type=int, default=3,
                        help="Show up to N most recent events (default: 3)")
    args = parser.parse_args()

    counts = Counter()
    last_snapshot = None
    last_events = []
    first_ts = None
    last_ts = None

    with open(args.path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            typ = obj.get("type", "unknown")
            counts[typ] += 1
            ts = obj.get("ts")
            if first_ts is None:
                first_ts = ts
            last_ts = ts
            if typ == "snapshot":
                last_snapshot = obj
            elif typ == "event":
                last_events.append(obj)
                if len(last_events) > args.max_events:
                    last_events.pop(0)

    print("Summary")
    print(f"  path: {args.path}")
    print(f"  types: {dict(counts)}")
    print(f"  first_ts: {_fmt_ts(first_ts)}")
    print(f"  last_ts: {_fmt_ts(last_ts)}")

    if last_snapshot:
        state = last_snapshot.get("state", {})
        meta = state.get("meta", {})
        sites = state.get("sites", {})
        print("")
        print("Last snapshot")
        print(f"  waiting_jobs: {meta.get('waiting_jobs')}")
        print(f"  fed_running_jobs: {meta.get('fed_running_jobs')}")
        print(f"  fed_queued_jobs: {meta.get('fed_queued_jobs')}")
        print(f"  avg_util: {meta.get('avg_util')}")
        print(f"  sites: {list(sites.keys())}")

    if last_events:
        print("")
        print("Recent events")
        for ev in last_events:
            msg = ev.get("msg", {})
            print(f"  ts={_fmt_ts(ev.get('ts'))} event={msg.get('event')} site={msg.get('site')}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
+124 −13
Original line number Diff line number Diff line
@@ -8,8 +8,11 @@ and renders a Rich Live dashboard showing aggregate + per-site metrics.
"""

import argparse
import json
import os
import random
import time
from typing import Optional

from raps.metasched.iam import AccessToken, IAMPolicyEngine
from raps.metasched.metascheduler import MetaScheduler
@@ -31,9 +34,9 @@ def _build_token() -> AccessToken:
_NODE_CHOICES = [32, 64, 128, 256, 512, 1024]


def _make_job(index: int) -> FedJob:
def _make_job(index: int, node_choices: list[int]) -> FedJob:
    return FedJob(
        nodes_required=random.choice(_NODE_CHOICES),
        nodes_required=random.choice(node_choices),
        wall_time_s=random.choice([1800, 3600, 7200]),
        name=f"fed-{index}",
        meta={
@@ -48,9 +51,10 @@ def _make_job(index: int) -> FedJob:
# --------------------------------------------------------------------------

def _run_text_mode(ms: MetaScheduler, token: AccessToken, num_jobs: int,
                   listen_seconds: int):
                   listen_seconds: int, json_out: Optional[dict],
                   node_choices: list[int]):
    for i in range(num_jobs):
        job = _make_job(i)
        job = _make_job(i, node_choices)
        try:
            site = ms.submit(job, token=token)
            print(f"submitted {job.job_id} by {token.subject} -> {site}")
@@ -64,6 +68,12 @@ def _run_text_mode(ms: MetaScheduler, token: AccessToken, num_jobs: int,
            ev = msg.get("event")
            if ev in ("HEARTBEAT", "METRICS", "ERROR", "IAM_DENY"):
                print(msg)
            if json_out and json_out["mode"] in ("events", "both"):
                json_out["write"]({
                    "type": "event",
                    "ts": time.time(),
                    "msg": msg,
                })
        time.sleep(0.2)


@@ -72,15 +82,18 @@ def _run_text_mode(ms: MetaScheduler, token: AccessToken, num_jobs: int,
# --------------------------------------------------------------------------

def _run_dashboard(ms: MetaScheduler, token: AccessToken, num_jobs: int,
                   site_names: list):
                   site_names: list, json_out: Optional[dict],
                   submit_interval: float, dispatch_interval: float,
                   max_dispatch: int, dispatch_policy: str,
                   node_choices: list[int], max_waiting: int):
    from rich.live import Live
    from raps.ui.federation import FederationDashboard

    dashboard = FederationDashboard(site_names)
    dashboard.update_layout()

    submit_interval = 2.0  # seconds between job submissions
    last_submit = 0.0
    last_dispatch = 0.0
    next_job = 0

    with Live(dashboard.layout, refresh_per_second=4, screen=True):
@@ -89,20 +102,47 @@ def _run_dashboard(ms: MetaScheduler, token: AccessToken, num_jobs: int,

            # Add a new job to the metascheduler queue at regular intervals
            if (num_jobs == 0 or next_job < num_jobs) and (now - last_submit) >= submit_interval:
                job = _make_job(next_job)
                if max_waiting <= 0 or dashboard.waiting_count() < max_waiting:
                    job = _make_job(next_job, node_choices)
                    dashboard.enqueue_job(job.job_id, job.name, job.nodes_required)
                    if json_out and json_out["mode"] in ("events", "both"):
                        json_out["write"]({
                            "type": "local_enqueue",
                            "ts": now,
                            "job_id": job.job_id,
                            "name": job.name,
                            "nodes_required": job.nodes_required,
                            "wall_time_s": job.wall_time_s,
                        })
                    next_job += 1
                    last_submit = now

            # Try to dispatch waiting jobs to sites with capacity
            dashboard.try_dispatch(ms)
            if dispatch_interval <= 0 or (now - last_dispatch) >= dispatch_interval:
                dashboard.try_dispatch(ms, max_dispatch=max_dispatch, policy=dispatch_policy)
                last_dispatch = now

            # Drain the status queue
            for msg in ms.poll_status(200):
                dashboard.process_event(msg)
                if json_out and json_out["mode"] in ("events", "both"):
                    json_out["write"]({
                        "type": "event",
                        "ts": time.time(),
                        "msg": msg,
                    })

            dashboard.update_layout()

            if json_out and json_out["mode"] in ("snapshot", "both"):
                if now >= json_out["next_snapshot"]:
                    json_out["write"]({
                        "type": "snapshot",
                        "ts": now,
                        "state": dashboard.snapshot(),
                    })
                    json_out["next_snapshot"] = now + json_out["interval"]

            # Exit when all site simulations have completed
            if all(dashboard.sites[s].status in ("SIMULATION_COMPLETE", "STOPPED", "ERROR")
                   for s in site_names):
@@ -125,6 +165,28 @@ def main():
                        help="Seconds to listen after last submission in --noui mode (default: 60)")
    parser.add_argument("--sim-time", type=str, default="24h",
                        help="Simulated time per site (default: 24h)")
    parser.add_argument("--submit-interval", type=float, default=2.0,
                        help="Seconds between job submissions (default: 2.0)")
    parser.add_argument("--dispatch-interval", type=float, default=0.0,
                        help="Seconds between dispatch attempts (default: 0 = every loop)")
    parser.add_argument("--max-dispatch", type=int, default=0,
                        help="Max jobs to dispatch per attempt (default: 0 = unlimited)")
    parser.add_argument("--dispatch-policy", type=str, default="max_free",
                        choices=["max_free", "round_robin", "random"],
                        help="Dispatch site selection policy (default: max_free)")
    parser.add_argument("--max-waiting", type=int, default=0,
                        help="Cap WAITING jobs in metascheduler queue (default: 0 = unlimited)")
    parser.add_argument("--node-choices", type=str, default="",
                        help="Comma-separated nodes choices (default: built-in list)")
    parser.add_argument("--demo", action="store_true",
                        help="Use a preset that builds a visible queue for demos")
    parser.add_argument("--output-json", type=str, default="",
                        help="Write JSONL output to this path (default: disabled)")
    parser.add_argument("--output-mode", type=str, default="both",
                        choices=["events", "snapshot", "both"],
                        help="JSON output mode (default: both)")
    parser.add_argument("--output-interval", type=float, default=1.0,
                        help="Seconds between snapshots (default: 1.0)")
    args = parser.parse_args()

    sites = {
@@ -140,17 +202,66 @@ def main():
            "perlmutter": 2,
        }
    )
    if args.demo:
        args.submit_interval = 0.2
        args.dispatch_interval = 1.0
        args.max_dispatch = 2
        args.dispatch_policy = "round_robin"
        args.max_waiting = 10
        args.node_choices = "8,16,32,64,128"

    if args.node_choices:
        node_choices = [int(x) for x in args.node_choices.split(",") if x.strip()]
    else:
        node_choices = list(_NODE_CHOICES)

    ms = MetaScheduler(sites, iam_policy=iam, sim_time=args.sim_time)
    token = _build_token()

    ms.start()
    json_out = None
    if args.output_json:
        out_dir = os.path.dirname(args.output_json)
        if out_dir:
            os.makedirs(out_dir, exist_ok=True)
        out_f = open(args.output_json, "w", buffering=1)
        def _write(obj: dict):
            out_f.write(json.dumps(obj) + "\n")
        json_out = {
            "file": out_f,
            "write": _write,
            "mode": args.output_mode,
            "interval": max(0.1, args.output_interval),
            "next_snapshot": 0.0,
        }
        json_out["write"]({
            "type": "start",
            "ts": time.time(),
            "sites": list(sites.keys()),
            "sim_time": args.sim_time,
            "num_jobs": args.num_jobs,
            "noui": args.noui,
        })

    try:
        if args.noui:
            _run_text_mode(ms, token, args.num_jobs, args.listen_seconds)
            _run_text_mode(ms, token, args.num_jobs, args.listen_seconds, json_out,
                           node_choices)
        else:
            _run_dashboard(ms, token, args.num_jobs, list(sites.keys()))
            _run_dashboard(ms, token, args.num_jobs, list(sites.keys()), json_out,
                           submit_interval=args.submit_interval,
                           dispatch_interval=args.dispatch_interval,
                           max_dispatch=args.max_dispatch,
                           dispatch_policy=args.dispatch_policy,
                           node_choices=node_choices,
                           max_waiting=args.max_waiting)
    finally:
        if json_out:
            json_out["write"]({
                "type": "end",
                "ts": time.time(),
            })
            json_out["file"].close()
        ms.stop()