Commit acd65488 authored by Duggan, John's avatar Duggan, John
Browse files

Restructure UI and set up access token exchange

parent 1a3fac87
Loading
Loading
Loading
Loading
Loading
+334 −285

File changed.

Preview size limit exceeded, changes collapsed.

+1 −1
Original line number Diff line number Diff line
@@ -28,7 +28,7 @@ trame-datagrid = ">=0.2.2"
trame-code = ">=1.0.2"
nova-galaxy = ">=0.11.1"
netcdf4 = ">=1.7.2"

Authlib = "*"

[tool.pixi.feature.dev.pypi-dependencies]
mypy = ">=1.10.0"
+21 −4
Original line number Diff line number Diff line
@@ -8,7 +8,8 @@ from typing import Callable, List, Tuple
from nova.galaxy import Connection, Dataset, Parameters, Tool
from nova.galaxy.interfaces import BasicTool

from ips_fastran_gui.app.models.main_model import MainModel
from ips_fastran_gui.app.models.main_model import MainModel, RunLocationOption
from ips_fastran_gui.app.models.superfacility import SuperfacilityTool

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@@ -25,10 +26,17 @@ class IPSFastranTool(BasicTool):
        self.store = connection.get_data_store("ips_fastran")

    def prepare_tool(self) -> Tuple[Tool, Parameters]:
        # TODO: This method needs to be rewritten to support the SF API for use when not connected to Galaxy.
        # Need to talk with others about how they want to authenticate (access token provided or generated) before
        # doing this.
        match self.model.resource_params.run_location:
            case RunLocationOption.local:
                return None, None
            case RunLocationOption.sf_perlmutter:
                return self.prepare_superfacility()
            case RunLocationOption.galaxy_perlmutter:
                return self.prepare_galaxy()

        return None, None

    def prepare_galaxy(self) -> Tuple[Tool, Parameters]:
        # Prepare file ingestion into Galaxy
        self.inputs_dataset = Dataset(name="inputs.zip")
        with zipfile.ZipFile("input.zip", "w") as zip_obj:
@@ -54,6 +62,15 @@ class IPSFastranTool(BasicTool):

        return self.tool, tool_params

    def prepare_superfacility(self) -> Tuple[Tool, Parameters]:
        self.tool = SuperfacilityTool(
            self.model.resource_params.token_url,
            self.model.resource_params.client_id,
            self.model.resource_params.private_key,
        )

        return self.tool, Parameters()

    def get_output_paths(self) -> List[str]:
        outputs = self.tool.get_results()
        collection = outputs.get_collection("outputs")
+21 −1
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@

import json
import zipfile
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List
@@ -29,15 +30,34 @@ class Config(BaseModel):
    result_files: List[str] = Field(default=[])


class RunLocationOption(str, Enum):
    """Defines available locations for running IPS Fastran."""

    local = "Local Machine"
    galaxy_perlmutter = "Perlmutter (via Galaxy)"  # Perlmutter via the Galaxy API
    sf_perlmutter = "Perlmutter (via Superfacility API)"  # Perlmutter via the Superfacility API


class ResourceParameters(BaseModel):
    """Contains resource parameters for running the job."""

    access_token: str = Field(default="", title="Superfacility API Access Token")
    # General Options
    run_location: RunLocationOption = Field(default=RunLocationOption.local, title="Where Will You Run IPS Fastran?")

    # Local Run Options
    executable: str = Field(default="", title="IPS Fastran Executable (use full path)")

    # Perlmutter Options
    partition: str = Field(default="debug", title="Partition")
    number_of_nodes: int = Field(default=1, ge=1, le=2, title="Number of Nodes")
    time_limit: int = Field(default=10, ge=1, le=10080, title="Time Limit (minutes)")
    tasks_per_node: int = Field(default=1, ge=0, le=128, title="Number of Tasks Per Node (0 to disable)")

    # Superfacility API Options
    client_id: str = Field(default="", title="Superfacility API Client ID")
    private_key: str = Field(default="", title="Superfacility API Private Key")
    token_url: str = Field(default="https://oidc.nersc.gov/c2id/token")


class PlotJSON(BaseModel):
    """Contains the plot.json file contents."""
+47 −0
Original line number Diff line number Diff line
"""Tool class for running via the Superfacility API."""

from time import time
from typing import Any, Dict

from authlib.integrations.requests_client import OAuth2Session
from authlib.oauth2.rfc7523 import PrivateKeyJWT
from nova.galaxy.job import JobStatus

# Pull new access token 60 seconds before expiration
REFRESH_BUFFER = 60


class SuperfacilityTool:
    """Tool class for running via the Superfacility API."""

    def __init__(self, token_url: str, client_id: str, private_key: str) -> None:
        self.access_token_obj: Dict[str, Any] = {}

        self.token_url = token_url
        self.client_id = client_id
        self.private_key = private_key

    def get_access_token(self) -> str:
        expiration = self.access_token_obj.get("expires_at", 0) - REFRESH_BUFFER
        current_time = int(time())
        if not self.access_token_obj or current_time > expiration:
            session = OAuth2Session(
                self.client_id,
                self.private_key,
                PrivateKeyJWT(self.token_url),
                grant_type="client_credentials",
                token_endpoint=self.token_url,
            )
            self.access_token_obj = session.fetch_token()

        return self.access_token_obj.get("access_token", "")

    def get_full_status(self) -> JobStatus:
        _ = self.get_access_token()

        return JobStatus()

    # def run(self) -> None:
    #     access_token = self.get_access_token()

    #     print("running with access token", access_token)
Loading