Commit 82f89257 authored by Brewer, Wes's avatar Brewer, Wes
Browse files

Get RL working again...

parent af3a9171
Loading
Loading
Loading
Loading
+18 −9
Original line number Diff line number Diff line
@@ -61,12 +61,13 @@ class RAPSEnv(gym.Env):
        self.action_space = spaces.Discrete(max_jobs)

    def _create_engine(self):
        self.engine, workload_data, time_delta = Engine.from_sim_config(self.sim_config)
        self.engine.scheduler.env = self
        engine, workload_data, time_delta = Engine.from_sim_config(self.sim_config)
        engine.scheduler.env = self
        self.jobs = workload_data.jobs
        timestep_start = workload_data.telemetry_start
        timestep_end = workload_data.telemetry_end
        self.generator = self.engine.run_simulation(self.jobs, timestep_start, timestep_end, time_delta)
        self.generator = engine.run_simulation(self.jobs, timestep_start, timestep_end, time_delta)
        return engine

    def _build_jobs(self):
        """
@@ -116,8 +117,7 @@ class RAPSEnv(gym.Env):
    def reset(self, **kwargs):
        self.engine = self._create_engine()
        obs = self._get_state()
        info = {}
        return obs, info
        return obs

    def _compute_reward(self, tick_data):
        """
@@ -145,6 +145,9 @@ class RAPSEnv(gym.Env):
        return reward

    def step(self, action):
        if self.engine is None:
            raise RuntimeError("Engine not initialized. Did you forget to call reset()?")

        queue = self.engine.queue
        invalid_action = False

@@ -153,11 +156,17 @@ class RAPSEnv(gym.Env):
            invalid_action = True
        else:
            job = queue[int(action)]
            available = len(self.engine.scheduler.resource_manager.available_nodes)
            if job.nodes_required <= available:
                # Valid scheduling
            available_nodes = self.engine.scheduler.resource_manager.available_nodes

            if job.nodes_required <= len(available_nodes):
                # Just pick the first available node (simplest placement policy)
                node_id = available_nodes[0]
                self.engine.scheduler.place_job_and_manage_queues(
                    job, queue, self.engine.running, self.engine.current_timestep
                    job,
                    queue,
                    self.engine.running,
                    self.engine.current_timestep,
                    node_id,
                )
            else:
                invalid_action = True
+0 −2
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ def train_rl_add_parser(subparsers: SubParsers):

def train_rl(rl_config: SingleSimConfig):
    from stable_baselines3 import PPO
    from stable_baselines3.common.env_checker import check_env
    from raps.envs.raps_env import RAPSEnv

    args_dict = rl_config.get_legacy_args_dict()
@@ -33,7 +32,6 @@ def train_rl(rl_config: SingleSimConfig):
    args_dict['args'] = rl_config.get_legacy_args()

    env = RAPSEnv(rl_config)
    check_env(RAPSEnv(env))

    model = PPO(
        "MlpPolicy",