Unverified Commit 35e2f759 authored by David López's avatar David López Committed by GitHub
Browse files

Merge pull request #14202 from mvdbeek/empty_response_middleware

[22.05] Add ``SuppressNoResponseReturnedMiddleware``
parents 533a33f9 84ea1076
Loading
Loading
Loading
Loading
+22 −6
Original line number Diff line number Diff line
from fastapi import (
    FastAPI,
    Request,
    status,
)
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.middleware.base import (
    BaseHTTPMiddleware,
    RequestResponseEndpoint,
)
from starlette.responses import Response

try:
from starlette_context.middleware import RawContextMiddleware
from starlette_context.plugins import RequestIdPlugin
except ImportError:
    pass

from galaxy.exceptions import MessageException
from galaxy.web.framework.base import walk_controller_modules
@@ -20,6 +21,21 @@ from galaxy.web.framework.decorators import (
)


# Copied from https://stackoverflow.com/questions/71222144/runtimeerror-no-response-returned-in-fastapi-when-refresh-request/72677699#72677699
class SuppressNoResponseReturnedMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
        try:
            return await call_next(request)
        except RuntimeError as exc:
            if str(exc) == "No response returned." and await request.is_disconnected():
                return Response(status_code=status.HTTP_204_NO_CONTENT)
            raise


def add_empty_response_middleware(app: FastAPI) -> None:
    app.add_middleware(SuppressNoResponseReturnedMiddleware)


def add_exception_handler(app: FastAPI) -> None:
    @app.exception_handler(RequestValidationError)
    async def validate_exception_middleware(request: Request, exc: RequestValidationError) -> Response:
+2 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ from starlette.responses import (

from galaxy.version import VERSION
from galaxy.webapps.base.api import (
    add_empty_response_middleware,
    add_exception_handler,
    add_request_id_middleware,
    include_all_package_routers,
@@ -176,6 +177,7 @@ def initialize_fast_app(gx_wsgi_webapp, gx_app):
    wsgi_handler = WSGIMiddleware(gx_wsgi_webapp)
    gx_app.haltables.append(("WSGI Middleware threadpool", wsgi_handler.executor.shutdown))
    app.mount("/", wsgi_handler)
    add_empty_response_middleware(app)
    if gx_app.config.galaxy_url_prefix != "/":
        parent_app = FastAPI()
        parent_app.mount(gx_app.config.galaxy_url_prefix, app=app)
+2 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ from a2wsgi import WSGIMiddleware
from fastapi import FastAPI

from galaxy.webapps.base.api import (
    add_empty_response_middleware,
    add_exception_handler,
    add_request_id_middleware,
    include_all_package_routers,
@@ -22,4 +23,5 @@ def initialize_fast_app(gx_webapp):
    include_all_package_routers(app, "galaxy.webapps.reports.api")
    wsgi_handler = WSGIMiddleware(gx_webapp)
    app.mount("/", wsgi_handler)
    add_empty_response_middleware(app)
    return app
+2 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ from a2wsgi import WSGIMiddleware
from fastapi import FastAPI

from galaxy.webapps.base.api import (
    add_empty_response_middleware,
    add_exception_handler,
    add_request_id_middleware,
    include_all_package_routers,
@@ -20,6 +21,7 @@ def initialize_fast_app(gx_webapp, tool_shed_app):
    wsgi_handler = WSGIMiddleware(gx_webapp)
    tool_shed_app.haltables.append(("WSGI Middleware threadpool", wsgi_handler.executor.shutdown))
    app.mount("/", wsgi_handler)
    add_empty_response_middleware(app)
    return app


+119 −0
Original line number Diff line number Diff line
import asyncio
import contextlib
import threading
import time
from typing import Optional

import pytest
import requests
import uvicorn
from fastapi import status
from fastapi.applications import FastAPI
from requests import ReadTimeout
from starlette.middleware.base import BaseHTTPMiddleware

from galaxy.util import sockets
from galaxy.webapps.base.api import add_empty_response_middleware

error_encountered: Optional[str] = None
error_handled = False


@pytest.fixture()
def reset_global_vars():
    global error_encountered
    global error_handled
    error_encountered = None
    error_handled = False


class SomeMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request, call_next):
        try:
            return await call_next(request)
        except Exception as e:
            global error_encountered
            error_encountered = str(e)
            raise


class OuterMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request, call_next):
        response = await call_next(request)
        assert response.status_code == status.HTTP_204_NO_CONTENT
        global error_handled
        error_handled = True
        return response


class Server(uvicorn.Server):
    """Uvicorn server in a thread.

    https://stackoverflow.com/a/66589593
    """

    def install_signal_handlers(self):
        pass

    @contextlib.contextmanager
    def run_in_thread(self):
        thread = threading.Thread(target=self.run)
        thread.start()
        try:
            while not self.started:
                time.sleep(1e-3)
            yield
        finally:
            self.should_exit = True
            thread.join()


def setup_fastAPI(add_middleware=True):
    app = FastAPI()
    # Looks weird, but we need at least 2 middlewares based on BaseHTTPMiddleware to trigger this.
    # xref https://github.com/encode/starlette/discussions/1527#discussion-3893922
    app.add_middleware(SomeMiddleware)
    app.add_middleware(SomeMiddleware)
    if add_middleware:
        add_empty_response_middleware(app)
    app.add_middleware(OuterMiddleware)

    @app.get("/")
    async def index():
        await asyncio.sleep(1)
        return

    return app


def test_client_disconnect_with_middleware(reset_global_vars):
    app = setup_fastAPI()
    port = sockets.unused_port()
    server = Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port))
    with server.run_in_thread():
        try:
            requests.get(f"http://127.0.0.1:{port}/", timeout=0.1)
        except ReadTimeout:
            pass

    assert error_encountered == "No response returned."
    assert error_handled


def test_client_disconnect_raises_error_without_middleware(reset_global_vars):
    app = setup_fastAPI(add_middleware=False)
    port = sockets.unused_port()
    server = Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port))
    with server.run_in_thread():
        try:
            requests.get(f"http://127.0.0.1:{port}/", timeout=0.1)
        except ReadTimeout:
            pass

    assert error_encountered == "No response returned."
    try:
        assert not error_handled
    except AssertionError:
        raise Exception(
            "add_empty_response_middleware not required anymore, bug likely fixed upstream. You can revert https://github.com/galaxyproject/galaxy/pull/14202"
        )