Commit 3d44b0e8 authored by Brewer, Wes's avatar Brewer, Wes
Browse files

Add run-fed CLI command with YAML federation config

parent ebe06dfb
Loading
Loading
Loading
Loading

experiments/amsc.yaml

0 → 100644
+40 −0
Original line number Diff line number Diff line
# Federation experiment: AMSC (Aurora, Frontier, Perlmutter)
sites:
  frontier: ../config/frontier.yaml
  aurora: ../config/aurora.yaml
  perlmutter: ../config/perlmutter.yaml

sim_time: 24h

iam:
  site_trust_tiers:
    frontier: 3
    aurora: 2
    perlmutter: 2

dispatch:
  policy: max_free
  interval: 0.0
  max_per_cycle: 0

submit:
  interval: 2.0
  num_jobs: 0
  max_waiting: 0
  node_choices: [32, 64, 128, 256, 512, 1024]

token:
  subject: alice@federation
  roles: [researcher]
  scopes: [federation:submit]
  allowed_sites: [frontier, aurora, perlmutter]
  max_nodes: 128
  max_wall_time_s: 7200

demo:
  submit_interval: 0.2
  dispatch_interval: 1.0
  max_dispatch: 2
  dispatch_policy: round_robin
  max_waiting: 10
  node_choices: [8, 16, 32, 64, 128]
+2 −0
Original line number Diff line number Diff line
@@ -68,6 +68,7 @@ def main(cli_args: list[str] | None = None):
    check_python_version()

    from raps.run_sim import run_sim_add_parser, run_parts_sim_add_parser, show_add_parser
    from raps.run_fed import run_fed_add_parser
    from raps.workloads import run_workload_add_parser
    from raps.telemetry import run_telemetry_add_parser, run_download_add_parser
    from raps.train_rl import train_rl_add_parser
@@ -83,6 +84,7 @@ def main(cli_args: list[str] | None = None):
    run_sim_add_parser(subparsers)
    run_parts_sim_add_parser(subparsers)
    show_add_parser(subparsers)
    run_fed_add_parser(subparsers)
    run_workload_add_parser(subparsers)
    run_telemetry_add_parser(subparsers)
    run_download_add_parser(subparsers)

raps/fed_config.py

0 → 100644
+55 −0
Original line number Diff line number Diff line
from pathlib import Path

from pydantic import Field

from raps.utils import RAPSBaseModel, ResolvedPath


class IAMConfig(RAPSBaseModel):
    site_trust_tiers: dict[str, int] = Field(default_factory=dict)


class DispatchConfig(RAPSBaseModel):
    policy: str = "max_free"
    interval: float = 0.0
    max_per_cycle: int = 0


class SubmitConfig(RAPSBaseModel):
    interval: float = 2.0
    num_jobs: int = 0
    max_waiting: int = 0
    node_choices: list[int] = Field(default_factory=lambda: [32, 64, 128, 256, 512, 1024])


class TokenConfig(RAPSBaseModel):
    subject: str = "alice@federation"
    roles: list[str] = Field(default_factory=lambda: ["researcher"])
    scopes: list[str] = Field(default_factory=lambda: ["federation:submit"])
    allowed_sites: list[str] = Field(default_factory=list)
    max_nodes: int = 128
    max_wall_time_s: int = 7200


class DemoConfig(RAPSBaseModel):
    submit_interval: float = 0.2
    dispatch_interval: float = 1.0
    max_dispatch: int = 2
    dispatch_policy: str = "round_robin"
    max_waiting: int = 10
    node_choices: list[int] = Field(default_factory=lambda: [8, 16, 32, 64, 128])


class FederationConfig(RAPSBaseModel):
    sites: dict[str, ResolvedPath] = Field(default_factory=dict)
    sim_time: str = "24h"
    iam: IAMConfig = Field(default_factory=IAMConfig)
    dispatch: DispatchConfig = Field(default_factory=DispatchConfig)
    submit: SubmitConfig = Field(default_factory=SubmitConfig)
    token: TokenConfig = Field(default_factory=TokenConfig)
    demo: DemoConfig = Field(default_factory=DemoConfig)
    noui: bool = False
    listen_seconds: int = 60
    output_json: Path | None = None
    output_mode: str = "both"
    output_interval: float = 1.0

raps/run_fed.py

0 → 100644
+284 −0
Original line number Diff line number Diff line
import argparse
import json
import random
import time
from pathlib import Path
from typing import Any

from raps.fed_config import FederationConfig
from raps.metasched.iam import AccessToken, IAMPolicyEngine
from raps.metasched.metascheduler import MetaScheduler
from raps.metasched.types import FedJob
from raps.utils import SubParsers, read_yaml_parsed


def run_fed_add_parser(subparsers: SubParsers):
    parser = subparsers.add_parser("run-fed", description="Run a federated multi-site simulation")
    parser.add_argument("config_file", nargs="?", default=None,
                        help="Federation YAML config (e.g. experiments/amsc.yaml)")
    parser.add_argument("--demo", action="store_true",
                        help="Use demo preset overrides for faster visible feedback")
    parser.add_argument("--noui", action="store_true",
                        help="Disable Rich dashboard; print events to stdout")
    parser.add_argument("--sim-time", type=str, default=None)
    parser.add_argument("--num-jobs", type=int, default=None)
    parser.add_argument("--listen-seconds", type=int, default=None)
    parser.add_argument("--submit-interval", type=float, default=None)
    parser.add_argument("--dispatch-interval", type=float, default=None)
    parser.add_argument("--max-dispatch", type=int, default=None)
    parser.add_argument("--dispatch-policy", type=str, default=None,
                        choices=["max_free", "round_robin", "random"])
    parser.add_argument("--max-waiting", type=int, default=None)
    parser.add_argument("--node-choices", type=str, default=None,
                        help="Comma-separated nodes choices")
    parser.add_argument("--output-json", type=str, default=None,
                        help="Write JSONL output to this path")
    parser.add_argument("--output-mode", type=str, default=None,
                        choices=["events", "snapshot", "both"])
    parser.add_argument("--output-interval", type=float, default=None)
    parser.set_defaults(impl=lambda args: run_fed(args))


def _build_token(token_cfg) -> AccessToken:
    return AccessToken(
        subject=token_cfg.subject,
        roles=set(token_cfg.roles),
        scopes=set(token_cfg.scopes),
        allowed_sites=set(token_cfg.allowed_sites) if token_cfg.allowed_sites else None,
        max_nodes=token_cfg.max_nodes,
        max_wall_time_s=token_cfg.max_wall_time_s,
        expiry_epoch_s=int(time.time()) + 3600,
    )


def _make_job(index: int, node_choices: list[int]) -> FedJob:
    return FedJob(
        nodes_required=random.choice(node_choices),
        wall_time_s=random.choice([1800, 3600, 7200]),
        name=f"fed-{index}",
        meta={
            "account": "federated",
            "required_trust_tier": 2,
        },
    )


def _run_text_mode(ms: MetaScheduler, token: AccessToken, num_jobs: int,
                   listen_seconds: int, json_out: dict[str, Any] | None,
                   node_choices: list[int], submit_interval: float):
    for i in range(num_jobs):
        job = _make_job(i, node_choices)
        try:
            site = ms.submit(job, token=token)
            print(f"submitted {job.job_id} by {token.subject} -> {site}")
        except PermissionError as exc:
            print(f"DENIED {job.job_id}: {exc}")
        time.sleep(submit_interval)

    t0 = time.time()
    while time.time() - t0 < listen_seconds:
        for msg in ms.poll_status():
            ev = msg.get("event")
            if ev in ("HEARTBEAT", "METRICS", "ERROR", "IAM_DENY"):
                print(msg)
            if json_out and json_out["mode"] in ("events", "both"):
                json_out["write"]({
                    "type": "event",
                    "ts": time.time(),
                    "msg": msg,
                })
        time.sleep(0.2)


def _run_dashboard(ms: MetaScheduler, token: AccessToken, num_jobs: int,
                   site_names: list[str], json_out: dict[str, Any] | None,
                   submit_interval: float, dispatch_interval: float,
                   max_dispatch: int, dispatch_policy: str,
                   node_choices: list[int], max_waiting: int):
    from rich.live import Live
    from raps.ui.federation import FederationDashboard

    dashboard = FederationDashboard(site_names)
    dashboard.update_layout()

    last_submit = 0.0
    last_dispatch = 0.0
    next_job = 0

    with Live(dashboard.layout, refresh_per_second=4, screen=True):
        while True:
            now = time.time()

            if (num_jobs == 0 or next_job < num_jobs) and (now - last_submit) >= submit_interval:
                if max_waiting <= 0 or dashboard.waiting_count() < max_waiting:
                    job = _make_job(next_job, node_choices)
                    dashboard.enqueue_job(job.job_id, job.name, job.nodes_required)
                    if json_out and json_out["mode"] in ("events", "both"):
                        json_out["write"]({
                            "type": "local_enqueue",
                            "ts": now,
                            "job_id": job.job_id,
                            "name": job.name,
                            "nodes_required": job.nodes_required,
                            "wall_time_s": job.wall_time_s,
                        })
                    next_job += 1
                    last_submit = now

            if dispatch_interval <= 0 or (now - last_dispatch) >= dispatch_interval:
                dashboard.try_dispatch(ms, max_dispatch=max_dispatch, policy=dispatch_policy)
                last_dispatch = now

            for msg in ms.poll_status(200):
                dashboard.process_event(msg)
                if json_out and json_out["mode"] in ("events", "both"):
                    json_out["write"]({
                        "type": "event",
                        "ts": time.time(),
                        "msg": msg,
                    })

            dashboard.update_layout()

            if json_out and json_out["mode"] in ("snapshot", "both"):
                if now >= json_out["next_snapshot"]:
                    json_out["write"]({
                        "type": "snapshot",
                        "ts": now,
                        "state": dashboard.snapshot(),
                    })
                    json_out["next_snapshot"] = now + json_out["interval"]

            if all(dashboard.sites[s].status in ("SIMULATION_COMPLETE", "STOPPED", "ERROR")
                   for s in site_names):
                break

            time.sleep(0.05)


def _load_federation_config(config_file: str | None) -> FederationConfig:
    yaml_data = read_yaml_parsed(FederationConfig, config_file)
    return FederationConfig.model_validate(yaml_data)


def run_fed(args: argparse.Namespace):
    cfg = _load_federation_config(args.config_file)

    submit_interval = cfg.submit.interval
    dispatch_interval = cfg.dispatch.interval
    max_dispatch = cfg.dispatch.max_per_cycle
    dispatch_policy = cfg.dispatch.policy
    max_waiting = cfg.submit.max_waiting
    node_choices = list(cfg.submit.node_choices)

    if args.demo:
        submit_interval = cfg.demo.submit_interval
        dispatch_interval = cfg.demo.dispatch_interval
        max_dispatch = cfg.demo.max_dispatch
        dispatch_policy = cfg.demo.dispatch_policy
        max_waiting = cfg.demo.max_waiting
        node_choices = list(cfg.demo.node_choices)

    if args.sim_time is not None:
        cfg.sim_time = args.sim_time
    if args.num_jobs is not None:
        cfg.submit.num_jobs = args.num_jobs
    if args.listen_seconds is not None:
        cfg.listen_seconds = args.listen_seconds
    if args.submit_interval is not None:
        submit_interval = args.submit_interval
    if args.dispatch_interval is not None:
        dispatch_interval = args.dispatch_interval
    if args.max_dispatch is not None:
        max_dispatch = args.max_dispatch
    if args.dispatch_policy is not None:
        dispatch_policy = args.dispatch_policy
    if args.max_waiting is not None:
        max_waiting = args.max_waiting
    if args.node_choices:
        node_choices = [int(x) for x in args.node_choices.split(",") if x.strip()]
    if args.output_json is not None:
        cfg.output_json = Path(args.output_json) if args.output_json else None
    if args.output_mode is not None:
        cfg.output_mode = args.output_mode
    if args.output_interval is not None:
        cfg.output_interval = args.output_interval
    if args.noui:
        cfg.noui = True

    iam = IAMPolicyEngine(site_trust_tiers=cfg.iam.site_trust_tiers)
    sites = {name: str(path) for name, path in cfg.sites.items()}
    ms = MetaScheduler(sites, iam_policy=iam, sim_time=cfg.sim_time)
    token = _build_token(cfg.token)

    ms.start()
    json_out = None
    if cfg.output_json:
        out_path = Path(cfg.output_json)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        out_f = out_path.open("w", buffering=1)

        def _write(obj: dict[str, Any]):
            out_f.write(json.dumps(obj) + "\n")

        json_out = {
            "file": out_f,
            "write": _write,
            "mode": cfg.output_mode,
            "interval": max(0.1, cfg.output_interval),
            "next_snapshot": 0.0,
        }
        json_out["write"]({
            "type": "start",
            "ts": time.time(),
            "sites": list(sites.keys()),
            "sim_time": cfg.sim_time,
            "num_jobs": cfg.submit.num_jobs,
            "noui": cfg.noui,
        })

    try:
        if cfg.noui:
            _run_text_mode(
                ms,
                token,
                cfg.submit.num_jobs,
                cfg.listen_seconds,
                json_out,
                node_choices,
                submit_interval,
            )
        else:
            _run_dashboard(
                ms,
                token,
                cfg.submit.num_jobs,
                list(sites.keys()),
                json_out,
                submit_interval=submit_interval,
                dispatch_interval=dispatch_interval,
                max_dispatch=max_dispatch,
                dispatch_policy=dispatch_policy,
                node_choices=node_choices,
                max_waiting=max_waiting,
            )
    finally:
        if json_out:
            json_out["write"]({
                "type": "end",
                "ts": time.time(),
            })
            json_out["file"].close()
        ms.stop()


def main(cli_args: list[str] | None = None):
    parser = argparse.ArgumentParser(description="RAPS Federation Demo")
    subparsers = parser.add_subparsers(required=True)
    run_fed_add_parser(subparsers)
    args = parser.parse_args(["run-fed", *(cli_args or [])])
    args.impl(args)


if __name__ == "__main__":
    main()

scripts/run_federation.py

100644 → 100755
+4 −263
Original line number Diff line number Diff line
#!/usr/bin/env python3
"""
Federation demo — launches 3 RAPS site simulations behind a MetaScheduler
and renders a Rich Live dashboard showing aggregate + per-site metrics.
"""Legacy wrapper: prefer `raps run-fed experiments/amsc.yaml`."""

    python scripts/run_federation.py              # full TUI
    python scripts/run_federation.py --noui       # text-only fallback
"""
import sys

import argparse
import json
import os
import random
import time
from typing import Optional

from raps.metasched.iam import AccessToken, IAMPolicyEngine
from raps.metasched.metascheduler import MetaScheduler
from raps.metasched.types import FedJob


def _build_token() -> AccessToken:
    return AccessToken(
        subject="alice@federation",
        roles={"researcher"},
        scopes={"federation:submit"},
        allowed_sites={"frontier", "aurora", "perlmutter"},
        max_nodes=128,
        max_wall_time_s=7200,
        expiry_epoch_s=int(time.time()) + 3600,
    )


_NODE_CHOICES = [32, 64, 128, 256, 512, 1024]


def _make_job(index: int, node_choices: list[int]) -> FedJob:
    return FedJob(
        nodes_required=random.choice(node_choices),
        wall_time_s=random.choice([1800, 3600, 7200]),
        name=f"fed-{index}",
        meta={
            "account": "federated",
            "required_trust_tier": 2,
        },
    )


# --------------------------------------------------------------------------
# Text-only fallback (--noui)
# --------------------------------------------------------------------------

def _run_text_mode(ms: MetaScheduler, token: AccessToken, num_jobs: int,
                   listen_seconds: int, json_out: Optional[dict],
                   node_choices: list[int]):
    for i in range(num_jobs):
        job = _make_job(i, node_choices)
        try:
            site = ms.submit(job, token=token)
            print(f"submitted {job.job_id} by {token.subject} -> {site}")
        except PermissionError as exc:
            print(f"DENIED {job.job_id}: {exc}")
        time.sleep(0.2)

    t0 = time.time()
    while time.time() - t0 < listen_seconds:
        for msg in ms.poll_status():
            ev = msg.get("event")
            if ev in ("HEARTBEAT", "METRICS", "ERROR", "IAM_DENY"):
                print(msg)
            if json_out and json_out["mode"] in ("events", "both"):
                json_out["write"]({
                    "type": "event",
                    "ts": time.time(),
                    "msg": msg,
                })
        time.sleep(0.2)


# --------------------------------------------------------------------------
# Rich Live dashboard mode
# --------------------------------------------------------------------------

def _run_dashboard(ms: MetaScheduler, token: AccessToken, num_jobs: int,
                   site_names: list, json_out: Optional[dict],
                   submit_interval: float, dispatch_interval: float,
                   max_dispatch: int, dispatch_policy: str,
                   node_choices: list[int], max_waiting: int):
    from rich.live import Live
    from raps.ui.federation import FederationDashboard

    dashboard = FederationDashboard(site_names)
    dashboard.update_layout()

    last_submit = 0.0
    last_dispatch = 0.0
    next_job = 0

    with Live(dashboard.layout, refresh_per_second=4, screen=True):
        while True:
            now = time.time()

            # Add a new job to the metascheduler queue at regular intervals
            if (num_jobs == 0 or next_job < num_jobs) and (now - last_submit) >= submit_interval:
                if max_waiting <= 0 or dashboard.waiting_count() < max_waiting:
                    job = _make_job(next_job, node_choices)
                    dashboard.enqueue_job(job.job_id, job.name, job.nodes_required)
                    if json_out and json_out["mode"] in ("events", "both"):
                        json_out["write"]({
                            "type": "local_enqueue",
                            "ts": now,
                            "job_id": job.job_id,
                            "name": job.name,
                            "nodes_required": job.nodes_required,
                            "wall_time_s": job.wall_time_s,
                        })
                    next_job += 1
                    last_submit = now

            # Try to dispatch waiting jobs to sites with capacity
            if dispatch_interval <= 0 or (now - last_dispatch) >= dispatch_interval:
                dashboard.try_dispatch(ms, max_dispatch=max_dispatch, policy=dispatch_policy)
                last_dispatch = now

            # Drain the status queue
            for msg in ms.poll_status(200):
                dashboard.process_event(msg)
                if json_out and json_out["mode"] in ("events", "both"):
                    json_out["write"]({
                        "type": "event",
                        "ts": time.time(),
                        "msg": msg,
                    })

            dashboard.update_layout()

            if json_out and json_out["mode"] in ("snapshot", "both"):
                if now >= json_out["next_snapshot"]:
                    json_out["write"]({
                        "type": "snapshot",
                        "ts": now,
                        "state": dashboard.snapshot(),
                    })
                    json_out["next_snapshot"] = now + json_out["interval"]

            # Exit when all site simulations have completed
            if all(dashboard.sites[s].status in ("SIMULATION_COMPLETE", "STOPPED", "ERROR")
                   for s in site_names):
                break

            time.sleep(0.05)


# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="RAPS Federation Demo")
    parser.add_argument("--noui", action="store_true",
                        help="Disable Rich dashboard; print events to stdout")
    parser.add_argument("--num-jobs", type=int, default=0,
                        help="Number of jobs to submit (default: 0 = continuous)")
    parser.add_argument("--listen-seconds", type=int, default=60,
                        help="Seconds to listen after last submission in --noui mode (default: 60)")
    parser.add_argument("--sim-time", type=str, default="24h",
                        help="Simulated time per site (default: 24h)")
    parser.add_argument("--submit-interval", type=float, default=2.0,
                        help="Seconds between job submissions (default: 2.0)")
    parser.add_argument("--dispatch-interval", type=float, default=0.0,
                        help="Seconds between dispatch attempts (default: 0 = every loop)")
    parser.add_argument("--max-dispatch", type=int, default=0,
                        help="Max jobs to dispatch per attempt (default: 0 = unlimited)")
    parser.add_argument("--dispatch-policy", type=str, default="max_free",
                        choices=["max_free", "round_robin", "random"],
                        help="Dispatch site selection policy (default: max_free)")
    parser.add_argument("--max-waiting", type=int, default=0,
                        help="Cap WAITING jobs in metascheduler queue (default: 0 = unlimited)")
    parser.add_argument("--node-choices", type=str, default="",
                        help="Comma-separated nodes choices (default: built-in list)")
    parser.add_argument("--demo", action="store_true",
                        help="Use a preset that builds a visible queue for demos")
    parser.add_argument("--output-json", type=str, default="",
                        help="Write JSONL output to this path (default: disabled)")
    parser.add_argument("--output-mode", type=str, default="both",
                        choices=["events", "snapshot", "both"],
                        help="JSON output mode (default: both)")
    parser.add_argument("--output-interval", type=float, default=1.0,
                        help="Seconds between snapshots (default: 1.0)")
    args = parser.parse_args()

    sites = {
        "frontier": "config/frontier.yaml",
        "aurora": "config/aurora.yaml",
        "perlmutter": "config/perlmutter.yaml",
    }

    iam = IAMPolicyEngine(
        site_trust_tiers={
            "frontier": 3,
            "aurora": 2,
            "perlmutter": 2,
        }
    )
    if args.demo:
        args.submit_interval = 0.2
        args.dispatch_interval = 1.0
        args.max_dispatch = 2
        args.dispatch_policy = "round_robin"
        args.max_waiting = 10
        args.node_choices = "8,16,32,64,128"

    if args.node_choices:
        node_choices = [int(x) for x in args.node_choices.split(",") if x.strip()]
    else:
        node_choices = list(_NODE_CHOICES)

    ms = MetaScheduler(sites, iam_policy=iam, sim_time=args.sim_time)
    token = _build_token()

    ms.start()
    json_out = None
    if args.output_json:
        out_dir = os.path.dirname(args.output_json)
        if out_dir:
            os.makedirs(out_dir, exist_ok=True)
        out_f = open(args.output_json, "w", buffering=1)
        def _write(obj: dict):
            out_f.write(json.dumps(obj) + "\n")
        json_out = {
            "file": out_f,
            "write": _write,
            "mode": args.output_mode,
            "interval": max(0.1, args.output_interval),
            "next_snapshot": 0.0,
        }
        json_out["write"]({
            "type": "start",
            "ts": time.time(),
            "sites": list(sites.keys()),
            "sim_time": args.sim_time,
            "num_jobs": args.num_jobs,
            "noui": args.noui,
        })

    try:
        if args.noui:
            _run_text_mode(ms, token, args.num_jobs, args.listen_seconds, json_out,
                           node_choices)
        else:
            _run_dashboard(ms, token, args.num_jobs, list(sites.keys()), json_out,
                           submit_interval=args.submit_interval,
                           dispatch_interval=args.dispatch_interval,
                           max_dispatch=args.max_dispatch,
                           dispatch_policy=args.dispatch_policy,
                           node_choices=node_choices,
                           max_waiting=args.max_waiting)
    finally:
        if json_out:
            json_out["write"]({
                "type": "end",
                "ts": time.time(),
            })
            json_out["file"].close()
        ms.stop()
from raps.run_fed import main


if __name__ == "__main__":
    main()
    main(sys.argv[1:])