Commit 540ac15b authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

Merge branch '116-refresh-tokens-in-a-separate-thread' into 'dev'

Create a separate task to refresh tokens

Closes #116

See merge request !93
parents ec1e881e 9709c1d8
Loading
Loading
Loading
Loading
Loading
+31 −23
Original line number Diff line number Diff line
@@ -793,22 +793,6 @@
:Type: seq


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``user_config_templates_index_by``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    Configure URIs for user object stores to use either the object ID
    ('id') or UUIDs ('uuid'). Either is fine really, Galaxy doesn't
    typically expose database objects by 'id' but there isn't any
    obvious disadvantage to doing it in this case and it keeps user
    exposed URIs much smaller. The default of UUID feels a little more
    like a typical way to do this within Galaxy though. Do not change
    this value once user object stores have been created.
:Default: ``uuid``
:Type: str


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``user_config_templates_use_saved_configuration``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -3758,25 +3742,39 @@
:Type: bool


~~~~~~~~~~~~~~~~~~~~~~~~~~~
``show_welcome_with_login``
~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~
``hide_sign_out``
~~~~~~~~~~~~~~~~~

:Description:
    Show the site's welcome page (see welcome_url) alongside the login
    page (even if require_login is true).
    Hide the sign out link in the user menu (Useful if only third
    party authentication is enabled, and users should only sign out
    using the  authentication provider's sign out page.)
:Default: ``false``
:Type: bool

~~~~~~~~~~~~~~~~~~~~~~~~~~~

~~~~~~~~~~~~~~~~~~~~~~~~~~
``disable_internal_login``
~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    Hides internal Galaxy login form fields
:Default: ``false``
:Type: bool


~~~~~~~~~~~~~~~~~~~~~~~~~~~
``show_welcome_with_login``
~~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    Hides internal Galaxy login form fields.
    Show the site's welcome page (see welcome_url) alongside the login
    page (even if require_login is true).
:Default: ``false``
:Type: bool


~~~~~~~~~~~~~~~~~~~~~~~
``prefer_custos_login``
~~~~~~~~~~~~~~~~~~~~~~~
@@ -4169,6 +4167,16 @@
:Type: str


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``oidc_refresh_tokens_interval``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

:Description:
    The interval in seconds between calls to refresh OIDC tokens
:Default: ``300``
:Type: int


~~~~~~~~~~~~~~~~~~~~
``auth_config_file``
~~~~~~~~~~~~~~~~~~~~
+10 −0
Original line number Diff line number Diff line
@@ -796,6 +796,16 @@ class UniverseApplication(StructuredApp, GalaxyManagerApplication):
            self.authnz_manager = managers.AuthnzManager(
                self, self.config.oidc_config_file, self.config.oidc_backends_config_file
            )
            if self.is_webapp:
                self.refresh_oidc_tokens_task = IntervalTask(
                    func=lambda: self.authnz_manager.refresh_expiring_oidc_tokens(self.model.session),
                    name="RefreshOIDCTokensTask",
                    interval=self.config.oidc_refresh_tokens_interval,
                    immediate_start=True,
                    time_execution=True,
                )
                self.application_stack.register_postfork_function(self.refresh_oidc_tokens_task.start)
                self.haltables.append(("RefreshOIDCTokensTask", self.refresh_oidc_tokens_task.shutdown))

            # If there is only a single external authentication provider in use
            # TODO: Future work will expand on this and provide an interface for
+1 −1
Original line number Diff line number Diff line
@@ -40,7 +40,7 @@ class IdentityProvider:
        """
        raise NotImplementedError()

    def refresh(self, trans, token):
    def refresh(self, session, token, skip_old_tokens_threshold_days):
        raise NotImplementedError()

    def authenticate(self, provider, trans):
+30 −5
Original line number Diff line number Diff line
@@ -115,13 +115,19 @@ class OIDCAuthnzBase(IdentityProvider):
    def _decode_token_no_signature(self, token):
        return jwt.decode(token, audience=self.config.client_id, options={"verify_signature": False})

    def refresh(self, trans, custos_authnz_token):
    def refresh(self, sa_session, custos_authnz_token, skip_old_tokens_threshold_days):
        if custos_authnz_token is None:
            raise exceptions.AuthenticationFailed("cannot find authorized user while refreshing token")
        id_token_decoded = self._decode_token_no_signature(custos_authnz_token.id_token)
        # do not refresh tokens if they didn't reach their half lifetime
        if int(id_token_decoded["iat"]) + int(id_token_decoded["exp"]) > 2 * int(time.time()):
            return False

        # do not refresh tokens if last token is too old
        skip_old_tokens_threshold_seconds = skip_old_tokens_threshold_days * 86400  # 86400 seconds in a day
        if int(id_token_decoded["iat"]) + skip_old_tokens_threshold_seconds < int(time.time()):
            return False

        oauth2_session = self._create_oauth2_session()
        token_endpoint = self.config.token_endpoint
        if self.config.iam_client_secret:
@@ -135,8 +141,11 @@ class OIDCAuthnzBase(IdentityProvider):
            "refresh_token": custos_authnz_token.refresh_token,
        }

        log.debug(
            f"Refreshing user token for {custos_authnz_token.external_user_id} via `{custos_authnz_token.provider}` identity provider"
        )
        token = oauth2_session.refresh_token(token_endpoint, **params)
        processed_token = self._process_token(trans, oauth2_session, token, False)
        processed_token = self._process_token_after_refresh(token)

        custos_authnz_token.access_token = processed_token["access_token"]
        if "id_token" in processed_token:
@@ -147,9 +156,14 @@ class OIDCAuthnzBase(IdentityProvider):
        custos_authnz_token.expiration_time = processed_token["expiration_time"]
        custos_authnz_token.refresh_expiration_time = processed_token["refresh_expiration_time"]

        trans.sa_session.add(custos_authnz_token)
        with transaction(trans.sa_session):
            trans.sa_session.commit()
        sa_session.add(custos_authnz_token)
        with transaction(sa_session):
            sa_session.commit()

        log.debug(
            f"Refreshed user token for {custos_authnz_token.external_user_id} via `{custos_authnz_token.provider}` identity provider"
        )

        return True

    def _get_provider_specific_scopes(self):
@@ -182,6 +196,17 @@ class OIDCAuthnzBase(IdentityProvider):
        trans.set_cookie(value=nonce, name=NONCE_COOKIE_NAME)
        return authorization_url

    def _process_token_after_refresh(self, token):
        processed_token = {}
        processed_token["access_token"] = token["access_token"]
        processed_token["id_token"] = token["id_token"]
        processed_token["refresh_token"] = token["refresh_token"] if "refresh_token" in token else None
        processed_token["expiration_time"] = datetime.now() + timedelta(seconds=token.get("expires_in", 3600))
        processed_token["refresh_expiration_time"] = (
            (datetime.now() + timedelta(seconds=token["refresh_expires_in"])) if "refresh_expires_in" in token else None
        )
        return processed_token

    def _process_token(self, trans, oauth2_session, token, validate_nonce=True):
        processed_token = {}
        processed_token["access_token"] = token["access_token"]
+25 −13
Original line number Diff line number Diff line
@@ -5,9 +5,14 @@ import logging
import os
import random
import string
from datetime import (
    datetime,
    timedelta,
)

from cloudauthz import CloudAuthz
from cloudauthz.exceptions import CloudAuthzBaseException
from sqlalchemy import select

from galaxy import (
    exceptions,
@@ -352,29 +357,34 @@ class AuthnzManager:
            raise exceptions.ItemAccessibilityException(msg)
        return qres

    def refresh_expiring_oidc_tokens_for_provider(self, trans, auth):
    def refresh_expiring_oidc_tokens_for_provider(self, sa_session, auth):
        try:
            success, message, backend = self._get_authnz_backend(auth.provider)
            if success is False:
                msg = f"An error occurred when refreshing user token on `{auth.provider}` identity provider: {message}"
                msg = f"An error occurred when getting backend for `{auth.provider}` identity provider: {message}"
                log.error(msg)
                return False
            refreshed = backend.refresh(trans, auth)
            if refreshed:
                log.debug(f"Refreshed user token via `{auth.provider}` identity provider")
            backend.refresh(sa_session, auth, skip_old_tokens_threshold_days=30)
            return True
        except Exception:
            log.exception("An error occurred when refreshing user token")
            return False

    def refresh_expiring_oidc_tokens(self, trans, user=None):
        user = trans.user or user
        if not isinstance(user, model.User):
    def refresh_expiring_oidc_tokens(self, sa_session):
        # Galaxy starts multiple RefreshOIDCTokensTask (one for each handler and workes). Until we found a better way
        # to deal with it, we check the server name here and only run refresh for one worker.
        if (
            self.app.config.server_name != self.app.config.base_server_name
            and self.app.config.server_name != f"{self.app.config.base_server_name}.1"
        ):
            return

        all_users = sa_session.scalars(select(model.User)).all()
        for user in all_users:
            for auth in user.custos_auth or []:
            self.refresh_expiring_oidc_tokens_for_provider(trans, auth)
                self.refresh_expiring_oidc_tokens_for_provider(sa_session, auth)
            for auth in user.social_auth or []:
            self.refresh_expiring_oidc_tokens_for_provider(trans, auth)
                self.refresh_expiring_oidc_tokens_for_provider(sa_session, auth)

    def authenticate(self, provider, trans, idphint=None):
        """
@@ -411,7 +421,9 @@ class AuthnzManager:

    def _validate_permissions(self, user, jwt, provider):
        # Get required scope if provided in config, else use the configured scope prefix
        required_scopes = [f"{self.oidc_backends_config[provider].get('required_scope', f'{self.app.config.oidc_scope_prefix}:*')}"]
        required_scopes = [
            f"{self.oidc_backends_config[provider].get('required_scope', f'{self.app.config.oidc_scope_prefix}:*')}"
        ]
        self._assert_jwt_contains_scopes(user, jwt, required_scopes)

    def callback(self, provider, state_token, authz_code, trans, login_redirect_url, idphint=None):
Loading