Commit edf48be9 authored by Yakubov, Sergey's avatar Yakubov, Sergey
Browse files

refactor token refreshing threshold

parent 4529dfaf
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -40,7 +40,7 @@ class IdentityProvider:
        """
        raise NotImplementedError()

    def refresh(self, trans, token):
    def refresh(self, session, token, skip_old_tokens_threshold_days):
        raise NotImplementedError()

    def authenticate(self, provider, trans):
+6 −1
Original line number Diff line number Diff line
@@ -115,7 +115,7 @@ class OIDCAuthnzBase(IdentityProvider):
    def _decode_token_no_signature(self, token):
        return jwt.decode(token, audience=self.config.client_id, options={"verify_signature": False})

    def refresh(self, sa_session, custos_authnz_token):
    def refresh(self, sa_session, custos_authnz_token, skip_old_tokens_threshold_days):
        if custos_authnz_token is None:
            raise exceptions.AuthenticationFailed("cannot find authorized user while refreshing token")
        id_token_decoded = self._decode_token_no_signature(custos_authnz_token.id_token)
@@ -123,6 +123,11 @@ class OIDCAuthnzBase(IdentityProvider):
        if int(id_token_decoded["iat"]) + int(id_token_decoded["exp"]) > 2 * int(time.time()):
            return False

        # do not refresh tokens if last token is too old
        skip_old_tokens_threshold_seconds = skip_old_tokens_threshold_days * 86400  # 86400 seconds in a day
        if int(id_token_decoded["iat"]) + skip_old_tokens_threshold_seconds < int(time.time()):
            return False

        oauth2_session = self._create_oauth2_session()
        token_endpoint = self.config.token_endpoint
        if self.config.iam_client_secret:
+4 −5
Original line number Diff line number Diff line
@@ -361,22 +361,21 @@ class AuthnzManager:
                msg = f"An error occurred when getting backend for `{auth.provider}` identity provider: {message}"
                log.error(msg)
                return False
            backend.refresh(sa_session, auth)
            backend.refresh(sa_session, auth, skip_old_tokens_threshold_days = 30)
            return True
        except Exception:
            log.exception("An error occurred when refreshing user token")
            return False

    def refresh_expiring_oidc_tokens(self, sa_session):
        # Galaxy starts multiple RefreshOIDCTokensTask (one for each handler and workes). Until we found a better way
        # to deal with it, we check the server name here and only run refresh for one worker.
        if (self.app.config.server_name != self.app.config.base_server_name
                and self.app.config.server_name != f"{self.app.config.base_server_name}.1"):
            return
        user_filter = datetime.now() - timedelta(days=90)

        all_users = sa_session.scalars(select(model.User)).all()
        for user in all_users:
            if not user.galaxy_sessions or user.current_galaxy_session.update_time < user_filter:
                log.debug(f"skipping token refresh for user {user.username}")
                continue
            for auth in user.custos_auth or []:
                self.refresh_expiring_oidc_tokens_for_provider(sa_session, auth)
            for auth in user.social_auth or []:
+8 −1
Original line number Diff line number Diff line
@@ -176,7 +176,7 @@ class PSAAuthnz(IdentityProvider):
        extra_data["expires"] = int(expires - time.time())
        user_authnz_token.set_extra_data(extra_data)

    def refresh(self, sa_session, user_authnz_token):
    def refresh(self, sa_session, user_authnz_token, skip_old_tokens_threshold_days):
        if not user_authnz_token or not user_authnz_token.extra_data:
            return False
        # refresh tokens if they reached their half lifetime
@@ -187,6 +187,13 @@ class PSAAuthnz(IdentityProvider):
        else:
            log.debug("No `expires` or `expires_in` key found in token extra data, cannot refresh")
            return False

        # do not refresh tokens if last token is too old
        skip_old_tokens_threshold_seconds = skip_old_tokens_threshold_days * 86400  # 86400 seconds in a day
        if int(user_authnz_token.extra_data["auth_time"]) + skip_old_tokens_threshold_seconds < int(time.time()):
            return False


        if int(user_authnz_token.extra_data["auth_time"]) + int(expires) / 2 <= int(time.time()):
            on_the_fly_config(sa_session)
            log.debug(f"Refreshing user token for {user_authnz_token.uid} via `{user_authnz_token.provider}` identity provider")