mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 04:19:26 +00:00
Fix CORS headers not set on exceptions (#1821)
Fixes #1820. Correct error handling in response to CORS. Changes proposed in this pull request: - Add a MiddlewarePosition before Exception handling so CORS is always returned - Add ServerError Middleware to handle unhandled errors between the ServerError- and ExceptionMiddleware - Update corresponding docs --------- Co-authored-by: Robbe Sneyders <robbe.sneyders@ml6.eu>
This commit is contained in:
@@ -21,6 +21,7 @@ from connexion.middleware.request_validation import RequestValidationMiddleware
|
||||
from connexion.middleware.response_validation import ResponseValidationMiddleware
|
||||
from connexion.middleware.routing import RoutingMiddleware
|
||||
from connexion.middleware.security import SecurityMiddleware
|
||||
from connexion.middleware.server_error import ServerErrorMiddleware
|
||||
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
||||
from connexion.options import SwaggerUIOptions
|
||||
from connexion.resolver import Resolver
|
||||
@@ -93,6 +94,21 @@ class _Options:
|
||||
class MiddlewarePosition(enum.Enum):
|
||||
"""Positions to insert a middleware"""
|
||||
|
||||
BEFORE_EXCEPTION = ExceptionMiddleware
|
||||
"""Add before the :class:`ExceptionMiddleware`. This is useful if you want your changes to
|
||||
affect the way exceptions are handled, such as a custom error handler.
|
||||
|
||||
Be mindful that security has not yet been applied at this stage.
|
||||
Additionally, the inserted middleware is positioned before the RoutingMiddleware, so you cannot
|
||||
leverage any routing information yet and should implement your middleware to work globally
|
||||
instead of on an operation level.
|
||||
|
||||
Useful for middleware which should also be applied to error responses. Note that errors
|
||||
raised here will not be handled by the exception handlers and will always result in an
|
||||
internal server error response.
|
||||
|
||||
:meta hide-value:
|
||||
"""
|
||||
BEFORE_SWAGGER = SwaggerUIMiddleware
|
||||
"""Add before the :class:`SwaggerUIMiddleware`. This is useful if you want your changes to
|
||||
affect the Swagger UI, such as a path altering middleware that should also alter the paths
|
||||
@@ -165,6 +181,7 @@ class ConnexionMiddleware:
|
||||
provided application."""
|
||||
|
||||
default_middlewares = [
|
||||
ServerErrorMiddleware,
|
||||
ExceptionMiddleware,
|
||||
SwaggerUIMiddleware,
|
||||
RoutingMiddleware,
|
||||
|
||||
36
connexion/middleware/server_error.py
Normal file
36
connexion/middleware/server_error.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from starlette.middleware.errors import (
|
||||
ServerErrorMiddleware as StarletteServerErrorMiddleware,
|
||||
)
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from connexion.exceptions import InternalServerError
|
||||
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
||||
from connexion.middleware.exceptions import connexion_wrapper
|
||||
from connexion.types import MaybeAwaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerErrorMiddleware(StarletteServerErrorMiddleware):
|
||||
"""Subclass of starlette ServerErrorMiddleware to change handling of Unhandled Server
|
||||
exceptions to existing connexion behavior."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
next_app: ASGIApp,
|
||||
handler: t.Optional[
|
||||
t.Callable[[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]]
|
||||
] = None,
|
||||
):
|
||||
handler = connexion_wrapper(handler) if handler else None
|
||||
super().__init__(next_app, handler=handler)
|
||||
|
||||
@staticmethod
|
||||
@connexion_wrapper
|
||||
def error_response(_request: ConnexionRequest, exc: Exception) -> ConnexionResponse:
|
||||
"""Default handler for any unhandled Exception"""
|
||||
logger.error("%r", exc, exc_info=exc)
|
||||
return InternalServerError().to_problem()
|
||||
@@ -28,7 +28,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
position=MiddlewarePosition.BEFORE_ROUTING,
|
||||
position=MiddlewarePosition.BEFORE_EXCEPTION,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
@@ -62,7 +62,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
position=MiddlewarePosition.BEFORE_ROUTING,
|
||||
position=MiddlewarePosition.BEFORE_EXCEPTION,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
@@ -96,7 +96,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
position=MiddlewarePosition.BEFORE_ROUTING,
|
||||
position=MiddlewarePosition.BEFORE_EXCEPTION,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
|
||||
@@ -8,6 +8,8 @@ following order:
|
||||
.. csv-table::
|
||||
:widths: 30, 70
|
||||
|
||||
**ServerErrorMiddleware**, "Returns server errors for any exceptions not caught by the
|
||||
ExceptionMiddleware"
|
||||
**ExceptionMiddleware**, Handles exceptions raised by the middleware stack or application
|
||||
**SwaggerUIMiddleware**, Adds a Swagger UI to your application
|
||||
**RoutingMiddleware**, "Routes incoming requests to the right operation defined in the
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from connexion.middleware import MiddlewarePosition
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from conftest import FIXTURES_FOLDER, OPENAPI3_SPEC, build_app_from_fixture
|
||||
from conftest import OPENAPI3_SPEC, build_app_from_fixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -20,6 +22,24 @@ def simple_openapi_app(app_class):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def cors_openapi_app(app_class):
|
||||
app = build_app_from_fixture(
|
||||
"simple",
|
||||
app_class=app_class,
|
||||
spec_file=OPENAPI3_SPEC,
|
||||
validate_responses=True,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
position=MiddlewarePosition.BEFORE_EXCEPTION,
|
||||
allow_origins=["http://localhost"],
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def reverse_proxied_app(spec, app_class):
|
||||
class ReverseProxied:
|
||||
|
||||
44
tests/api/test_cors.py
Normal file
44
tests/api/test_cors.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import json
|
||||
|
||||
|
||||
def test_cors_valid(cors_openapi_app):
|
||||
app_client = cors_openapi_app.test_client()
|
||||
origin = "http://localhost"
|
||||
response = app_client.post("/v1.0/goodday/dan", data={}, headers={"Origin": origin})
|
||||
assert response.status_code == 201
|
||||
assert "Access-Control-Allow-Origin" in response.headers
|
||||
assert origin == response.headers["Access-Control-Allow-Origin"]
|
||||
|
||||
|
||||
def test_cors_invalid(cors_openapi_app):
|
||||
app_client = cors_openapi_app.test_client()
|
||||
response = app_client.options(
|
||||
"/v1.0/goodday/dan",
|
||||
headers={"Origin": "http://0.0.0.0", "Access-Control-Request-Method": "POST"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Access-Control-Allow-Origin" not in response.headers
|
||||
|
||||
|
||||
def test_cors_validation_error(cors_openapi_app):
|
||||
app_client = cors_openapi_app.test_client()
|
||||
origin = "http://localhost"
|
||||
response = app_client.post(
|
||||
"/v1.0/body-not-allowed-additional-properties",
|
||||
data={},
|
||||
headers={"Origin": origin},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Access-Control-Allow-Origin" in response.headers
|
||||
assert origin == response.headers["Access-Control-Allow-Origin"]
|
||||
|
||||
|
||||
def test_cors_server_error(cors_openapi_app):
|
||||
app_client = cors_openapi_app.test_client()
|
||||
origin = "http://localhost"
|
||||
response = app_client.post(
|
||||
"/v1.0/goodday/noheader", data={}, headers={"Origin": origin}
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert "Access-Control-Allow-Origin" in response.headers
|
||||
assert origin == response.headers["Access-Control-Allow-Origin"]
|
||||
Reference in New Issue
Block a user