Loading .gitlab-ci.yml +2 −0 Original line number Diff line number Diff line Loading @@ -80,6 +80,7 @@ deploy: --atomic --set config.db.username=$PGDB_USER --set config.db.password=$PGDB_PASS --set config.auth.oidc_secret=$OIDC_SECRET --set ingress.url=api-$CI_BUILD_REF_SLUG.k8s.arm.gov --set dataService.guc.loadBalancerIP="" Loading Loading @@ -126,5 +127,6 @@ deploy:prod: --atomic --set config.db.username=$PGDB_USER --set config.db.password=$PGDB_PASS --set config.auth.oidc_secret=$OIDC_SECRET --set ingress.url=api.k8s.arm.gov # --set ingress.className=external app/ADL/api/auth.py 0 → 100644 +42 −0 Original line number Diff line number Diff line import time from logging import getLogger import jwt import requests from fastapi import Depends, HTTPException, Request, status from fastapi.security import (HTTPAuthorizationCredentials, HTTPBasic, HTTPBearer) from jwt import PyJWKClient from jwt.exceptions import PyJWTError from . import config log = getLogger(__name__) class JWTBearer(HTTPBearer): def __init__(self, auto_error: bool = True): super(JWTBearer, self).__init__(auto_error=auto_error) async def __call__(self, request: Request): credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request) if credentials: if not credentials.scheme == "Bearer": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authentication scheme.") try: openid_config = requests.get(f'{str(config.oidc_url)}/.well-known/openid-configuration').json() jwks_client = PyJWKClient(openid_config['jwks_uri']) signing_key = jwks_client.get_signing_key_from_jwt(credentials.credentials) auth_data = jwt.decode( credentials.credentials, signing_key.key, algorithms=[config.jwt_algorithm], audience=['ADL'], ) except PyJWTError as e: log.exception('Unable to validate token') raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unable to validate token.") JWTAuth = Depends(JWTBearer()) BasicAuth = Depends(HTTPBasic()) app/ADL/api/config.py +7 −2 Original line number Diff line number Diff line import sys from logging import getLogger from pydantic import (BaseSettings, DirectoryPath, FilePath, PostgresDsn, SecretStr) from pydantic import (AnyHttpUrl, BaseSettings, DirectoryPath, FilePath, PostgresDsn, SecretStr) from sqlalchemy.engine.url import URL logger = getLogger(__name__) Loading @@ -24,6 +24,11 @@ class Settings(BaseSettings): db_host: str = "localhost" db_port: int = 5432 db_name: str = "arm_all" jwt_algorithm: str = 'RS256' oidc_url: AnyHttpUrl = 'https://dex.k8s.arm.gov' oidc_client: str = 'ADL' oidc_secret: SecretStr = 'client_secret' host_url: AnyHttpUrl = 'https://api.k8s.arm.gov' # TODO: break this out into separate models so pieces can be overridden logging: dict = { "version": 1, Loading app/ADL/api/models.py +11 −14 Original line number Diff line number Diff line Loading @@ -2,8 +2,9 @@ from datetime import datetime import databases import sqlalchemy from ormar import Integer, Model, String, DateTime, Boolean, ForeignKey from ormar import Boolean, DateTime, ForeignKey, Integer, Model, String from ormar.decorators import property_field from pydantic import AnyHttpUrl from . import config Loading Loading @@ -57,19 +58,6 @@ class Files(ARMPath): site = self.datastream[3:] return f'/{site}/{self.datastream}/{self.versioned_filename}' # patched_field = ForeignKey(Files, related_name='contents') # patched_field.primary_key=True # class FileContents(Model): # class Meta: # tablename = 'file_contents' # metadata = dr_meta # database=database # versioned_filename: Files = patched_field # start_time: datetime = DateTime() # end_time: datetime = DateTime() # deleted: bool = Boolean() # n_samples: int = Integer() class ARMFiles(Model): class Meta: Loading @@ -82,3 +70,12 @@ class ARMFiles(Model): md5_checksum: str = String(max_length=32, min_length=32, regex=r'[a-fA-F\d]+') start_time: datetime = DateTime() end_time: datetime = DateTime() @property_field def name(self) -> str: site = self.datastream[:3] return f'{site}/{self.datastream}/{self.versioned_filename}' @property_field def url(self) -> AnyHttpUrl: return f'{config.host_url}/fs/stream/{self.name}' app/ADL/api/v1.py +35 −4 Original line number Diff line number Diff line from collections.abc import Collection from datetime import datetime from logging import getLogger from typing import List, Union from collections.abc import Collection import fsspec import requests from fastapi.responses import StreamingResponse from fastapi.routing import APIRouter from fastapi_versioning import versioned_api_route from starlette.types import Send from . import config from .auth import BasicAuth, JWTAuth from .models import ARMFiles, Datastreams, Files, Sites router = APIRouter(route_class=versioned_api_route(1)) Loading Loading @@ -66,7 +68,7 @@ class FsspecStreamingResponse(StreamingResponse): @router.get('/fs/stream/{site}/{datastream}/{filename}', response_class=StreamingResponse) async def stream(site: Sites.site_code, datastream: Datastreams.datastream, filename: Files.versioned_filename): async def stream(site: Sites.site_code, datastream: Datastreams.datastream, filename: Files.versioned_filename, auth=JWTAuth): url = f'simplecache::guc:///f1/arm/{site}/{datastream}/{filename}' log.debug('streaming %s', url) Loading @@ -90,7 +92,7 @@ async def stream(site: Sites.site_code, datastream: Datastreams.datastream, file @router.get('/fs/file_list/') async def file_list(datastreams: Datastreams.datastream, start_time: datetime, end_time: datetime, include_deleted: bool=False): async def file_list(datastreams: Datastreams.datastream, start_time: datetime, end_time: datetime): if not isinstance(datastreams, Collection): datastreams = [datastreams] Loading @@ -99,3 +101,32 @@ async def file_list(datastreams: Datastreams.datastream, start_time: datetime, e start_time__gte=start_time, end_time__lte=end_time, ).all() @router.post('/get_token') async def get_token(auth=BasicAuth): openid_config = requests.get(f'{str(config.oidc_url)}.well-known/openid-configuration').json() token_url = openid_config['token_endpoint'] return requests.post( token_url, data={ 'grant_type': 'password', 'username': auth.username, 'password': auth.password, 'scope': ' '.join((scope for scope in openid_config['scopes_supported'])), }, auth=(config.oidc_client, config.oidc_secret.get_secret_value()) ).json() @router.post('/refresh_token') async def refresh_token(refresh_token: str): openid_config = requests.get(f'{str(config.oidc_url)}.well-known/openid-configuration').json() token_url = openid_config['token_endpoint'] return requests.post( token_url, data={ 'grant_type': 'refresh_token', 'refresh_token': refresh_token, }, auth=(config.oidc_client, config.oidc_secret.get_secret_value()) ).json() Loading
.gitlab-ci.yml +2 −0 Original line number Diff line number Diff line Loading @@ -80,6 +80,7 @@ deploy: --atomic --set config.db.username=$PGDB_USER --set config.db.password=$PGDB_PASS --set config.auth.oidc_secret=$OIDC_SECRET --set ingress.url=api-$CI_BUILD_REF_SLUG.k8s.arm.gov --set dataService.guc.loadBalancerIP="" Loading Loading @@ -126,5 +127,6 @@ deploy:prod: --atomic --set config.db.username=$PGDB_USER --set config.db.password=$PGDB_PASS --set config.auth.oidc_secret=$OIDC_SECRET --set ingress.url=api.k8s.arm.gov # --set ingress.className=external
app/ADL/api/auth.py 0 → 100644 +42 −0 Original line number Diff line number Diff line import time from logging import getLogger import jwt import requests from fastapi import Depends, HTTPException, Request, status from fastapi.security import (HTTPAuthorizationCredentials, HTTPBasic, HTTPBearer) from jwt import PyJWKClient from jwt.exceptions import PyJWTError from . import config log = getLogger(__name__) class JWTBearer(HTTPBearer): def __init__(self, auto_error: bool = True): super(JWTBearer, self).__init__(auto_error=auto_error) async def __call__(self, request: Request): credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request) if credentials: if not credentials.scheme == "Bearer": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authentication scheme.") try: openid_config = requests.get(f'{str(config.oidc_url)}/.well-known/openid-configuration').json() jwks_client = PyJWKClient(openid_config['jwks_uri']) signing_key = jwks_client.get_signing_key_from_jwt(credentials.credentials) auth_data = jwt.decode( credentials.credentials, signing_key.key, algorithms=[config.jwt_algorithm], audience=['ADL'], ) except PyJWTError as e: log.exception('Unable to validate token') raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unable to validate token.") JWTAuth = Depends(JWTBearer()) BasicAuth = Depends(HTTPBasic())
app/ADL/api/config.py +7 −2 Original line number Diff line number Diff line import sys from logging import getLogger from pydantic import (BaseSettings, DirectoryPath, FilePath, PostgresDsn, SecretStr) from pydantic import (AnyHttpUrl, BaseSettings, DirectoryPath, FilePath, PostgresDsn, SecretStr) from sqlalchemy.engine.url import URL logger = getLogger(__name__) Loading @@ -24,6 +24,11 @@ class Settings(BaseSettings): db_host: str = "localhost" db_port: int = 5432 db_name: str = "arm_all" jwt_algorithm: str = 'RS256' oidc_url: AnyHttpUrl = 'https://dex.k8s.arm.gov' oidc_client: str = 'ADL' oidc_secret: SecretStr = 'client_secret' host_url: AnyHttpUrl = 'https://api.k8s.arm.gov' # TODO: break this out into separate models so pieces can be overridden logging: dict = { "version": 1, Loading
app/ADL/api/models.py +11 −14 Original line number Diff line number Diff line Loading @@ -2,8 +2,9 @@ from datetime import datetime import databases import sqlalchemy from ormar import Integer, Model, String, DateTime, Boolean, ForeignKey from ormar import Boolean, DateTime, ForeignKey, Integer, Model, String from ormar.decorators import property_field from pydantic import AnyHttpUrl from . import config Loading Loading @@ -57,19 +58,6 @@ class Files(ARMPath): site = self.datastream[3:] return f'/{site}/{self.datastream}/{self.versioned_filename}' # patched_field = ForeignKey(Files, related_name='contents') # patched_field.primary_key=True # class FileContents(Model): # class Meta: # tablename = 'file_contents' # metadata = dr_meta # database=database # versioned_filename: Files = patched_field # start_time: datetime = DateTime() # end_time: datetime = DateTime() # deleted: bool = Boolean() # n_samples: int = Integer() class ARMFiles(Model): class Meta: Loading @@ -82,3 +70,12 @@ class ARMFiles(Model): md5_checksum: str = String(max_length=32, min_length=32, regex=r'[a-fA-F\d]+') start_time: datetime = DateTime() end_time: datetime = DateTime() @property_field def name(self) -> str: site = self.datastream[:3] return f'{site}/{self.datastream}/{self.versioned_filename}' @property_field def url(self) -> AnyHttpUrl: return f'{config.host_url}/fs/stream/{self.name}'
app/ADL/api/v1.py +35 −4 Original line number Diff line number Diff line from collections.abc import Collection from datetime import datetime from logging import getLogger from typing import List, Union from collections.abc import Collection import fsspec import requests from fastapi.responses import StreamingResponse from fastapi.routing import APIRouter from fastapi_versioning import versioned_api_route from starlette.types import Send from . import config from .auth import BasicAuth, JWTAuth from .models import ARMFiles, Datastreams, Files, Sites router = APIRouter(route_class=versioned_api_route(1)) Loading Loading @@ -66,7 +68,7 @@ class FsspecStreamingResponse(StreamingResponse): @router.get('/fs/stream/{site}/{datastream}/{filename}', response_class=StreamingResponse) async def stream(site: Sites.site_code, datastream: Datastreams.datastream, filename: Files.versioned_filename): async def stream(site: Sites.site_code, datastream: Datastreams.datastream, filename: Files.versioned_filename, auth=JWTAuth): url = f'simplecache::guc:///f1/arm/{site}/{datastream}/{filename}' log.debug('streaming %s', url) Loading @@ -90,7 +92,7 @@ async def stream(site: Sites.site_code, datastream: Datastreams.datastream, file @router.get('/fs/file_list/') async def file_list(datastreams: Datastreams.datastream, start_time: datetime, end_time: datetime, include_deleted: bool=False): async def file_list(datastreams: Datastreams.datastream, start_time: datetime, end_time: datetime): if not isinstance(datastreams, Collection): datastreams = [datastreams] Loading @@ -99,3 +101,32 @@ async def file_list(datastreams: Datastreams.datastream, start_time: datetime, e start_time__gte=start_time, end_time__lte=end_time, ).all() @router.post('/get_token') async def get_token(auth=BasicAuth): openid_config = requests.get(f'{str(config.oidc_url)}.well-known/openid-configuration').json() token_url = openid_config['token_endpoint'] return requests.post( token_url, data={ 'grant_type': 'password', 'username': auth.username, 'password': auth.password, 'scope': ' '.join((scope for scope in openid_config['scopes_supported'])), }, auth=(config.oidc_client, config.oidc_secret.get_secret_value()) ).json() @router.post('/refresh_token') async def refresh_token(refresh_token: str): openid_config = requests.get(f'{str(config.oidc_url)}.well-known/openid-configuration').json() token_url = openid_config['token_endpoint'] return requests.post( token_url, data={ 'grant_type': 'refresh_token', 'refresh_token': refresh_token, }, auth=(config.oidc_client, config.oidc_secret.get_secret_value()) ).json()