Commit dc033a9f authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

move OIDC token refresh to a separate thread

parent ec1e881e
Loading
Loading
Loading
Loading
Loading
+31 −23
Original line number Diff line number Diff line
@@ -793,22 +793,6 @@
:Type: seq


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``user_config_templates_index_by``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    Configure URIs for user object stores to use either the object ID
    ('id') or UUIDs ('uuid'). Either is fine really, Galaxy doesn't
    typically expose database objects by 'id' but there isn't any
    obvious disadvantage to doing it in this case and it keeps user
    exposed URIs much smaller. The default of UUID feels a little more
    like a typical way to do this within Galaxy though. Do not change
    this value once user object stores have been created.
:Default: ``uuid``
:Type: str


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``user_config_templates_use_saved_configuration``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -3758,25 +3742,39 @@
:Type: bool


~~~~~~~~~~~~~~~~~~~~~~~~~~~
``show_welcome_with_login``
~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
``hide_sign_out``
~~~~~~~~~~~~~~~~~

:Description:
    Show the site's welcome page (see welcome_url) alongside the login
    page (even if require_login is true).
    Hide the sign out link in the user menu (Useful if only third
    party authentication is enabled, and users should only sign out
    using the  authentication provider's sign out page.)
:Default: ``false``
:Type: bool

~~~~~~~~~~~~~~~~~~~~~~~~~~~

~~~~~~~~~~~~~~~~~~~~~~~~~~
``disable_internal_login``
~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    Hides internal Galaxy login form fields
:Default: ``false``
:Type: bool


~~~~~~~~~~~~~~~~~~~~~~~~~~~
``show_welcome_with_login``
~~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    Hides internal Galaxy login form fields.
    Show the site's welcome page (see welcome_url) alongside the login
    page (even if require_login is true).
:Default: ``false``
:Type: bool


~~~~~~~~~~~~~~~~~~~~~~~
``prefer_custos_login``
~~~~~~~~~~~~~~~~~~~~~~~
@@ -4169,6 +4167,16 @@
:Type: str


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``oidc_refresh_tokens_interval``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    The interval in seconds between calls to refresh OIDC tokens
:Default: ``60``
:Type: int


~~~~~~~~~~~~~~~~~~~~
``auth_config_file``
~~~~~~~~~~~~~~~~~~~~
+11 −0
Original line number Diff line number Diff line
@@ -797,6 +797,17 @@ class UniverseApplication(StructuredApp, GalaxyManagerApplication):
                self, self.config.oidc_config_file, self.config.oidc_backends_config_file
            )

            self.refresh_oidc_tokens_task = IntervalTask(
                func=lambda: self.authnz_manager.refresh_expiring_oidc_tokens(self.model.session),
                name="RefreshOIDCTokensTask",
                interval=self.config.oidc_refresh_tokens_interval,
                immediate_start=True,
                time_execution=True,
            )
            self.application_stack.register_postfork_function(self.refresh_oidc_tokens_task.start)
            self.haltables.append(("RefreshOIDCTokensTask", self.refresh_oidc_tokens_task.shutdown))


            # If there is only a single external authentication provider in use
            # TODO: Future work will expand on this and provide an interface for
            # multiple auth providers allowing explicit authenticated association.
+21 −5
Original line number Diff line number Diff line
@@ -115,13 +115,14 @@ class OIDCAuthnzBase(IdentityProvider):
    def _decode_token_no_signature(self, token):
        return jwt.decode(token, audience=self.config.client_id, options={"verify_signature": False})

    def refresh(self, trans, custos_authnz_token):
    def refresh(self, sa_session, custos_authnz_token):
        if custos_authnz_token is None:
            raise exceptions.AuthenticationFailed("cannot find authorized user while refreshing token")
        id_token_decoded = self._decode_token_no_signature(custos_authnz_token.id_token)
        # do not refresh tokens if they didn't reach their half lifetime
        if int(id_token_decoded["iat"]) + int(id_token_decoded["exp"]) > 2 * int(time.time()):
            return False

        oauth2_session = self._create_oauth2_session()
        token_endpoint = self.config.token_endpoint
        if self.config.iam_client_secret:
@@ -136,7 +137,7 @@ class OIDCAuthnzBase(IdentityProvider):
        }

        token = oauth2_session.refresh_token(token_endpoint, **params)
        processed_token = self._process_token(trans, oauth2_session, token, False)
        processed_token = self._process_token_after_refresh(token)

        custos_authnz_token.access_token = processed_token["access_token"]
        if "id_token" in processed_token:
@@ -147,9 +148,12 @@ class OIDCAuthnzBase(IdentityProvider):
        custos_authnz_token.expiration_time = processed_token["expiration_time"]
        custos_authnz_token.refresh_expiration_time = processed_token["refresh_expiration_time"]

        trans.sa_session.add(custos_authnz_token)
        with transaction(trans.sa_session):
            trans.sa_session.commit()
        sa_session.add(custos_authnz_token)
        with transaction(sa_session):
            sa_session.commit()

        log.debug(f"Refreshed user token for {custos_authnz_token.external_user_id} via `{custos_authnz_token.provider}` identity provider")

        return True

    def _get_provider_specific_scopes(self):
@@ -182,6 +186,18 @@ class OIDCAuthnzBase(IdentityProvider):
        trans.set_cookie(value=nonce, name=NONCE_COOKIE_NAME)
        return authorization_url

    def _process_token_after_refresh(self, token):
        processed_token = {}
        processed_token["access_token"] = token["access_token"]
        processed_token["id_token"] = token["id_token"]
        processed_token["refresh_token"] = token["refresh_token"] if "refresh_token" in token else None
        processed_token["expiration_time"] = datetime.now() + timedelta(seconds=token.get("expires_in", 3600))
        processed_token["refresh_expiration_time"] = (
            (datetime.now() + timedelta(seconds=token["refresh_expires_in"])) if "refresh_expires_in" in token else None
        )
        return processed_token


    def _process_token(self, trans, oauth2_session, token, validate_nonce=True):
        processed_token = {}
        processed_token["access_token"] = token["access_token"]
+13 −13
Original line number Diff line number Diff line
import builtins
import copy
from datetime import datetime, timedelta
import json
import logging
import os
import random
import string
from sqlalchemy import select

from cloudauthz import CloudAuthz
from cloudauthz.exceptions import CloudAuthzBaseException
@@ -352,29 +354,27 @@ class AuthnzManager:
            raise exceptions.ItemAccessibilityException(msg)
        return qres

    def refresh_expiring_oidc_tokens_for_provider(self, trans, auth):
    def refresh_expiring_oidc_tokens_for_provider(self, sa_session, auth):
        try:
            success, message, backend = self._get_authnz_backend(auth.provider)
            if success is False:
                msg = f"An error occurred when refreshing user token on `{auth.provider}` identity provider: {message}"
                msg = f"An error occurred when getting backend for `{auth.provider}` identity provider: {message}"
                log.error(msg)
                return False
            refreshed = backend.refresh(trans, auth)
            if refreshed:
                log.debug(f"Refreshed user token via `{auth.provider}` identity provider")
            backend.refresh(sa_session, auth)
            return True
        except Exception:
            log.exception("An error occurred when refreshing user token")
            return False

    def refresh_expiring_oidc_tokens(self, trans, user=None):
        user = trans.user or user
        if not isinstance(user, model.User):
            return
    def refresh_expiring_oidc_tokens(self, sa_session):
        user_filter = datetime.now() - timedelta(days=7)
        all_users = sa_session.scalars(select(model.User).filter(model.User.update_time < user_filter)).all()
        for user in all_users:
            for auth in user.custos_auth or []:
            self.refresh_expiring_oidc_tokens_for_provider(trans, auth)
                self.refresh_expiring_oidc_tokens_for_provider(sa_session, auth)
            for auth in user.social_auth or []:
            self.refresh_expiring_oidc_tokens_for_provider(trans, auth)
                self.refresh_expiring_oidc_tokens_for_provider(sa_session, auth)

    def authenticate(self, provider, trans, idphint=None):
        """
+7 −3
Original line number Diff line number Diff line
@@ -176,7 +176,7 @@ class PSAAuthnz(IdentityProvider):
        extra_data["expires"] = int(expires - time.time())
        user_authnz_token.set_extra_data(extra_data)

    def refresh(self, trans, user_authnz_token):
    def refresh(self, sa_session, user_authnz_token):
        if not user_authnz_token or not user_authnz_token.extra_data:
            return False
        # refresh tokens if they reached their half lifetime
@@ -188,13 +188,17 @@ class PSAAuthnz(IdentityProvider):
            log.debug("No `expires` or `expires_in` key found in token extra data, cannot refresh")
            return False
        if int(user_authnz_token.extra_data["auth_time"]) + int(expires) / 2 <= int(time.time()):
            on_the_fly_config(trans.sa_session)
            on_the_fly_config(sa_session)
            if self.config["provider"] == "azure":
                self.refresh_azure(user_authnz_token)
            else:
                strategy = Strategy(trans.request, trans.session, Storage, self.config)
                strategy = Strategy(None, sa_session, Storage, self.config)
                user_authnz_token.refresh_token(strategy)
            log.debug(
                f"Refreshed user token for {user_authnz_token.uid} via `{user_authnz_token.provider}` identity provider")

            return True

        return False

    def authenticate(self, trans):
Loading