Unverified Commit 942ddea3 authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #18255 from davelopez/24.0_fix_invenio_credentials_handling

[24.0] Fix Invenio credentials handling
parents 4c8b4bf2 7cfba26f
Loading
Loading
Loading
Loading
+1 −6
Original line number Diff line number Diff line
@@ -7,7 +7,6 @@ from typing import (

from typing_extensions import Unpack

from galaxy.exceptions import AuthenticationRequired
from galaxy.files import ProvidesUserFileSourcesUserContext
from galaxy.files.sources import (
    BaseFilesSource,
@@ -193,15 +192,11 @@ class RDMFilesSource(BaseFilesSource):
            effective_props[key] = self._evaluate_prop(val, user_context=user_context)
        return effective_props

    def get_authorization_token(self, user_context: OptionalUserContext) -> str:
    def get_authorization_token(self, user_context: OptionalUserContext) -> Optional[str]:
        token = None
        if user_context:
            effective_props = self._serialization_props(user_context)
            token = effective_props.get("token")
        if not token:
            raise AuthenticationRequired(
                f"Please provide a personal access token in your user's preferences for '{self.label}'"
            )
        return token

    def get_public_name(self, user_context: OptionalUserContext) -> Optional[str]:
+18 −13
Original line number Diff line number Diff line
@@ -217,12 +217,7 @@ class InvenioRepositoryInteractor(RDMRepositoryInteractor):
            },
        }

        headers = self._get_request_headers(user_context)
        if "Authorization" not in headers:
            raise Exception(
                "Cannot create record without authentication token. Please set your personal access token in your Galaxy preferences."
            )

        headers = self._get_request_headers(user_context, auth_required=True)
        response = requests.post(self.records_url, json=create_record_request, headers=headers)
        self._ensure_response_has_expected_status_code(response, 201)
        record = response.json()
@@ -238,7 +233,7 @@ class InvenioRepositoryInteractor(RDMRepositoryInteractor):
    ):
        record = self._get_draft_record(record_id, user_context=user_context)
        upload_file_url = record["links"]["files"]
        headers = self._get_request_headers(user_context)
        headers = self._get_request_headers(user_context, auth_required=True)

        # Add file metadata entry
        response = requests.post(upload_file_url, json=[{"key": filename}], headers=headers)
@@ -394,28 +389,38 @@ class InvenioRepositoryInteractor(RDMRepositoryInteractor):
        }

    def _get_response(
        self, user_context: OptionalUserContext, request_url: str, params: Optional[Dict[str, Any]] = None
        self,
        user_context: OptionalUserContext,
        request_url: str,
        params: Optional[Dict[str, Any]] = None,
        auth_required: bool = False,
    ) -> dict:
        headers = self._get_request_headers(user_context)
        headers = self._get_request_headers(user_context, auth_required)
        response = requests.get(request_url, params=params, headers=headers)
        self._ensure_response_has_expected_status_code(response, 200)
        return response.json()

    def _get_request_headers(self, user_context: OptionalUserContext):
    def _get_request_headers(self, user_context: OptionalUserContext, auth_required: bool = False):
        token = self.plugin.get_authorization_token(user_context)
        headers = {"Authorization": f"Bearer {token}"} if token else {}
        if auth_required and token is None:
            self._raise_auth_required()
        return headers

    def _ensure_response_has_expected_status_code(self, response, expected_status_code: int):
        if response.status_code == 403:
            record_url = response.url.replace("/api", "").replace("/files", "")
            raise AuthenticationRequired(f"Please make sure you have the necessary permissions to access: {record_url}")
        if response.status_code != expected_status_code:
            if response.status_code == 403:
                self._raise_auth_required()
            error_message = self._get_response_error_message(response)
            raise Exception(
                f"Request to {response.url} failed with status code {response.status_code}: {error_message}"
            )

    def _raise_auth_required(self):
        raise AuthenticationRequired(
            f"Please provide a personal access token in your user's preferences for '{self.plugin.label}'"
        )

    def _get_response_error_message(self, response):
        response_json = response.json()
        error_message = response_json.get("message") if response.status_code == 400 else response.text