Commit fc70a6f4 authored by Cage, Gregory's avatar Cage, Gregory
Browse files

Merge branch '112-psa_authnz_token_decode' into 'dev'

Add jwt decode to psa auth manager and rework scope checking

Closes #112

See merge request !86
parents 4408e794 74b30ec7
Loading
Loading
Loading
Loading
Loading
+9 −5
Original line number Diff line number Diff line
@@ -545,10 +545,10 @@ class OIDCAuthnzBase(IdentityProvider):
                options={
                    "verify_signature": True,
                    "verify_exp": True,
                    "verify_nbf": True,
                    "verify_iat": True,
                    "verify_aud": bool(self.config.accepted_audiences),
                    "verify_iss": True,
                    "verify_nbf": False,
                    "verify_iat": False,
                    "verify_aud": False,
                    "verify_iss": False,
                },
            )
        except jwt.exceptions.PyJWKClientError:
@@ -559,7 +559,11 @@ class OIDCAuthnzBase(IdentityProvider):
            # All other exceptions are bubbled up
            return None, None
        # jwt verified, we can now fetch the user
        try:
            user_id = decoded_jwt["sub"]
        except:
            user_id = decoded_jwt["subject"]

        custos_authnz_token = self._get_custos_authnz_token(sa_session, user_id, self.config.provider)
        user = custos_authnz_token.user if custos_authnz_token else None
        return user, decoded_jwt
+14 −6
Original line number Diff line number Diff line
@@ -183,6 +183,10 @@ class AuthnzManager:
            rtv["checkin_env"] = config_xml.find("checkin_env").text
        if config_xml.find("alias") is not None:
            rtv["alias"] = config_xml.find("alias").text
        if config_xml.find("well_known_oidc_config_uri") is not None:
            rtv["well_known_oidc_config_uri"] = config_xml.find("well_known_oidc_config_uri").text
        if config_xml.find("required_scope") is not None:
            rtv["required_scope"] = config_xml.find("required_scope").text

        return rtv

@@ -216,6 +220,8 @@ class AuthnzManager:
            rtv["user_extra_authorization_script"] = config_xml.find("user_extra_authorization_script").text
        if config_xml.find("accepted_audiences") is not None:
            rtv["accepted_audiences"] = config_xml.find("accepted_audiences").text
        if config_xml.find("required_scope") is not None:
            rtv["required_scope"] = config_xml.find("required_scope").text
        return rtv

    def get_allowed_idps(self):
@@ -403,6 +409,11 @@ class AuthnzManager:
            log.exception(msg)
            return False, msg, None

    def _validate_permissions(self, user, jwt, provider):
        # Get required scope if provided in config, else use the configured scope prefix
        required_scopes = [f"{self.oidc_backends_config[provider].get('required_scope', f'{self.app.config.oidc_scope_prefix}:*')}"]
        self._assert_jwt_contains_scopes(user, jwt, required_scopes)

    def callback(self, provider, state_token, authz_code, trans, login_redirect_url, idphint=None):
        try:
            success, message, backend = self._get_authnz_backend(provider, idphint=idphint)
@@ -435,16 +446,13 @@ class AuthnzManager:
            raise exceptions.AuthenticationFailed(
                err_msg=f"User: {user.username} does not have the required scopes: [{required_scopes}]"
            )
        scopes = jwt.get("scope") or ""
        scopes = f"{jwt.get('scope')} {jwt.get('scp')}" or ""

        if not set(required_scopes).issubset(scopes.split(" ")):
            raise exceptions.AuthenticationFailed(
                err_msg=f"User: {user.username} has JWT with scopes: [{scopes}] but not required scopes: [{required_scopes}]"
            )

    def _validate_permissions(self, user, jwt):
        required_scopes = [f"{self.app.config.oidc_scope_prefix}:*"]
        self._assert_jwt_contains_scopes(user, jwt, required_scopes)

    def _match_access_token_to_user_in_provider(self, sa_session, provider, access_token):
        try:
            success, message, backend = self._get_authnz_backend(provider)
@@ -459,7 +467,7 @@ class AuthnzManager:
                log.exception("Could not decode access token")
                raise exceptions.AuthenticationFailed(err_msg="Invalid access token or an unexpected error occurred.")
            if user and jwt:
                self._validate_permissions(user, jwt)
                self._validate_permissions(user, jwt, provider)
                return user
            elif not user and jwt:
                # jwt was decoded, but no user could be matched
+86 −2
Original line number Diff line number Diff line
@@ -56,6 +56,8 @@ BACKENDS_NAME = {
    "egi_checkin": "egi-checkin",
}

AZURE_USERINFO_ENDPOINT = "https://graph.microsoft.com/oidc/userinfo"

AUTH_PIPELINE = (
    # Get the information we can about the user and return it in a simple
    # format to create the user instance later. On some cases the details are
@@ -141,6 +143,8 @@ class PSAAuthnz(IdentityProvider):
            self.config[setting_name("API_URL")] = oidc_backend_config.get("api_url")
        if oidc_backend_config.get("url") is not None:
            self.config[setting_name("URL")] = oidc_backend_config.get("url")
        if oidc_backend_config.get("well_known_oidc_config_uri") is not None:
            self.config["well_known_oidc_config_uri"] = oidc_backend_config.get("well_known_oidc_config_uri")

    def _get_helper(self, name, do_import=False):
        this_config = self.config.get(setting_name(name), DEFAULTS.get(name, None))
@@ -236,6 +240,86 @@ class PSAAuthnz(IdentityProvider):
            return True, "", response
        return response.get("success", False), response.get("message", ""), ""

    def decode_user_access_token(self, sa_session, access_token):
        """Verifies and decodes an access token against this provider, returning the user and
        a dict containing the decoded token data.

        :type  sa_session:      sqlalchemy.orm.scoping.scoped_session
        :param sa_session:      SQLAlchemy database handle.

        :type  access_token: string
        :param access_token: An OIDC access token

        :return: A tuple containing the user and decoded jwt data or [None, None]
                 if the access token does not belong to this provider.
        :rtype: Tuple[User, dict]
        """
        well_known_oidc_config_uri = self.config["well_known_oidc_config_uri"] if self.config.get(
            "well_known_oidc_config_uri", None) else self._get_well_known_uri_from_url(self.config["provider"])
        well_known_oidc_config = None
        try:
            well_known_oidc_config = requests.get(
                well_known_oidc_config_uri,
                headers={},
                verify=True,
                params={},
            ).json()
        except Exception:
            log.error(f"Failed to load well-known OIDC config URI: {well_known_oidc_config_uri}")
            raise
        jwks_client = jwt.PyJWKClient(well_known_oidc_config["jwks_uri"])

        try:
            signing_key = jwks_client.get_signing_key_from_jwt(access_token)
            accepted_aud = self.config.get("accepted_audiences", None)
            headers = jwt.get_unverified_header(access_token)
            verify_signature = True
            if headers.get("nonce", None) and self.config["provider"] == "azure":
                # Tokens with Nonce in header are not supposed to be verified
                verify_signature = False
                r = requests.get(AZURE_USERINFO_ENDPOINT, headers={"Authorization": f"Bearer {access_token}"})
                r.raise_for_status()

            decoded_jwt = jwt.decode(
                access_token,
                signing_key.key,
                algorithms=["RS256"],
                issuer=well_known_oidc_config["issuer"],
                audience=accepted_aud,
                options={
                    "verify_signature": verify_signature,
                    "verify_exp": True,
                    "verify_nbf": True,
                    "verify_iat": True,
                    "verify_aud": bool(accepted_aud),
                    "verify_iss": True,
                },
            )
        except jwt.exceptions.PyJWKClientError:
            log.debug(
                f"Could not get signing keys for access token with provider: {self.config['provider']}. Ignoring...")
            return None, None
        except jwt.exceptions.InvalidIssuerError:
            # An Invalid issuer means that the access token is not relevant to this provider.
            # All other exceptions are bubbled up
            return None, None
        # jwt verified, we can now fetch the user
        user_id = decoded_jwt["unique_name"]
        authnz_token = self._get_authnz_token(sa_session, user_id, self.config["provider"])
        user = authnz_token.user if authnz_token else None
        return user, decoded_jwt

    @staticmethod
    def _get_authnz_token(sa_session, user_id, provider):
        return sa_session.query(UserAuthnzToken).filter_by(uid=user_id).one_or_none()

    def _get_well_known_uri_from_url(self, provider):
        # TODO: Look up this URL from a Python library
        base_url = self.config["SOCIAL_AUTH_URL"]
        # Remove potential trailing slash to avoid "//realms"
        base_url = base_url if base_url[-1] != "/" else base_url[:-1]
        return f"{base_url}/.well-known/openid-configuration"


class Strategy(BaseStrategy):
    def __init__(self, request, session, storage, config, tpl=None):
+7 −0
Original line number Diff line number Diff line
@@ -149,6 +149,13 @@
                                    </xs:documentation>
                                </xs:annotation>
                            </xs:element>
                            <xs:element name="required_scope" minOccurs="0" type="xs:string">
                                <xs:annotation>
                                    <xs:documentation>
                                        Specifies scope to be used for authorization for this provider.
                                    </xs:documentation>
                                </xs:annotation>
                            </xs:element>
                        </xs:all>
                        <xs:attribute name="name" type="xs:string" use="required">
                            <xs:annotation>