Commit ceb90f65 authored by Duggan, John's avatar Duggan, John
Browse files

Fix mock tool code

parent acd65488
Loading
Loading
Loading
Loading
Loading
+1 −5
Original line number Diff line number Diff line
@@ -63,11 +63,7 @@ class IPSFastranTool(BasicTool):
        return self.tool, tool_params

    def prepare_superfacility(self) -> Tuple[Tool, Parameters]:
        self.tool = SuperfacilityTool(
            self.model.resource_params.token_url,
            self.model.resource_params.client_id,
            self.model.resource_params.private_key,
        )
        self.tool = SuperfacilityTool(self.model)

        return self.tool, Parameters()

+1 −0
Original line number Diff line number Diff line
@@ -56,6 +56,7 @@ class ResourceParameters(BaseModel):
    # Superfacility API Options
    client_id: str = Field(default="", title="Superfacility API Client ID")
    private_key: str = Field(default="", title="Superfacility API Private Key")
    project_id: str = Field(default="", title="NERSC Project ID")
    token_url: str = Field(default="https://oidc.nersc.gov/c2id/token")


+170 −21
Original line number Diff line number Diff line
"""Tool class for running via the Superfacility API."""

from time import time
import json
from time import sleep, time
from typing import Any, Dict

from authlib.integrations.requests_client import OAuth2Session
from authlib.oauth2.rfc7523 import PrivateKeyJWT
from nova.galaxy.job import JobStatus
from nova.galaxy.job import JobStatus, WorkState

from ips_fastran_gui.app.models.main_model import MainModel

# Pull new access token 60 seconds before expiration
REFRESH_BUFFER = 60

# Check task status every 10 seconds
STATUS_INTERVAL = 10


class SuperfacilityTool:
    """Tool class for running via the Superfacility API."""

    def __init__(self, token_url: str, client_id: str, private_key: str) -> None:
    def __init__(self, model: MainModel) -> None:
        self.model = model

        self.access_token_obj: Dict[str, Any] = {}
        self.session: OAuth2Session = OAuth2Session(
            self.model.resource_params.client_id,
            self.model.resource_params.private_key,
            PrivateKeyJWT(self.model.resource_params.token_url),
            grant_type="client_credentials",
            token_endpoint=self.model.resource_params.token_url,
        )
        self.state = JobStatus()
        self.job_id = ""
        self.task_id = ""

        self.token_url = token_url
        self.client_id = client_id
        self.private_key = private_key
        self.stdout = ""
        self.stderr = ""
        self.last_status_check = 0
        self.last_stdout_check = 0
        self.last_stderr_check = 0

    def get_access_token(self) -> str:
    def _refresh_token(self) -> None:
        expiration = self.access_token_obj.get("expires_at", 0) - REFRESH_BUFFER
        current_time = int(time())
        if not self.access_token_obj or current_time > expiration:
            session = OAuth2Session(
                self.client_id,
                self.private_key,
                PrivateKeyJWT(self.token_url),
                grant_type="client_credentials",
                token_endpoint=self.token_url,
            )
            self.access_token_obj = session.fetch_token()
            self.access_token_obj = self.session.fetch_token()

    def cancel(self) -> None:
        if not self.job_id:
            return

        self.state.state = WorkState.CANCELING

        self._refresh_token()
        self.session.delete(f"https://api.nersc.gov/api/v1.2/compute/jobs/perlmutter/{self.job_id}")

        return self.access_token_obj.get("access_token", "")
        self.state.state = WorkState.CANCELED

    def get_full_status(self) -> JobStatus:
        _ = self.get_access_token()
        current_time = int(time())
        if not self.task_id or current_time < self.last_status_check + STATUS_INTERVAL:
            return self.state
        self.last_status_check = current_time

        if self.state.state == WorkState.CANCELING:
            pass
        elif not self.job_id:
            self.get_submission_status()
        else:
            self.get_job_status()

        return self.state

    def get_job_status(self) -> None:
        self._refresh_token()
        response = self.session.get(f"https://api.nersc.gov/api/v1.2/compute/jobs/perlmutter/{self.job_id}?cached=0")
        data = response.json()
        output = data.get("output", [])

        if not output:
            self.state.state = WorkState.FINISHED
            return

        if output[0]["state"] == "PENDING":
            return
        if output[0]["state"] == "RUNNING":
            self.state.state = WorkState.RUNNING

    def get_stderr(self, *args: Any, **kwargs: Any) -> str:
        current_time = int(time())
        if self.state.state != WorkState.RUNNING or current_time < self.last_stderr_check + STATUS_INTERVAL:
            return self.stderr
        self.last_stderr_check = current_time

        new_content = self.read_file("ips.err")
        return_value = new_content.removeprefix(self.stderr)

        self.stderr = new_content

        return JobStatus()
        return return_value

    def get_stdout(self, *args: Any, **kwargs: Any) -> str:
        current_time = int(time())
        if self.state.state != WorkState.RUNNING or current_time < self.last_stdout_check + STATUS_INTERVAL:
            return self.stdout
        self.last_stdout_check = current_time

        new_content = self.read_file("ips.out")
        return_value = new_content.removeprefix(self.stdout)

        self.stdout = new_content

        return return_value

    def get_submission_status(self) -> None:
        self._refresh_token()
        response = self.session.get(f"https://api.nersc.gov/api/v1.2/tasks/{self.task_id}")

        data = response.json()
        match data["status"]:
            case "new":
                self.state.state = WorkState.QUEUED
            case "completed":
                result = json.loads(data["result"])
                if result["status"] == "error":
                    self.state.state = WorkState.ERROR
                    self.state.details = {"message": result["error"]}
                elif result["status"] == "ok":
                    self.job_id = result["jobid"]
                else:
                    raise Exception("Completed task has unexpected status:", result["status"])
            case _:
                raise Exception("Unexpected submission task status:", data["status"])

    def read_file(self, path: str) -> str:
        self._refresh_token()
        response = self.session.post(
            "https://api.nersc.gov/api/v1.2/utilities/command/perlmutter", data={"executable": f"cat {path}"}
        )
        task = response.json()
        task_id = task["task_id"]

        complete = False
        data = {}
        while not complete:
            self._refresh_token()
            response = self.session.get(f"https://api.nersc.gov/api/v1.2/tasks/{task_id}")
            data = response.json()

            complete = data["status"] == "completed"
            if not complete:
                sleep(0.1)

        result = json.loads(data["result"])
        if result["status"] == "ok":
            return result["output"]
        return result["error"]

    def run(self, *args: Any, **kwargs: Any) -> None:
        hours, minutes = divmod(self.model.resource_params.time_limit, 60)

        self._refresh_token()
        response = self.session.post(
            "https://api.nersc.gov/api/v1.2/compute/jobs/perlmutter",
            data={
                "isPath": False,
                "job": f"""#!/bin/bash -l
#SBATCH -p {self.model.resource_params.partition}
#SBATCH -N {self.model.resource_params.number_of_nodes}
#SBATCH -t {hours:02d}:{minutes:02d}:00
#SBATCH -A {self.model.resource_params.project_id}
#SBATCH -J ips_fastran
#SBATCH -e ips.err
#SBATCH -o ips.out
#SBATCH -C cpu

module load python
source activate /global/common/software/atom/perlmutter/cesol/conda/dev

export SHOT_NUMBER={self.model.config.shot_number}
export TIME_ID={self.model.config.time_id}

ips.py --simulation=fastran_scenario.config --platform=perlmutter_cpu_node.conf --log=ips.log 1>ips.out 2>ips.err

conda deactivate
""",
            },
        )

    # def run(self) -> None:
    #     access_token = self.get_access_token()
        result = response.json()
        if result["status"] != "OK":
            raise Exception(result["error"])

    #     print("running with access token", access_token)
        self.task_id = result["task_id"]
+11 −6
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@

from nova.trame.view.components import InputField
from nova.trame.view.layouts import GridLayout, VBoxLayout
from trame.widgets import html


class ResourcesTab:
@@ -26,11 +27,15 @@ class ResourcesTab:
            InputField(v_model="resource_params.partition")

        with GridLayout(
            v_if="resource_params.run_location.includes('Superfacility')",
            columns=2,
            gap="0.5em",
            stretch=True,
            valign="start",
            v_if="resource_params.run_location.includes('Superfacility')", columns=2, gap="0.5em", stretch=True
        ):
            InputField(v_model="resource_params.client_id", classes="align-self-start")
            with VBoxLayout(classes="align-self-start"):
                InputField(v_model="resource_params.client_id")
                InputField(v_model="resource_params.project_id")
                html.A(
                    "How Do I Get Superfacility API Credentials?",
                    classes="pl-2",
                    href="https://docs.nersc.gov/services/sfapi/authentication/#client",
                    target="_blank",
                )
            InputField(v_model="resource_params.private_key", type="textarea")