Commit af3a9171 authored by Brewer, Wes's avatar Brewer, Wes
Browse files

Clean up raps_env.py. Add in check_env. Add sample command to README.md. Not working yet.

parent cfac2e28
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -74,6 +74,9 @@ For MIT Supercloud
    # Synthetic tests for verification studies:
    raps run-parts -x mit_supercloud -w multitenant

    # Reinforcement learning test case
    python main.py train-rl --system mit_supercloud/part-cpu -f /opt/data/mit_supercloud/202201

For Lumi

    # Synthetic test for Lumi:
+8 −102
Original line number Diff line number Diff line
import copy
import gym
from gym import spaces
import numpy as np
@@ -64,10 +63,10 @@ class RAPSEnv(gym.Env):
    def _create_engine(self):
        self.engine, workload_data, time_delta = Engine.from_sim_config(self.sim_config)
        self.engine.scheduler.env = self
        jobs = workload_data.jobs
        self.jobs = workload_data.jobs
        timestep_start = workload_data.telemetry_start
        timestep_end = workload_data.telemetry_end
        self.generator = self.engine.run_simulation(jobs, timestep_start, timestep_end, time_delta)
        self.generator = self.engine.run_simulation(self.jobs, timestep_start, timestep_end, time_delta)

    def _build_jobs(self):
        """
@@ -114,72 +113,11 @@ class RAPSEnv(gym.Env):
        else:
            raise ValueError("RAPSEnv requires either --workload or --replay to build jobs.")

#    def reset(self, seed=None, options=None):
#        super().reset(seed=seed)
#
#        self.jobs = copy.deepcopy(self.original_jobs)  # working copy
#
#        # Reset engine
#        self.engine.current_timestep = 0
#        #self.engine.reset()  # or clear state manually
#        power_manager = PowerManager(compute_node_power, **self.config)
#        flops_manager = FLOPSManager(**self.args_dict)
#        telemetry = Telemetry(**self.args_dict)
#        jobs, timestep_start, timestep_end = self._build_jobs()
#
#        self.engine = Engine(
#            power_manager=power_manager,
#            flops_manager=flops_manager,
#            jobs=jobs,
#            **self.args_dict
#        )
#
#        self.engine.timestep_start = timestep_start
#        self.engine.timestep_end = timestep_end
#        #self.engine.current_timestep = timestep_start
#
#        # Restart generator
#        self.generator = self.layout_manager.run_stepwise(
#            self.jobs,
#            timestep_start=self.timestep_start,
#            timestep_end=self.timestep_end,
#            time_delta=self.args_dict.get("time_delta"),
#        )
#
#        return self._get_state(), {}

    def reset(self, **kwargs):
        self.engine = self._create_engine()

    def reset2(self, **kwargs):
        completed = [j.id for j in self.jobs if j.current_state.name == "COMPLETED"]
        print(f"[RESET] Jobs already completed before deepcopy: {len(completed)}")

        super().reset(seed=42)
        # self.engine.jobs = self.jobs
        self.jobs = copy.deepcopy(self.original_jobs)  # working copy

        # self.engine.timestep_start = self.timestep_start
        # self.engine.timestep_end = self.timestep_end
        # self.engine.reset(self.jobs, self.timestep_start, self.timestep_end)

        # self.engine.current_timestep = self.timestep_start

        # self.engine.jobs = self.jobs  # repoint engine to fresh jobs
        # self.engine.completed_jobs = []
        # self.engine.queue.clear()
        # self.engine.running.clear()
        # self.engine.power_manager.history.clear()
        # self.engine.jobs_completed = 0

        self.generator = self.layout_manager.run_stepwise(
            self.jobs,
            timestep_start=self.timestep_start,
            timestep_end=self.timestep_end,
            time_delta=self.args_dict.get("time_delta", 1),
        )

        return self._get_state()
        obs = self._get_state()
        info = {}
        return obs, info

    def _compute_reward(self, tick_data):
        """
@@ -206,41 +144,6 @@ class RAPSEnv(gym.Env):

        return reward

#    def _compute_reward(self, tick_data):
#        """
#        Reward function: minimize carbon footprint per job completed.
#        Encourages the agent to complete jobs while keeping emissions low.
#        """
#        reward = 0.0
#
#        # Jobs completed this tick
#        jobs_completed = len(getattr(tick_data, "completed", []))
#
#        # Carbon emitted so far (metric tons CO2)
#        carbon_so_far = getattr(self.engine, "carbon emissions", 0.0)
#
#        if jobs_completed > 0:
#            # Reward is higher when more jobs finish with less carbon
#            reward = jobs_completed / (carbon_so_far + 1e-6)
#        else:
#            # Small penalty if no jobs finished (encourages progress)
#            reward = -0.01
#
#        return reward

    def _compute_reward2(self, tick_data, alpha=10.0, beta=1.0, gamma=2.0):
        completed = getattr(tick_data, "completed", None)
        jobs_completed = len(completed) if completed else 0
        power = self.power_manager.history[-1][1]
        queue_len = len(self.engine.queue)

        reward = alpha * jobs_completed - beta * power - gamma * queue_len

        print(f"[t={self.engine.current_timestep}] jobs_completed={jobs_completed}, "
              f"power={power}, queue_len={queue_len}, reward={reward}")

        return reward

    def step(self, action):
        queue = self.engine.queue
        invalid_action = False
@@ -268,6 +171,9 @@ class RAPSEnv(gym.Env):
        else:
            reward = self._compute_reward(tick_data)

        # clip reward
        reward = np.clip(reward, -10.0, 10.0)

        # Print stats
        stats = self.get_stats()
        print_stats(stats)
+2 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ def train_rl_add_parser(subparsers: SubParsers):

def train_rl(rl_config: SingleSimConfig):
    from stable_baselines3 import PPO
    from stable_baselines3.common.env_checker import check_env
    from raps.envs.raps_env import RAPSEnv

    args_dict = rl_config.get_legacy_args_dict()
@@ -32,6 +33,7 @@ def train_rl(rl_config: SingleSimConfig):
    args_dict['args'] = rl_config.get_legacy_args()

    env = RAPSEnv(rl_config)
    check_env(RAPSEnv(env))

    model = PPO(
        "MlpPolicy",