Commit 0e7561d2 authored by Brewer, Wes's avatar Brewer, Wes
Browse files

Add federation dashboard with Rich Live TUI for multi-site view



- Restructure raps/ui.py into raps/ui/ package (backward-compat preserved)
- Add federation dashboard (raps/ui/federation.py) with per-site job tables,
  system stats, metascheduler aggregate panel, and job queue display
- Expand site_worker metrics: richer METRICS events every 10 ticks with
  system_util, power, PFLOPS, GFL/W, top running jobs, down nodes
- Configurable sim duration (--sim-time, default 24h) threaded through
  MetaScheduler -> site_worker -> Engine
- Add all three sites (including perlmutter) to IAM-eligible targets
- run_federation.py supports --noui, --num-jobs, --listen-seconds, --sim-time

Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent dc5a4b62
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -5,11 +5,14 @@ from .site_worker import site_worker_main
from .iam import AccessToken, IAMPolicyEngine

class MetaScheduler:
    def __init__(self, sites: Dict[str, str], iam_policy: Optional[IAMPolicyEngine] = None):
    def __init__(self, sites: Dict[str, str], iam_policy: Optional[IAMPolicyEngine] = None,
                 sim_time: str = "24h"):
        """
        sites: {site_name: sim_config_path}
        sim_time: simulation duration string passed to each site worker (e.g. "1h", "24h", "7d")
        """
        self.sites = sites
        self.sim_time = sim_time
        self.job_queues: Dict[str, Queue] = {s: Queue() for s in sites}
        self.stop_queues: Dict[str, Queue] = {s: Queue() for s in sites}
        self.status_q: Queue = Queue()
@@ -23,7 +26,8 @@ class MetaScheduler:
        for site, cfg_path in self.sites.items():
            p = Process(
                target=site_worker_main,
                args=(site, cfg_path, self.job_queues[site], self.status_q, self.stop_queues[site]),
                args=(site, cfg_path, self.job_queues[site], self.status_q,
                      self.stop_queues[site], self.sim_time),
                daemon=True,
            )
            p.start()
+46 −8
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional
import time


def _initialize_raps_sim(site_name: str, sim_config_path: str):
def _initialize_raps_sim(site_name: str, sim_config_path: str, sim_time: str = "24h"):
    """
    Initialize a RAPS Engine directly in this process.
    Returns the engine instance for direct manipulation.
@@ -11,6 +11,7 @@ def _initialize_raps_sim(site_name: str, sim_config_path: str):
    sim_config_path can be:
    - A path to a SystemConfig YAML file (e.g., config/frontier.yaml)
    - A built-in system name (e.g., "frontier")
    sim_time: simulation duration string (e.g. "1h", "24h", "7d")
    """
    from raps.sim_config import SingleSimConfig
    from raps.engine import Engine
@@ -19,7 +20,7 @@ def _initialize_raps_sim(site_name: str, sim_config_path: str):
    # Using sensible defaults for federation simulation
    sim_config = SingleSimConfig(
        system=sim_config_path,  # Path to SystemConfig or system name
        time="1h",               # Simulate 1 hour by default
        time=sim_time,           # Simulation duration
        numjobs=1,               # Minimal jobs to initialize (more will be injected)
        noui=True,               # No UI in worker process
        output="none",           # No file output
@@ -83,20 +84,56 @@ def _get_sim_time(engine) -> Optional[int]:
    return getattr(engine, 'current_timestep', None)


def _poll_site_metrics(engine, tick: int) -> Optional[Dict[str, Any]]:
def _poll_site_metrics(engine, tick: int, tick_data=None) -> Optional[Dict[str, Any]]:
    """
    Return current site metrics every N ticks.
    Returns None if not time to report yet.
    """
    # Report metrics every 100 ticks
    if tick % 100 != 0:
    # Report metrics every 10 ticks
    if tick % 10 != 0:
        return None

    # Build top running jobs list (up to 10)
    top_running_jobs = []
    for job in engine.running[:10]:
        top_running_jobs.append({
            'id': job.id,
            'name': str(job.name),
            'nodes_required': job.nodes_required,
            'wall_time': getattr(job, 'time_limit', getattr(job, 'wall_time', 0)),
            'run_time': getattr(job, 'current_run_time', 0),
            'state': job.current_state.value if hasattr(job.current_state, 'value') else str(job.current_state),
            'account': getattr(job, 'account', ''),
        })

    # Extract richer metrics from tick_data if available
    system_util = 0.0
    total_power_kw = 0.0
    p_flops = 0.0
    g_flops_w = 0.0
    down_nodes_count = 0

    if tick_data is not None:
        system_util = tick_data.system_util or 0.0
        p_flops = tick_data.p_flops or 0.0
        g_flops_w = tick_data.g_flops_w or 0.0
        down_nodes_count = len(tick_data.down_nodes) if tick_data.down_nodes else 0
        # total_power_kw from engine's recorded value
        total_power_kw = getattr(engine, 'sys_power', 0.0) or 0.0

    return {
        'running_jobs': len(engine.running),
        'queued_jobs': len(engine.queue),
        'active_nodes': engine.num_active_nodes,
        'free_nodes': engine.num_free_nodes,
        'down_nodes': down_nodes_count,
        'system_util': system_util,
        'total_power_kw': total_power_kw,
        'p_flops': p_flops,
        'g_flops_w': g_flops_w,
        'top_running_jobs': top_running_jobs,
        'sim_time': _get_sim_time(engine),
        'tick': tick,
    }


@@ -109,13 +146,14 @@ def site_worker_main(site_name: str,
                     sim_config_path: str,
                     job_in_q: "Queue[Dict[str, Any]]",
                     status_out_q: "Queue[Dict[str, Any]]",
                     stop_q: "Queue[bool]"):
                     stop_q: "Queue[bool]",
                     sim_time: str = "24h"):
    """
    One process per site. Owns the RAPS Engine and all mutable state.
    Communicates via queues only.
    """
    # Initialize the RAPS engine directly in this process
    engine = _initialize_raps_sim(site_name, sim_config_path)
    engine = _initialize_raps_sim(site_name, sim_config_path, sim_time=sim_time)

    # Create the simulation generator
    sim_gen = engine.run_simulation()
@@ -163,7 +201,7 @@ def site_worker_main(site_name: str,
            })

        # Periodic metrics
        metrics = _poll_site_metrics(engine, tick)
        metrics = _poll_site_metrics(engine, tick, tick_data)
        if metrics is not None:
            status_out_q.put({
                "site": site_name,

raps/ui/__init__.py

0 → 100644
+3 −0
Original line number Diff line number Diff line
from raps.ui.default import LayoutManager

__all__ = ["LayoutManager"]
+0 −0

File moved.

raps/ui/federation.py

0 → 100644
+373 −0
Original line number Diff line number Diff line
"""
Federation Dashboard — Rich Live TUI for multi-site RAPS simulations.

Renders a MetaScheduler status panel with aggregate stats plus
per-site columns showing running jobs and system metrics.
"""

import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from rich.align import Align
from rich.layout import Layout
from rich.panel import Panel
from rich.table import Table
from rich.text import Text

from raps.utils import convert_seconds_to_hhmm


# ---------------------------------------------------------------------------
# State tracking
# ---------------------------------------------------------------------------

@dataclass
class SiteState:
    """Latest known state for one site."""
    name: str
    status: str = "STARTING"
    running_jobs: int = 0
    queued_jobs: int = 0
    active_nodes: int = 0
    free_nodes: int = 0
    down_nodes: int = 0
    system_util: float = 0.0
    total_power_kw: float = 0.0
    p_flops: float = 0.0
    g_flops_w: float = 0.0
    top_running_jobs: List[dict] = field(default_factory=list)
    sim_time: Optional[int] = None
    tick: int = 0
    last_update: float = field(default_factory=time.time)


@dataclass
class PendingJob:
    """A job submitted to the metascheduler awaiting site confirmation."""
    job_id: str
    name: str
    nodes: int
    target_site: str
    submit_wall_time: float  # wall-clock time of submission
    status: str = "PENDING"  # PENDING -> ENQUEUED -> DISPATCHED


@dataclass
class MetaSchedulerState:
    """Aggregate statistics across the federation."""
    jobs_submitted: int = 0
    jobs_enqueued: int = 0
    iam_denials: int = 0
    per_site_submissions: Dict[str, int] = field(default_factory=dict)
    pending_jobs: List["PendingJob"] = field(default_factory=list)
    start_time: float = field(default_factory=time.time)


# ---------------------------------------------------------------------------
# Status color helpers
# ---------------------------------------------------------------------------

_STATUS_STYLES = {
    "STARTING": "yellow",
    "READY": "green",
    "RUNNING": "green",
    "SIMULATION_COMPLETE": "cyan",
    "STOPPED": "dim",
    "ERROR": "bold red",
}


def _status_tag(status: str) -> str:
    style = _STATUS_STYLES.get(status, "white")
    return f"[{style}]{status}[/{style}]"


# ---------------------------------------------------------------------------
# Dashboard
# ---------------------------------------------------------------------------

class FederationDashboard:
    """Rich Layout dashboard for the federated MetaScheduler."""

    def __init__(self, site_names: List[str]):
        self.site_names = list(site_names)
        self.sites: Dict[str, SiteState] = {
            name: SiteState(name=name) for name in self.site_names
        }
        self.meta = MetaSchedulerState()
        self.layout = self._build_layout()

    # ------------------------------------------------------------------
    # Layout construction
    # ------------------------------------------------------------------

    def _build_layout(self) -> Layout:
        layout = Layout(name="root")
        layout.split_column(
            Layout(name="header", size=3),
            Layout(name="meta_row", size=12),
            Layout(name="sites", ratio=1),
        )
        layout["meta_row"].split_row(
            Layout(name="meta", ratio=2),
            Layout(name="meta_queue", ratio=3),
        )
        # Split sites into columns (one per site)
        site_layouts = []
        for name in self.site_names:
            col = Layout(name=f"site_{name}")
            col.split_column(
                Layout(name=f"jobs_{name}", ratio=3),
                Layout(name=f"stats_{name}", ratio=2),
            )
            site_layouts.append(col)
        layout["sites"].split_row(*site_layouts)
        return layout

    # ------------------------------------------------------------------
    # Render helpers
    # ------------------------------------------------------------------

    def _render_header(self) -> Panel:
        elapsed = time.time() - self.meta.start_time
        h, rem = divmod(int(elapsed), 3600)
        m, s = divmod(rem, 60)
        uptime = f"{h}:{m:02d}:{s:02d}"
        title = Text.assemble(
            ("RAPS FEDERATION DASHBOARD", "bold cyan"),
            "    ",
            ("Uptime: ", "dim"),
            (uptime, "bold white"),
        )
        return Panel(Align.center(title), style="bright_blue")

    def _render_meta_panel(self) -> Panel:
        # Compute federation totals from site states
        fed_pflops = sum(s.p_flops for s in self.sites.values())
        fed_power_mw = sum(s.total_power_kw for s in self.sites.values()) / 1000.0
        fed_running = sum(s.running_jobs for s in self.sites.values())
        fed_queued = sum(s.queued_jobs for s in self.sites.values())
        active_sites = [s for s in self.sites.values() if s.system_util > 0]
        avg_util = (sum(s.system_util for s in active_sites) / len(active_sites)) if active_sites else 0.0

        elapsed_min = (time.time() - self.meta.start_time) / 60.0
        throughput = self.meta.jobs_submitted / elapsed_min if elapsed_min > 0.1 else 0.0

        table = Table(expand=True, header_style="bold magenta", show_edge=False)
        for col in ["Submitted", "Enqueued", "IAM Deny", "Fed PFLOPS",
                     "Fed Power", "Running", "Queued", "Avg Util", "Throughput"]:
            table.add_column(col, justify="center")

        table.add_row(
            str(self.meta.jobs_submitted),
            str(self.meta.jobs_enqueued),
            str(self.meta.iam_denials),
            f"{fed_pflops:.2f}",
            f"{fed_power_mw:.2f} MW",
            str(fed_running),
            str(fed_queued),
            f"{avg_util:.1f}%",
            f"{throughput:.1f} j/min",
            style="white",
        )

        # Per-site breakdown subtitle
        parts = []
        for name in self.site_names:
            count = self.meta.per_site_submissions.get(name, 0)
            parts.append(f"{name}={count}")
        subtitle = "Per-site: " + " | ".join(parts)

        return Panel(
            Align.center(table),
            title="[bold]MetaScheduler Status[/bold]",
            subtitle=subtitle,
            style="bright_blue",
        )

    def _render_meta_queue(self) -> Panel:
        table = Table(expand=True, header_style="bold magenta", show_edge=False, pad_edge=False)
        for col in ["JOBID", "NAME", "NODES", "SITE", "STATUS", "AGE"]:
            table.add_column(col, justify="center", no_wrap=True)

        now = time.time()
        for pj in self.meta.pending_jobs[-15:]:
            raw_id = str(pj.job_id)
            job_id_str = raw_id[:8] if len(raw_id) > 8 else raw_id.zfill(5)
            age_s = int(now - pj.submit_wall_time)
            age_str = f"{age_s}s"
            if pj.status == "PENDING":
                status_str = "[yellow]PENDING[/yellow]"
            elif pj.status == "ENQUEUED":
                status_str = "[green]ENQUEUED[/green]"
            else:
                status_str = f"[dim]{pj.status}[/dim]"
            table.add_row(
                job_id_str,
                str(pj.name)[:12],
                str(pj.nodes),
                pj.target_site,
                status_str,
                age_str,
                style="white",
            )

        count = sum(1 for pj in self.meta.pending_jobs if pj.status == "PENDING")
        return Panel(table, title="[bold]Job Queue[/bold]",
                     subtitle=f"Pending: {count}", border_style="bright_blue")

    def _render_site_jobs(self, site_name: str) -> Panel:
        site = self.sites[site_name]
        status_str = _status_tag(site.status)
        title = f"[bold]{site_name.upper()}[/bold] ({status_str})"

        table = Table(expand=True, header_style="bold green", show_edge=False, pad_edge=False)
        for col in ["JOBID", "NAME", "NODES", "WALL", "RUN", "ST"]:
            table.add_column(col, justify="center", no_wrap=True)

        for job in site.top_running_jobs[:10]:
            wall_str = convert_seconds_to_hhmm(job.get('wall_time', 0))
            run_str = convert_seconds_to_hhmm(job.get('run_time', 0))
            # Truncate UUID-style IDs to 8 chars; zero-pad short integer IDs
            raw_id = str(job.get('id', ''))
            job_id_str = raw_id[:8] if len(raw_id) > 8 else raw_id.zfill(5)
            table.add_row(
                job_id_str,
                str(job.get('name', ''))[:12],
                str(job.get('nodes_required', '')),
                wall_str,
                run_str,
                str(job.get('state', '')),
                style="white",
            )

        count = len(site.top_running_jobs)
        return Panel(table, title=title, subtitle=f"Jobs: {count}", border_style="cyan")

    def _render_site_stats(self, site_name: str) -> Panel:
        site = self.sites[site_name]

        table = Table(show_header=False, expand=True, show_edge=False, pad_edge=False)
        table.add_column("label", justify="right", style="dim", no_wrap=True)
        table.add_column("value", justify="left", no_wrap=True)
        table.add_column("label2", justify="right", style="dim", no_wrap=True)
        table.add_column("value2", justify="left", no_wrap=True)

        # Row 1: Util | Queued
        util_style = "green" if site.system_util < 80 else "yellow" if site.system_util < 95 else "red"
        table.add_row(
            "Util", f"[{util_style}]{site.system_util:.1f}%[/{util_style}]",
            "Queued", f"[white]{site.queued_jobs}[/white]",
        )
        # Row 2: Power | Free
        power_mw = site.total_power_kw / 1000.0
        table.add_row(
            "Power", f"[yellow]{power_mw:.2f} MW[/yellow]",
            "Free", f"[white]{site.free_nodes}[/white]",
        )
        # Row 3: PFLOPS | Down
        table.add_row(
            "PFLOPS", f"[cyan]{site.p_flops:.2f}[/cyan]",
            "Down", f"[red]{site.down_nodes}[/red]" if site.down_nodes > 0 else f"[green]{site.down_nodes}[/green]",
        )
        # Row 4: Running | GFL/W
        table.add_row(
            "Running", f"[white]{site.running_jobs}[/white]",
            "GFL/W", f"[cyan]{site.g_flops_w:.1f}[/cyan]",
        )

        return Panel(table, title="[bold]System Stats[/bold]", border_style="blue")

    # ------------------------------------------------------------------
    # Event processing
    # ------------------------------------------------------------------

    def process_event(self, msg: dict):
        """Translate a status_q message into state updates."""
        site_name = msg.get("site")
        event = msg.get("event")

        if event == "READY":
            if site_name and site_name in self.sites:
                self.sites[site_name].status = "READY"

        elif event == "METRICS":
            if site_name and site_name in self.sites:
                site = self.sites[site_name]
                site.status = "RUNNING"
                site.running_jobs = msg.get('running_jobs', site.running_jobs)
                site.queued_jobs = msg.get('queued_jobs', site.queued_jobs)
                site.active_nodes = msg.get('active_nodes', site.active_nodes)
                site.free_nodes = msg.get('free_nodes', site.free_nodes)
                site.down_nodes = msg.get('down_nodes', site.down_nodes)
                site.system_util = msg.get('system_util', site.system_util)
                site.total_power_kw = msg.get('total_power_kw', site.total_power_kw)
                site.p_flops = msg.get('p_flops', site.p_flops)
                site.g_flops_w = msg.get('g_flops_w', site.g_flops_w)
                site.top_running_jobs = msg.get('top_running_jobs', site.top_running_jobs)
                site.sim_time = msg.get('sim_time', site.sim_time)
                site.tick = msg.get('tick', site.tick)
                site.last_update = time.time()

        elif event == "HEARTBEAT":
            if site_name and site_name in self.sites:
                site = self.sites[site_name]
                site.tick = msg.get('tick', site.tick)
                site.sim_time = msg.get('sim_time', site.sim_time)
                site.last_update = time.time()

        elif event == "ENQUEUED":
            self.meta.jobs_enqueued += 1
            enq_id = str(msg.get("job_id", ""))
            for pj in self.meta.pending_jobs:
                if pj.job_id == enq_id and pj.status == "PENDING":
                    pj.status = "ENQUEUED"
                    break

        elif event == "IAM_DENY":
            self.meta.iam_denials += 1

        elif event == "SIMULATION_COMPLETE":
            if site_name and site_name in self.sites:
                self.sites[site_name].status = "SIMULATION_COMPLETE"

        elif event == "STOPPED":
            if site_name and site_name in self.sites:
                self.sites[site_name].status = "STOPPED"

        elif event == "ERROR":
            if site_name and site_name in self.sites:
                self.sites[site_name].status = "ERROR"

    def record_submission(self, site_name: str, job_id, job_name: str = "",
                          nodes: int = 0):
        """Called by the main loop after a successful ms.submit()."""
        self.meta.jobs_submitted += 1
        self.meta.per_site_submissions[site_name] = (
            self.meta.per_site_submissions.get(site_name, 0) + 1
        )
        self.meta.pending_jobs.append(PendingJob(
            job_id=str(job_id),
            name=job_name,
            nodes=nodes,
            target_site=site_name,
            submit_wall_time=time.time(),
        ))

    def record_iam_denial(self, job_id, reason: str):
        """Called by the main loop when ms.submit() raises PermissionError."""
        self.meta.iam_denials += 1

    # ------------------------------------------------------------------
    # Layout update
    # ------------------------------------------------------------------

    def update_layout(self):
        """Rebuild all panels from current state."""
        self.layout["header"].update(self._render_header())
        self.layout["meta"].update(self._render_meta_panel())
        self.layout["meta_queue"].update(self._render_meta_queue())
        for name in self.site_names:
            self.layout[f"jobs_{name}"].update(self._render_site_jobs(name))
            self.layout[f"stats_{name}"].update(self._render_site_stats(name))
Loading