Commit dc5a4b62 authored by Codex's avatar Codex Committed by Brewer, Wes
Browse files

Add IAM policy enforcement for federated job submission



Introduce a lightweight IAM model for federation with token-based
authorization and site trust-tier checks.

- add AccessToken/IAMPolicyEngine primitives
- enforce authorization in MetaScheduler.submit(job, token)
- emit IAM_DENY events with denial reasons
- wire IAM demo flow into scripts/run_federation.py
- add unit tests for allow/deny paths and unauthorized submission rejection

Signed-off-by: default avatarCodex <codex@openai.com>
parent 0dfc13f0
Loading
Loading
Loading
Loading

raps/metasched/iam.py

0 → 100644
+71 −0
Original line number Diff line number Diff line
from dataclasses import dataclass, field
from typing import Dict, Iterable, Optional, Set
import time

from .types import FedJob


@dataclass(frozen=True)
class AccessToken:
    """
    Lightweight token model for federation simulation.
    """
    subject: str
    roles: Set[str]
    allowed_sites: Optional[Set[str]] = None
    max_nodes: Optional[int] = None
    max_wall_time_s: Optional[int] = None
    expiry_epoch_s: Optional[int] = None
    scopes: Set[str] = field(default_factory=set)


@dataclass(frozen=True)
class IAMDecision:
    allowed: bool
    reason: str = ""


class IAMPolicyEngine:
    """
    Enforces authentication/authorization checks for federated submissions.

    Site trust tiers are optional and use larger number = higher trust.
    Jobs can request a required trust tier using job.meta["required_trust_tier"].
    """

    def __init__(self, site_trust_tiers: Optional[Dict[str, int]] = None):
        self.site_trust_tiers = site_trust_tiers or {}

    def authorize(self, job: FedJob, token: Optional[AccessToken], site: str) -> IAMDecision:
        if token is None:
            return IAMDecision(False, "missing access token")

        now_s = int(time.time())
        if token.expiry_epoch_s is not None and now_s >= token.expiry_epoch_s:
            return IAMDecision(False, "token expired")

        if "federation:submit" not in token.scopes:
            return IAMDecision(False, "missing scope federation:submit")

        if token.allowed_sites is not None and site not in token.allowed_sites:
            return IAMDecision(False, f"site {site} not allowed for subject {token.subject}")

        if token.max_nodes is not None and job.nodes_required > token.max_nodes:
            return IAMDecision(False, f"nodes_required={job.nodes_required} exceeds token max_nodes={token.max_nodes}")

        if token.max_wall_time_s is not None and job.wall_time_s > token.max_wall_time_s:
            return IAMDecision(False, f"wall_time_s={job.wall_time_s} exceeds token max_wall_time_s={token.max_wall_time_s}")

        required_tier = int(job.meta.get("required_trust_tier", 0))
        site_tier = int(self.site_trust_tiers.get(site, 0))
        if site_tier < required_tier:
            return IAMDecision(False, f"site trust tier {site_tier} is below required_trust_tier {required_tier}")

        return IAMDecision(True)

    def eligible_sites(self, job: FedJob, token: Optional[AccessToken], sites: Iterable[str]) -> Set[str]:
        allowed = set()
        for site in sites:
            if self.authorize(job, token, site).allowed:
                allowed.add(site)
        return allowed
+38 −6
Original line number Diff line number Diff line
from multiprocessing import Process, Queue
from typing import Dict, List
from typing import Dict, List, Optional
from .types import FedJob
from .site_worker import site_worker_main
from .iam import AccessToken, IAMPolicyEngine

class MetaScheduler:
    def __init__(self, sites: Dict[str, str]):
    def __init__(self, sites: Dict[str, str], iam_policy: Optional[IAMPolicyEngine] = None):
        """
        sites: {site_name: sim_config_path}
        """
@@ -15,6 +16,7 @@ class MetaScheduler:
        self.procs: Dict[str, Process] = {}
        self._rr = 0
        self._site_list = list(self.sites.keys())
        self.iam_policy = iam_policy
        assert len(self._site_list) > 0, "No sites configured"

    def start(self):
@@ -42,7 +44,7 @@ class MetaScheduler:
            p.join(timeout=5)

    # ---- policies ----
    def choose_site(self, job: FedJob) -> str:
    def choose_site(self, job: FedJob, candidates: Optional[List[str]] = None) -> str:
        """
        Simplest policy: choose the site with the shortest inbound queue length.
        Queue.qsize() is approximate on some platforms, but adequate for a first pass.
@@ -56,14 +58,44 @@ class MetaScheduler:
        #        best_site = site
        #assert best_site is not None
        #return best_site
        site = self._site_list[self._rr % len(self._site_list)]
        candidate_sites = candidates if candidates is not None else self._site_list
        if not candidate_sites:
            raise RuntimeError(f"No candidate sites available for job {job.job_id}")
        site = candidate_sites[self._rr % len(candidate_sites)]
        self._rr += 1
        return site

    def submit(self, job: FedJob) -> str:
        site = self.choose_site(job)
    def submit(self, job: FedJob, token: Optional[AccessToken] = None) -> str:
        candidate_sites = self._site_list
        if self.iam_policy is not None:
            eligible = sorted(self.iam_policy.eligible_sites(job, token, self._site_list))
            if not eligible:
                reason = "no authorized site for job"
                self.status_q.put({
                    "event": "IAM_DENY",
                    "job_id": job.job_id,
                    "subject": getattr(token, "subject", "unknown"),
                    "reason": reason,
                })
                raise PermissionError(f"IAM authorization denied: {reason}")
            candidate_sites = eligible

        site = self.choose_site(job, candidate_sites)
        if site is None:
            raise RuntimeError(f"choose_site() returned None; _site_list={getattr(self,'_site_list',None)}")

        if self.iam_policy is not None:
            decision = self.iam_policy.authorize(job, token, site)
            if not decision.allowed:
                self.status_q.put({
                    "event": "IAM_DENY",
                    "job_id": job.job_id,
                    "site": site,
                    "subject": getattr(token, "subject", "unknown"),
                    "reason": decision.reason,
                })
                raise PermissionError(f"IAM authorization denied: {decision.reason}")

        self.job_queues[site].put(job.__dict__)
        return site

+35 −5
Original line number Diff line number Diff line
#!/usr/bin/env python3
import time

from raps.metasched.iam import AccessToken, IAMPolicyEngine
from raps.metasched.metascheduler import MetaScheduler
from raps.metasched.types import FedJob

@@ -12,16 +13,45 @@ def main():
        "perlmutter": "config/perlmutter.yaml",
    }

    ms = MetaScheduler(sites)
    iam = IAMPolicyEngine(
        site_trust_tiers={
            "frontier": 3,
            "aurora": 2,
            "perlmutter": 1,
        }
    )
    ms = MetaScheduler(sites, iam_policy=iam)

    # Demo identity: can submit only to frontier/aurora and has resource limits.
    token = AccessToken(
        subject="alice@federation",
        roles={"researcher"},
        scopes={"federation:submit"},
        allowed_sites={"frontier", "aurora"},
        max_nodes=128,
        max_wall_time_s=7200,
        expiry_epoch_s=int(time.time()) + 3600,
    )

    # Start all site workers; MetaScheduler.start() blocks until all are READY
    ms.start()

    # Submit a few jobs (arrival stream)
    for i in range(10):
        job = FedJob(nodes_required=64, wall_time_s=3600, name=f"fed-{i}")
        site = ms.submit(job)
        print(f"submitted {job.job_id} -> {site}")
        job = FedJob(
            nodes_required=64,
            wall_time_s=3600,
            name=f"fed-{i}",
            meta={
                "account": "federated",
                "required_trust_tier": 2,  # This denies perlmutter (tier 1)
            },
        )
        try:
            site = ms.submit(job, token=token)
            print(f"submitted {job.job_id} by {token.subject} -> {site}")
        except PermissionError as exc:
            print(f"DENIED {job.job_id}: {exc}")
        time.sleep(0.2)

    # Listen for events for a while.
@@ -33,7 +63,7 @@ def main():
            ev = msg.get("event")

            # Print the signals that prove each site is alive / progressing
            if ev in ("HEARTBEAT", "METRICS", "ERROR"):
            if ev in ("HEARTBEAT", "METRICS", "ERROR", "IAM_DENY"):
                print(msg)

            # Optional: show enqueue events too (can be noisy)
+54 −0
Original line number Diff line number Diff line
import time
import queue

import pytest

from raps.metasched.iam import AccessToken, IAMPolicyEngine
from raps.metasched.metascheduler import MetaScheduler
from raps.metasched.types import FedJob


def _token(**kwargs):
    defaults = {
        "subject": "alice@federation",
        "roles": {"researcher"},
        "scopes": {"federation:submit"},
        "allowed_sites": {"frontier", "aurora"},
        "max_nodes": 128,
        "max_wall_time_s": 7200,
        "expiry_epoch_s": int(time.time()) + 3600,
    }
    defaults.update(kwargs)
    return AccessToken(**defaults)


def test_iam_policy_allows_and_denies_expected_cases():
    policy = IAMPolicyEngine(site_trust_tiers={"frontier": 3, "aurora": 2, "perlmutter": 1})
    job = FedJob(nodes_required=64, wall_time_s=3600, meta={"required_trust_tier": 2})
    token = _token()

    assert policy.authorize(job, token, "frontier").allowed
    assert policy.authorize(job, token, "aurora").allowed
    assert not policy.authorize(job, token, "perlmutter").allowed

    expired = _token(expiry_epoch_s=int(time.time()) - 1)
    assert not policy.authorize(job, expired, "frontier").allowed

    oversized = FedJob(nodes_required=1024, wall_time_s=3600)
    assert not policy.authorize(oversized, token, "frontier").allowed


def test_submit_raises_when_no_authorized_site(monkeypatch):
    # Use stdlib queues in tests to avoid OS semaphore requirements.
    monkeypatch.setattr("raps.metasched.metascheduler.Queue", queue.Queue)

    policy = IAMPolicyEngine(site_trust_tiers={"frontier": 3, "aurora": 2})
    ms = MetaScheduler(
        {"frontier": "config/frontier.yaml", "aurora": "config/aurora.yaml"},
        iam_policy=policy,
    )
    token = _token(allowed_sites={"frontier"}, max_nodes=32)
    job = FedJob(nodes_required=64, wall_time_s=3600)

    with pytest.raises(PermissionError):
        ms.submit(job, token=token)