Unverified Commit f6c8c636 authored by Marius van den Beek's avatar Marius van den Beek Committed by GitHub
Browse files

Merge pull request #18629 from mvdbeek/fix_install_model_session_scope

[24.0] Close install model session when request ends
parents 42829a5c d20b3331
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -104,6 +104,7 @@ class WebApplication:
        self.transaction_factory = DefaultWebTransaction
        # Set if trace logging is enabled
        self.trace_logger = None
        self.session_factories = []

    def add_ui_controller(self, controller_name, controller):
        """
@@ -170,10 +171,12 @@ class WebApplication:
        path_info = environ.get("PATH_INFO", "")

        try:
            self._model.set_request_id(request_id)  # Start SQLAlchemy session scope
            for session_factory in self.session_factories:
                session_factory.set_request_id(request_id)  # Start SQLAlchemy session scope
            return self.handle_request(request_id, path_info, environ, start_response)
        finally:
            self._model.unset_request_id(request_id)  # End SQLAlchemy session scope
            for session_factory in self.session_factories:
                session_factory.unset_request_id(request_id)  # End SQLAlchemy session scope
            self.trace(message="Handle request finished")
            if self.trace_logger:
                self.trace_logger.context_remove("request_id")
+1 −1
Original line number Diff line number Diff line
@@ -120,7 +120,7 @@ class WebApplication(base.WebApplication):

        # We need this to set the REQUEST_ID contextvar in model.base *BEFORE* a GalaxyWebTransaction is created.
        # This will ensure a SQLAlchemy session is request-scoped for legacy (non-fastapi) endpoints.
        self._model = galaxy_app.model
        self.session_factories.append(galaxy_app.model)

    def build_apispec(self):
        """
+6 −3
Original line number Diff line number Diff line
@@ -98,10 +98,12 @@ async def get_app_with_request_session() -> AsyncGenerator[StructuredApp, None]:
    app = get_app()
    request_id = request_context.data["X-Request-ID"]
    app.model.set_request_id(request_id)
    app.install_model.set_request_id(request_id)
    try:
        yield app
    finally:
        app.model.unset_request_id(request_id)
        app.install_model.unset_request_id(request_id)


DependsOnApp = cast(StructuredApp, Depends(get_app_with_request_session))
@@ -118,9 +120,10 @@ class GalaxyTypeDepends(Depends):
        self.galaxy_type_depends = dep_type


def depends(dep_type: Type[T], get_app=get_app) -> T:
    def _do_resolve(request: Request):
        return get_app().resolve(dep_type)
def depends(dep_type: Type[T], app=get_app_with_request_session) -> T:
    async def _do_resolve(request: Request):
        async for _dep in app():
            yield _dep.resolve(dep_type)

    return cast(T, GalaxyTypeDepends(_do_resolve, dep_type))

+8 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import logging
import sys
import threading
import traceback
from typing import Optional
from urllib.parse import urljoin

from paste import httpexceptions
@@ -20,6 +21,7 @@ import galaxy.web.framework
import galaxy.webapps.base.webapp
from galaxy import util
from galaxy.security.validate_user_input import VALID_PUBLICNAME_RE
from galaxy.structured_app import MinimalApp
from galaxy.util import asbool
from galaxy.util.properties import load_app_properties
from galaxy.web.framework.middleware.error import ErrorMiddleware
@@ -34,6 +36,12 @@ log = logging.getLogger(__name__)
class GalaxyWebApplication(galaxy.webapps.base.webapp.WebApplication):
    injection_aware = True

    def __init__(
        self, galaxy_app: MinimalApp, session_cookie: str = "galaxysession", name: Optional[str] = None
    ) -> None:
        super().__init__(galaxy_app, session_cookie, name)
        self.session_factories.append(galaxy_app.install_model)


def app_factory(*args, **kwargs):
    """
+1 −1
Original line number Diff line number Diff line
@@ -81,7 +81,7 @@ api_key_cookie = APIKeyCookie(name=AUTH_COOKIE_NAME, auto_error=False)


def depends(dep_type: Type[T]) -> T:
    return framework_depends(dep_type, get_app=get_app)
    return framework_depends(dep_type, app=get_app_with_request_session)


def get_api_user(