Fix CORS headers not set on exceptions

This commit is contained in:
Niels Dewulf
2023-11-28 15:49:33 +01:00
parent bbd085bd39
commit a0f2647541
3 changed files with 84 additions and 3 deletions

View File

@@ -16,6 +16,7 @@ from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.server_error import ServerErrorMiddleware
from connexion.middleware.lifespan import Lifespan, LifespanMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
@@ -92,6 +93,17 @@ class _Options:
class MiddlewarePosition(enum.Enum):
"""Positions to insert a middleware"""
BEFORE_EXCEPTION = ExceptionMiddleware
"""Add before the :class:`ExceptionMiddleware`. This is useful if you want your changes to
affect the way exceptions are handled, such as a custom error handler.
Be mindful that security has not yet been applied at this stage.
Additionally, the inserted middleware is positioned before the RoutingMiddleware, so you cannot
leverage any routing information yet and should implement your middleware to work globally
instead of on an operation level.
Usefull for CORS middleware which should be applied before the exception middleware.
"""
BEFORE_SWAGGER = SwaggerUIMiddleware
"""Add before the :class:`SwaggerUIMiddleware`. This is useful if you want your changes to
affect the Swagger UI, such as a path altering middleware that should also alter the paths
@@ -164,6 +176,7 @@ class ConnexionMiddleware:
provided application."""
default_middlewares = [
ServerErrorMiddleware,
ExceptionMiddleware,
SwaggerUIMiddleware,
RoutingMiddleware,

View File

@@ -0,0 +1,68 @@
import asyncio
import functools
import logging
import typing as t
from starlette.concurrency import run_in_threadpool
from starlette.middleware.errors import (
ServerErrorMiddleware as StarletteServerErrorMiddleware,
)
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import InternalServerError
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.types import MaybeAwaitable
logger = logging.getLogger(__name__)
def connexion_wrapper(
handler: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
]
) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]:
"""Wrapper that translates Starlette requests to Connexion requests before passing
them to the error handler, and translates the returned Connexion responses to
Starlette responses."""
@functools.wraps(handler)
async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse:
request = ConnexionRequest.from_starlette_request(request)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc) # type: ignore
else:
response = await run_in_threadpool(handler, request, exc)
while asyncio.iscoroutine(response):
response = await response
return StarletteResponse(
content=response.body,
status_code=response.status_code,
media_type=response.mimetype,
headers=response.headers,
)
return wrapper
class ServerErrorMiddleware(StarletteServerErrorMiddleware):
"""Subclass of starlette ServerErrorMiddleware to change handling of Unhandled Server
exceptions to existing connexion behavior."""
def __init__(self, next_app: ASGIApp):
super().__init__(next_app)
@staticmethod
def error_response(
_request: StarletteRequest, exc: Exception
) -> ConnexionResponse:
"""Default handler for any unhandled Exception"""
logger.error("%r", exc, exc_info=exc)
return InternalServerError().to_problem()
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await super().__call__(scope, receive, send)

View File

@@ -28,7 +28,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
@@ -62,7 +62,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
@@ -96,7 +96,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],