Commit 6dc3ffa2 authored by Hines, Jesse's avatar Hines, Jesse
Browse files

Add skeleton for dataset download

parent c95115b2
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ def main(cli_args: list[str] | None = None):

    from raps.run_sim import run_sim_add_parser, run_parts_sim_add_parser, show_add_parser
    from raps.workloads import run_workload_add_parser
    from raps.telemetry import run_telemetry_add_parser
    from raps.telemetry import run_telemetry_add_parser, run_download_add_parser
    from raps.train_rl import train_rl_add_parser

    parser = argparse.ArgumentParser(
@@ -85,6 +85,7 @@ def main(cli_args: list[str] | None = None):
    show_add_parser(subparsers)
    run_workload_add_parser(subparsers)
    run_telemetry_add_parser(subparsers)
    run_download_add_parser(subparsers)
    train_rl_add_parser(subparsers)
    shell_completion_add_parser(subparsers)

+30 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ helper functions for data encryption and conversion between node name and index
from typing import Literal
import random
from pathlib import Path
# import json
from datetime import datetime
from typing import Optional
from types import ModuleType
import importlib
@@ -21,6 +21,7 @@ from pydantic import model_validator
from raps.sim_config import SimConfig
from raps.system_config import get_system_config
from raps.job import Job, job_dict
from raps.utils import AutoAwareDatetime
import matplotlib.pyplot as plt
from raps.plotting import (
    plot_jobs_gantt,
@@ -183,6 +184,13 @@ class Telemetry:
        assert self.dataloader
        return self.dataloader.load_live_data(**self.kwargs)

    def download_data(self, dest: Path, start: datetime | None, end: datetime | None):
        """Load telemetry data using custom data loaders."""
        assert self.dataloader
        if not hasattr(self.dataloader, "download"):
            raise ValueError("Dataloader does not support download")
        return self.dataloader.download(dest, start, end)

    def node_index_to_name(self, index: int):
        """ Convert node index into a name"""
        assert self.dataloader
@@ -359,3 +367,24 @@ def run_telemetry(args: TelemetryArgs):
        print(f"Saved to: {filename}")
    else:
        plt.show()


class DownloadArgs(RAPSBaseModel):
    system: str
    dest: ResolvedPath
    start: AutoAwareDatetime | None = None
    end: AutoAwareDatetime | None = None


def run_download_add_parser(subparsers: SubParsers):
    parser = subparsers.add_parser("download", description="""
        Download telemetry data
    """)
    model_validate = pydantic_add_args(parser, DownloadArgs)
    parser.set_defaults(impl=lambda args: run_download(model_validate(args, {})))


def run_download(args: DownloadArgs):
    config = get_system_config(args.system).get_legacy()
    td = Telemetry(system = args.system, config = config)
    td.download_data(args.dest, args.start, args.end)