Commit dcc59ba0 authored by Brewer, Wes's avatar Brewer, Wes
Browse files

Add capacity-aware dispatch and continuous job submission



- Decouple job arrival from dispatch: jobs enter a WAITING queue in the
  metascheduler and only dispatch to a site when free_nodes >= required
- Dispatch directly to site worker queues, bypassing IAM round-robin
- Random power-of-2 node counts (32–1024) and wall times for fed jobs
- Bump seed jobs to 10 per site so machines start busy
- Continuous job submission (--num-jobs 0 = unlimited, 2s interval)
- Dashboard runs until simulations complete or Ctrl+C
- Job Queue panel shows WAITING/DISPATCHED lifecycle with age

Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 0e7561d2
Loading
Loading
Loading
Loading
+6 −1
Original line number Diff line number Diff line
@@ -21,13 +21,18 @@ def _initialize_raps_sim(site_name: str, sim_config_path: str, sim_time: str = "
    sim_config = SingleSimConfig(
        system=sim_config_path,  # Path to SystemConfig or system name
        time=sim_time,           # Simulation duration
        numjobs=1,               # Minimal jobs to initialize (more will be injected)
        numjobs=10,              # Seed jobs to pre-fill the site
        noui=True,               # No UI in worker process
        output="none",           # No file output
        workload="random",       # Generates initial random jobs
        policy="fcfs",           # First-come-first-served for dynamic jobs
    )
    engine = Engine(sim_config)

    # Rename seed jobs to match the fed-N convention
    for i, job in enumerate(engine.jobs):
        job.name = f"seed-{i}"

    return engine


+91 −38
Original line number Diff line number Diff line
@@ -44,13 +44,13 @@ class SiteState:

@dataclass
class PendingJob:
    """A job submitted to the metascheduler awaiting site confirmation."""
    """A job in the metascheduler queue awaiting dispatch to a site."""
    job_id: str
    name: str
    nodes: int
    target_site: str
    submit_wall_time: float  # wall-clock time of submission
    status: str = "PENDING"  # PENDING -> ENQUEUED -> DISPATCHED
    target_site: str  # empty while WAITING, set once dispatched
    submit_wall_time: float  # wall-clock time of arrival
    status: str = "WAITING"  # WAITING -> DISPATCHED -> (removed on ENQUEUED)


@dataclass
@@ -154,22 +154,33 @@ class FederationDashboard:
        elapsed_min = (time.time() - self.meta.start_time) / 60.0
        throughput = self.meta.jobs_submitted / elapsed_min if elapsed_min > 0.1 else 0.0

        table = Table(expand=True, header_style="bold magenta", show_edge=False)
        for col in ["Submitted", "Enqueued", "IAM Deny", "Fed PFLOPS",
                     "Fed Power", "Running", "Queued", "Avg Util", "Throughput"]:
            table.add_column(col, justify="center")
        # 2×5 key-value grid (two label/value pairs per row)
        table = Table(show_header=False, expand=True, show_edge=False, pad_edge=False)
        table.add_column("l1", justify="right", style="dim", no_wrap=True)
        table.add_column("v1", justify="left", no_wrap=True)
        table.add_column("l2", justify="right", style="dim", no_wrap=True)
        table.add_column("v2", justify="left", no_wrap=True)

        waiting_count = sum(1 for pj in self.meta.pending_jobs if pj.status == "WAITING")
        table.add_row(
            str(self.meta.jobs_submitted),
            str(self.meta.jobs_enqueued),
            str(self.meta.iam_denials),
            f"{fed_pflops:.2f}",
            f"{fed_power_mw:.2f} MW",
            str(fed_running),
            str(fed_queued),
            f"{avg_util:.1f}%",
            f"{throughput:.1f} j/min",
            style="white",
            "Submitted", f"[white]{self.meta.jobs_submitted}[/white]",
            "Waiting", f"[yellow]{waiting_count}[/yellow]" if waiting_count else "[white]0[/white]",
        )
        table.add_row(
            "IAM Deny", f"[red]{self.meta.iam_denials}[/red]" if self.meta.iam_denials else "[white]0[/white]",
            "Throughput", f"[white]{throughput:.1f} j/min[/white]",
        )
        table.add_row(
            "Fed PFLOPS", f"[cyan]{fed_pflops:.2f}[/cyan]",
            "Fed Power", f"[yellow]{fed_power_mw:.2f} MW[/yellow]",
        )
        table.add_row(
            "Running", f"[white]{fed_running}[/white]",
            "Queued", f"[white]{fed_queued}[/white]",
        )
        table.add_row(
            "Avg Util", f"[green]{avg_util:.1f}%[/green]",
            "", "",
        )

        # Per-site breakdown subtitle
@@ -177,10 +188,10 @@ class FederationDashboard:
        for name in self.site_names:
            count = self.meta.per_site_submissions.get(name, 0)
            parts.append(f"{name}={count}")
        subtitle = "Per-site: " + " | ".join(parts)
        subtitle = " | ".join(parts)

        return Panel(
            Align.center(table),
            table,
            title="[bold]MetaScheduler Status[/bold]",
            subtitle=subtitle,
            style="bright_blue",
@@ -197,25 +208,28 @@ class FederationDashboard:
            job_id_str = raw_id[:8] if len(raw_id) > 8 else raw_id.zfill(5)
            age_s = int(now - pj.submit_wall_time)
            age_str = f"{age_s}s"
            if pj.status == "PENDING":
                status_str = "[yellow]PENDING[/yellow]"
            elif pj.status == "ENQUEUED":
                status_str = "[green]ENQUEUED[/green]"
            if pj.status == "WAITING":
                status_str = "[bold yellow]WAITING[/bold yellow]"
            elif pj.status == "DISPATCHED":
                status_str = "[green]DISPATCHED[/green]"
            elif pj.status == "DENIED":
                status_str = "[red]DENIED[/red]"
            else:
                status_str = f"[dim]{pj.status}[/dim]"
            site_str = pj.target_site if pj.target_site else "[dim]---[/dim]"
            table.add_row(
                job_id_str,
                str(pj.name)[:12],
                str(pj.nodes),
                pj.target_site,
                site_str,
                status_str,
                age_str,
                style="white",
            )

        count = sum(1 for pj in self.meta.pending_jobs if pj.status == "PENDING")
        waiting = sum(1 for pj in self.meta.pending_jobs if pj.status == "WAITING")
        return Panel(table, title="[bold]Job Queue[/bold]",
                     subtitle=f"Pending: {count}", border_style="bright_blue")
                     subtitle=f"Waiting: {waiting}", border_style="bright_blue")

    def _render_site_jobs(self, site_name: str) -> Panel:
        site = self.sites[site_name]
@@ -320,10 +334,10 @@ class FederationDashboard:
        elif event == "ENQUEUED":
            self.meta.jobs_enqueued += 1
            enq_id = str(msg.get("job_id", ""))
            for pj in self.meta.pending_jobs:
                if pj.job_id == enq_id and pj.status == "PENDING":
                    pj.status = "ENQUEUED"
                    break
            self.meta.pending_jobs = [
                pj for pj in self.meta.pending_jobs
                if not (pj.job_id == enq_id and pj.status == "DISPATCHED")
            ]

        elif event == "IAM_DENY":
            self.meta.iam_denials += 1
@@ -340,21 +354,60 @@ class FederationDashboard:
            if site_name and site_name in self.sites:
                self.sites[site_name].status = "ERROR"

    def record_submission(self, site_name: str, job_id, job_name: str = "",
                          nodes: int = 0):
        """Called by the main loop after a successful ms.submit()."""
    def enqueue_job(self, job_id: str, job_name: str, nodes: int):
        """Add a job to the metascheduler waiting queue."""
        self.meta.jobs_submitted += 1
        self.meta.per_site_submissions[site_name] = (
            self.meta.per_site_submissions.get(site_name, 0) + 1
        )
        self.meta.pending_jobs.append(PendingJob(
            job_id=str(job_id),
            name=job_name,
            nodes=nodes,
            target_site=site_name,
            target_site="",
            submit_wall_time=time.time(),
        ))

    def try_dispatch(self, ms) -> int:
        """
        Try to dispatch WAITING jobs to sites with enough free nodes.
        Returns the number of jobs dispatched this cycle.
        """
        dispatched = 0
        still_waiting = []
        for pj in self.meta.pending_jobs:
            if pj.status != "WAITING":
                still_waiting.append(pj)
                continue
            # Find a site with enough free nodes
            best_site = None
            best_free = -1
            for name in self.site_names:
                site = self.sites[name]
                if site.status not in ("READY", "RUNNING"):
                    continue
                if site.free_nodes >= pj.nodes and site.free_nodes > best_free:
                    best_site = name
                    best_free = site.free_nodes
            if best_site is not None:
                # Dispatch directly to the chosen site's queue
                job_dict = {
                    "job_id": pj.job_id,
                    "nodes_required": pj.nodes,
                    "wall_time_s": 3600,
                    "name": pj.name,
                    "meta": {"account": "federated"},
                }
                ms.job_queues[best_site].put(job_dict)
                pj.status = "DISPATCHED"
                pj.target_site = best_site
                self.meta.per_site_submissions[best_site] = (
                    self.meta.per_site_submissions.get(best_site, 0) + 1
                )
                # Optimistically reduce free count so we don't over-dispatch
                self.sites[best_site].free_nodes -= pj.nodes
                dispatched += 1
            still_waiting.append(pj)
        self.meta.pending_jobs = still_waiting
        return dispatched

    def record_iam_denial(self, job_id, reason: str):
        """Called by the main loop when ms.submit() raises PermissionError."""
        self.meta.iam_denials += 1
+22 −23
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ and renders a Rich Live dashboard showing aggregate + per-site metrics.
"""

import argparse
import random
import time

from raps.metasched.iam import AccessToken, IAMPolicyEngine
@@ -27,10 +28,13 @@ def _build_token() -> AccessToken:
    )


_NODE_CHOICES = [32, 64, 128, 256, 512, 1024]


def _make_job(index: int) -> FedJob:
    return FedJob(
        nodes_required=64,
        wall_time_s=3600,
        nodes_required=random.choice(_NODE_CHOICES),
        wall_time_s=random.choice([1800, 3600, 7200]),
        name=f"fed-{index}",
        meta={
            "account": "federated",
@@ -68,44 +72,40 @@ def _run_text_mode(ms: MetaScheduler, token: AccessToken, num_jobs: int,
# --------------------------------------------------------------------------

def _run_dashboard(ms: MetaScheduler, token: AccessToken, num_jobs: int,
                   listen_seconds: int, site_names: list):
                   site_names: list):
    from rich.live import Live
    from raps.ui.federation import FederationDashboard

    dashboard = FederationDashboard(site_names)
    dashboard.update_layout()

    submit_interval = 0.3  # seconds between job submissions
    submit_interval = 2.0  # seconds between job submissions
    last_submit = 0.0
    next_job = 0
    t0 = time.time()

    with Live(dashboard.layout, refresh_per_second=4, screen=True):
        while True:
            now = time.time()
            elapsed = now - t0

            # Submit next job if interval has elapsed and jobs remain
            if next_job < num_jobs and (now - last_submit) >= submit_interval:
            # Add a new job to the metascheduler queue at regular intervals
            if (num_jobs == 0 or next_job < num_jobs) and (now - last_submit) >= submit_interval:
                job = _make_job(next_job)
                try:
                    site = ms.submit(job, token=token)
                    dashboard.record_submission(site, job.job_id,
                                               job_name=job.name,
                                               nodes=job.nodes_required)
                except PermissionError:
                    dashboard.record_iam_denial(job.job_id, "denied")
                dashboard.enqueue_job(job.job_id, job.name, job.nodes_required)
                next_job += 1
                last_submit = now

            # Try to dispatch waiting jobs to sites with capacity
            dashboard.try_dispatch(ms)

            # Drain the status queue
            for msg in ms.poll_status(200):
                dashboard.process_event(msg)

            dashboard.update_layout()

            # Stop after all jobs submitted + listen period
            if next_job >= num_jobs and elapsed >= (next_job * submit_interval + listen_seconds):
            # Exit when all site simulations have completed
            if all(dashboard.sites[s].status in ("SIMULATION_COMPLETE", "STOPPED", "ERROR")
                   for s in site_names):
                break

            time.sleep(0.05)
@@ -119,10 +119,10 @@ def main():
    parser = argparse.ArgumentParser(description="RAPS Federation Demo")
    parser.add_argument("--noui", action="store_true",
                        help="Disable Rich dashboard; print events to stdout")
    parser.add_argument("--num-jobs", type=int, default=10,
                        help="Number of jobs to submit (default: 10)")
    parser.add_argument("--listen-seconds", type=int, default=10,
                        help="Seconds to listen after last submission (default: 10)")
    parser.add_argument("--num-jobs", type=int, default=0,
                        help="Number of jobs to submit (default: 0 = continuous)")
    parser.add_argument("--listen-seconds", type=int, default=60,
                        help="Seconds to listen after last submission in --noui mode (default: 60)")
    parser.add_argument("--sim-time", type=str, default="24h",
                        help="Simulated time per site (default: 24h)")
    args = parser.parse_args()
@@ -149,8 +149,7 @@ def main():
        if args.noui:
            _run_text_mode(ms, token, args.num_jobs, args.listen_seconds)
        else:
            _run_dashboard(ms, token, args.num_jobs, args.listen_seconds,
                           list(sites.keys()))
            _run_dashboard(ms, token, args.num_jobs, list(sites.keys()))
    finally:
        ms.stop()