Commit a546f993 authored by Brewer, Wes's avatar Brewer, Wes
Browse files

A number of simplifications to raps_env.py

parent 05527c71
Loading
Loading
Loading
Loading
+15 −67
Original line number Diff line number Diff line
@@ -4,23 +4,14 @@ from gym import spaces
import numpy as np

from raps.engine import Engine
from raps.power import PowerManager, compute_node_power
from raps.flops import FLOPSManager
from raps.telemetry import Telemetry
from raps.workload import Workload
from raps.ui import LayoutManager
from raps.schedulers.rl import Scheduler
# from raps.resmgr.default import MultiTenantResourceManager as ResourceManager
from raps.resmgr.default import ExclusiveNodeResourceManager as ResourceManager
from raps.stats import get_engine_stats, get_job_stats, get_scheduler_stats, get_network_stats

from stable_baselines3.common.logger import Logger, HumanOutputFormat
import sys

logger = Logger(
    folder=None,  # no log file, just stdout
    output_formats=[HumanOutputFormat(sys.stdout)]
)
logger = Logger(folder=None, output_formats=[HumanOutputFormat(sys.stdout)])


def print_stats(stats, step=0):
@@ -56,65 +47,11 @@ class RAPSEnv(gym.Env):

    metadata = {"render.modes": ["human"]}

    def __init__(self, **kwargs):
    def __init__(self, sim_config):
        super().__init__()
        # Store everything in self.args
        self.args_dict = kwargs  # dict
        self.cli_args = kwargs.get("args")  # Namespace
        self.config = kwargs.get("config")
        if self.cli_args is None:
            raise ValueError("RAPSEnv requires 'args' (argparse.Namespace) in kwargs")
        if self.config is None:
            raise ValueError("RAPSEnv requires 'config' in kwargs")

        # --- managers (minimal versions) ---
        self.power_manager = PowerManager(compute_node_power, **self.config)
        self.flops_manager = FLOPSManager(**self.args_dict)
        self.telemetry = Telemetry(**self.args_dict)

        # --- Build initial jobs & time bounds ---
        self.jobs, self.timestep_start, self.timestep_end = self._build_jobs()
        self.original_jobs = self.jobs               # keep pristine version

        self.engine = Engine(
            power_manager=self.power_manager,
            flops_manager=self.flops_manager,
            jobs=self.jobs,
            **self.args_dict
        )

        resmgr = ResourceManager(
            total_nodes=self.config["TOTAL_NODES"],
            down_nodes=self.config.get("DOWN_NODES", []),
            config=self.config
        )

        # Plug in RL scheduler
        self.scheduler = Scheduler(
            config=self.config,
            policy="fcfs",   # or None if you want no heuristic fallback
            resource_manager=resmgr,
            env=self
        )
        self.engine.scheduler = self.scheduler

        self.layout_manager = LayoutManager(
            self.args_dict.get("layout"), engine=self.engine,
            debug=self.args_dict.get("debug", False),
            total_timesteps=self.args_dict.get("time", 1000),
            args_dict=self.args_dict,
            **self.config
        )

        self.timestep_start = 0
        self.timestep_end = getattr(self.cli_args, "episode_length")

        self.generator = self.layout_manager.run_stepwise(
            self.jobs,
            timestep_start=self.timestep_start,
            timestep_end=self.timestep_end,
            time_delta=self.args_dict.get("time_delta"),
        )
        self.sim_config = sim_config
        self.engine = self._create_engine()

        # --- RL spaces ---
        max_jobs = 100
@@ -124,6 +61,14 @@ class RAPSEnv(gym.Env):
        )
        self.action_space = spaces.Discrete(max_jobs)

    def _create_engine(self):
        self.engine, workload_data, time_delta = Engine.from_sim_config(self.sim_config)
        self.engine.scheduler.env = self
        jobs = workload_data.jobs
        timestep_start = workload_data.telemetry_start
        timestep_end = workload_data.telemetry_end
        self.generator = self.engine.run_simulation(jobs, timestep_start, timestep_end, time_delta)

    def _build_jobs(self):
        """
        Build a job list either from synthetic workload (--workload)
@@ -204,6 +149,9 @@ class RAPSEnv(gym.Env):
#        return self._get_state(), {}

    def reset(self, **kwargs):
        self.engine = self._create_engine()

    def reset2(self, **kwargs):
        completed = [j.id for j in self.jobs if j.current_state.name == "COMPLETED"]
        print(f"[RESET] Jobs already completed before deepcopy: {len(completed)}")

+11 −13
Original line number Diff line number Diff line
@@ -2,12 +2,6 @@ from raps.sim_config import SingleSimConfig, SIM_SHORTCUTS
from raps.utils import SubParsers, pydantic_add_args, read_yaml


class RLConfig(SingleSimConfig):
    # Reinforcement Learning
    episode_length: int = 1000
    """ Number of timesteps per RL episode (default 1000) """


def train_rl_add_parser(subparsers: SubParsers):
    parser = subparsers.add_parser("train-rl", description="""
        Example usage:
@@ -17,15 +11,18 @@ def train_rl_add_parser(subparsers: SubParsers):
        YAML sim config file, can be used to configure an experiment instead of using CLI
        flags. Pass "-" to read from stdin.
    """)
    model_validate = pydantic_add_args(parser, RLConfig, model_config={
    model_validate = pydantic_add_args(parser, SingleSimConfig, model_config={
        "cli_shortcuts": SIM_SHORTCUTS,
    })
    parser.set_defaults(
        impl=lambda args: train_rl(model_validate(args, read_yaml(args.config_file)))
    )

    def impl(args):
        model = model_validate(args, read_yaml(args.config_file))
        model.scheduler = "rl"
        train_rl(model)
    parser.set_defaults(impl=impl)


def train_rl(rl_config: RLConfig):
def train_rl(rl_config: SingleSimConfig):
    from stable_baselines3 import PPO
    from raps.envs.raps_env import RAPSEnv

@@ -34,14 +31,14 @@ def train_rl(rl_config: RLConfig):
    args_dict['config'] = config
    args_dict['args'] = rl_config.get_legacy_args()

    env = RAPSEnv(**args_dict)
    env = RAPSEnv(rl_config)

    model = PPO(
        "MlpPolicy",
        env,
        n_steps=512,         # shorter rollouts (quicker feedback loop)
        batch_size=128,      # must divide n_steps evenly
        n_epochs=10,         # # of minibatch passes per update
        n_epochs=10,         # of minibatch passes per update
        gamma=0.99,          # discount (keeps long-term credit)
        learning_rate=3e-4,  # default Adam lr, can try 1e-4 if unstable
        ent_coef=0.01,       # encourage exploration
@@ -53,6 +50,7 @@ def train_rl(rl_config: RLConfig):

    # Output stats
    stats = env.get_stats()
    print(stats)

    # Save trained model
    model.save("ppo_raps")