Commit 96e55dfb authored by Price, Zach's avatar Price, Zach
Browse files

Implement authentication

parent e0a5759d
Loading
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -80,6 +80,7 @@ deploy:
    --atomic
    --set config.db.username=$PGDB_USER
    --set config.db.password=$PGDB_PASS
    --set config.auth.oidc_secret=$OIDC_SECRET
    --set ingress.url=api-$CI_BUILD_REF_SLUG.k8s.arm.gov
    --set dataService.guc.loadBalancerIP=""

@@ -126,5 +127,6 @@ deploy:prod:
    --atomic
    --set config.db.username=$PGDB_USER
    --set config.db.password=$PGDB_PASS
    --set config.auth.oidc_secret=$OIDC_SECRET
    --set ingress.url=api.k8s.arm.gov
    # --set ingress.className=external

app/ADL/api/auth.py

0 → 100644
+42 −0
Original line number Diff line number Diff line
import time
from logging import getLogger

import jwt
import requests
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import (HTTPAuthorizationCredentials, HTTPBasic,
                              HTTPBearer)
from jwt import PyJWKClient
from jwt.exceptions import PyJWTError

from . import config

log = getLogger(__name__)


class JWTBearer(HTTPBearer):
    def __init__(self, auto_error: bool = True):
        super(JWTBearer, self).__init__(auto_error=auto_error)

    async def __call__(self, request: Request):
        credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
        if credentials:
            if not credentials.scheme == "Bearer":
                raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authentication scheme.")

            try:
                openid_config = requests.get(f'{str(config.oidc_url)}/.well-known/openid-configuration').json()
                jwks_client = PyJWKClient(openid_config['jwks_uri'])
                signing_key = jwks_client.get_signing_key_from_jwt(credentials.credentials)
                auth_data = jwt.decode(
                    credentials.credentials,
                    signing_key.key,
                    algorithms=[config.jwt_algorithm],
                    audience=['ADL'],
                )
            except PyJWTError as e:
                log.exception('Unable to validate token')
                raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unable to validate token.")

JWTAuth = Depends(JWTBearer())
BasicAuth = Depends(HTTPBasic())
+7 −2
Original line number Diff line number Diff line
import sys
from logging import getLogger

from pydantic import (BaseSettings, DirectoryPath, FilePath, PostgresDsn,
                      SecretStr)
from pydantic import (AnyHttpUrl, BaseSettings, DirectoryPath, FilePath,
                      PostgresDsn, SecretStr)
from sqlalchemy.engine.url import URL

logger = getLogger(__name__)
@@ -24,6 +24,11 @@ class Settings(BaseSettings):
    db_host: str = "localhost"
    db_port: int = 5432
    db_name: str = "arm_all"
    jwt_algorithm: str = 'RS256'
    oidc_url: AnyHttpUrl = 'https://dex.k8s.arm.gov'
    oidc_client: str = 'ADL'
    oidc_secret: SecretStr = 'client_secret'
    host_url: AnyHttpUrl = 'https://api.k8s.arm.gov'
    # TODO: break this out into separate models so pieces can be overridden
    logging: dict = {
        "version": 1,
+11 −14
Original line number Diff line number Diff line
@@ -2,8 +2,9 @@ from datetime import datetime

import databases
import sqlalchemy
from ormar import Integer, Model, String, DateTime, Boolean, ForeignKey
from ormar import Boolean, DateTime, ForeignKey, Integer, Model, String
from ormar.decorators import property_field
from pydantic import AnyHttpUrl

from . import config

@@ -57,19 +58,6 @@ class Files(ARMPath):
        site = self.datastream[3:]
        return f'/{site}/{self.datastream}/{self.versioned_filename}'

# patched_field = ForeignKey(Files, related_name='contents')
# patched_field.primary_key=True

# class FileContents(Model):
#     class Meta:
#         tablename = 'file_contents'
#         metadata = dr_meta
#         database=database
#     versioned_filename: Files = patched_field
#     start_time: datetime = DateTime()
#     end_time: datetime = DateTime()
#     deleted: bool = Boolean()
#     n_samples: int = Integer()

class ARMFiles(Model):
    class Meta:
@@ -82,3 +70,12 @@ class ARMFiles(Model):
    md5_checksum: str = String(max_length=32, min_length=32, regex=r'[a-fA-F\d]+')
    start_time: datetime = DateTime()
    end_time: datetime = DateTime()

    @property_field
    def name(self) -> str:
        site = self.datastream[:3]
        return f'{site}/{self.datastream}/{self.versioned_filename}'

    @property_field
    def url(self) -> AnyHttpUrl:
        return f'{config.host_url}/fs/stream/{self.name}'
+35 −4
Original line number Diff line number Diff line
from collections.abc import Collection
from datetime import datetime
from logging import getLogger
from typing import List, Union
from collections.abc import Collection

import fsspec
import requests
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRouter
from fastapi_versioning import versioned_api_route
from starlette.types import Send

from . import config
from .auth import BasicAuth, JWTAuth
from .models import ARMFiles, Datastreams, Files, Sites

router = APIRouter(route_class=versioned_api_route(1))
@@ -66,7 +68,7 @@ class FsspecStreamingResponse(StreamingResponse):


@router.get('/fs/stream/{site}/{datastream}/{filename}', response_class=StreamingResponse)
async def stream(site: Sites.site_code, datastream: Datastreams.datastream, filename: Files.versioned_filename):
async def stream(site: Sites.site_code, datastream: Datastreams.datastream, filename: Files.versioned_filename, auth=JWTAuth):

    url = f'simplecache::guc:///f1/arm/{site}/{datastream}/{filename}'
    log.debug('streaming %s', url)
@@ -90,7 +92,7 @@ async def stream(site: Sites.site_code, datastream: Datastreams.datastream, file


@router.get('/fs/file_list/')
async def file_list(datastreams: Datastreams.datastream, start_time: datetime, end_time: datetime, include_deleted: bool=False):
async def file_list(datastreams: Datastreams.datastream, start_time: datetime, end_time: datetime):
    if not isinstance(datastreams, Collection):
        datastreams = [datastreams]

@@ -99,3 +101,32 @@ async def file_list(datastreams: Datastreams.datastream, start_time: datetime, e
        start_time__gte=start_time,
        end_time__lte=end_time,
    ).all()


@router.post('/get_token')
async def get_token(auth=BasicAuth):
    openid_config = requests.get(f'{str(config.oidc_url)}.well-known/openid-configuration').json()
    token_url = openid_config['token_endpoint']
    return requests.post(
        token_url,
        data={
            'grant_type': 'password',
            'username': auth.username,
            'password': auth.password,
            'scope': ' '.join((scope for scope in openid_config['scopes_supported'])),
        },
        auth=(config.oidc_client, config.oidc_secret.get_secret_value())
    ).json()

@router.post('/refresh_token')
async def refresh_token(refresh_token: str):
    openid_config = requests.get(f'{str(config.oidc_url)}.well-known/openid-configuration').json()
    token_url = openid_config['token_endpoint']
    return requests.post(
        token_url,
        data={
            'grant_type': 'refresh_token',
            'refresh_token': refresh_token,
        },
        auth=(config.oidc_client, config.oidc_secret.get_secret_value())
    ).json()
Loading