mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 12:27:45 +00:00
Centralize error handling in ExceptionMiddleware (#1754)
I was writing the documentation on exception handling, and I noticed that it was very hard to explain our current behavior. Error handlers can be registered either on the internal Flask app (not the Starlette one) or on the Connexion app, which leads to some undefined (actually just really hard to explain) behavior. Eg. - Registering error handlers on a status code would capture `starlette.HTTPException` errors on the Connexion app, and `werkzeug.HTTPException` errors on the Flask App, which means that registering an error handler on a status code doesn't catch all the errors with that status code. - Flask does some default error handling which leads to some exceptions never reaching the error handlers registered on the Connexion app. So I made the following changes: - Replaced the default error handlers we registered on the Flask app with a default handler on the `ExceptionMiddleware` that takes into account other handlers registered on status codes. - Configured Flask to propagate exceptions instead of catching them. - Abstracted away the Starlette `Request` and `Response` types, so users can and must now use `ConnexionRequest` and `ConnexionResponse` types in error handlers. - Renamed the `ASGIRequest` class to `ConnexionRequest` since it is the only Request class part of the high level Connexion interface. We could also rename `ConnexionRequest` and `ConnexionResponse` to just `Request` and `Response`. Wdyt?
This commit is contained in:
@@ -10,10 +10,12 @@ from starlette.testclient import TestClient
|
|||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
from connexion.jsonifier import Jsonifier
|
from connexion.jsonifier import Jsonifier
|
||||||
|
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
||||||
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware
|
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware
|
||||||
from connexion.middleware.lifespan import Lifespan
|
from connexion.middleware.lifespan import Lifespan
|
||||||
from connexion.options import SwaggerUIOptions
|
from connexion.options import SwaggerUIOptions
|
||||||
from connexion.resolver import Resolver
|
from connexion.resolver import Resolver
|
||||||
|
from connexion.types import MaybeAwaitable
|
||||||
from connexion.uri_parsing import AbstractURIParser
|
from connexion.uri_parsing import AbstractURIParser
|
||||||
|
|
||||||
|
|
||||||
@@ -250,14 +252,18 @@ class AbstractApp:
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def add_error_handler(
|
def add_error_handler(
|
||||||
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
|
self,
|
||||||
|
code_or_exception: t.Union[int, t.Type[Exception]],
|
||||||
|
function: t.Callable[
|
||||||
|
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
|
||||||
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Register a callable to handle application errors.
|
Register a callable to handle application errors.
|
||||||
|
|
||||||
:param code_or_exception: An exception class or the status code of HTTP exceptions to
|
:param code_or_exception: An exception class or the status code of HTTP exceptions to
|
||||||
handle.
|
handle.
|
||||||
:param function: Callable that will handle exception.
|
:param function: Callable that will handle exception, may be async.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_client(self, **kwargs):
|
def test_client(self, **kwargs):
|
||||||
|
|||||||
@@ -14,11 +14,13 @@ from starlette.types import Receive, Scope, Send
|
|||||||
from connexion.apps.abstract import AbstractApp
|
from connexion.apps.abstract import AbstractApp
|
||||||
from connexion.decorators import StarletteDecorator
|
from connexion.decorators import StarletteDecorator
|
||||||
from connexion.jsonifier import Jsonifier
|
from connexion.jsonifier import Jsonifier
|
||||||
|
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
||||||
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
|
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
|
||||||
from connexion.middleware.lifespan import Lifespan
|
from connexion.middleware.lifespan import Lifespan
|
||||||
from connexion.operations import AbstractOperation
|
from connexion.operations import AbstractOperation
|
||||||
from connexion.options import SwaggerUIOptions
|
from connexion.options import SwaggerUIOptions
|
||||||
from connexion.resolver import Resolver
|
from connexion.resolver import Resolver
|
||||||
|
from connexion.types import MaybeAwaitable
|
||||||
from connexion.uri_parsing import AbstractURIParser
|
from connexion.uri_parsing import AbstractURIParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -88,7 +90,7 @@ class AsyncApi(RoutedAPI[AsyncOperation]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncMiddlewareApp(RoutedMiddleware[AsyncApi]):
|
class AsyncASGIApp(RoutedMiddleware[AsyncApi]):
|
||||||
|
|
||||||
api_cls = AsyncApi
|
api_cls = AsyncApi
|
||||||
|
|
||||||
@@ -176,7 +178,7 @@ class AsyncApp(AbstractApp):
|
|||||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||||
:obj:`security.SECURITY_HANDLERS`
|
:obj:`security.SECURITY_HANDLERS`
|
||||||
"""
|
"""
|
||||||
self._middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp()
|
self._middleware_app: AsyncASGIApp = AsyncASGIApp()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
import_name,
|
import_name,
|
||||||
@@ -205,6 +207,10 @@ class AsyncApp(AbstractApp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def add_error_handler(
|
def add_error_handler(
|
||||||
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
|
self,
|
||||||
|
code_or_exception: t.Union[int, t.Type[Exception]],
|
||||||
|
function: t.Callable[
|
||||||
|
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
|
||||||
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.middleware.add_error_handler(code_or_exception, function)
|
self.middleware.add_error_handler(code_or_exception, function)
|
||||||
|
|||||||
@@ -6,23 +6,22 @@ import pathlib
|
|||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
import werkzeug.exceptions
|
|
||||||
from a2wsgi import WSGIMiddleware
|
from a2wsgi import WSGIMiddleware
|
||||||
from flask import Response as FlaskResponse
|
from flask import Response as FlaskResponse
|
||||||
from flask import signals
|
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
from connexion.apps.abstract import AbstractApp
|
from connexion.apps.abstract import AbstractApp
|
||||||
from connexion.decorators import FlaskDecorator
|
from connexion.decorators import FlaskDecorator
|
||||||
from connexion.exceptions import InternalServerError, ProblemException, ResolverError
|
from connexion.exceptions import ResolverError
|
||||||
from connexion.frameworks import flask as flask_utils
|
from connexion.frameworks import flask as flask_utils
|
||||||
from connexion.jsonifier import Jsonifier
|
from connexion.jsonifier import Jsonifier
|
||||||
|
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
||||||
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
|
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
|
||||||
from connexion.middleware.lifespan import Lifespan
|
from connexion.middleware.lifespan import Lifespan
|
||||||
from connexion.operations import AbstractOperation
|
from connexion.operations import AbstractOperation
|
||||||
from connexion.options import SwaggerUIOptions
|
from connexion.options import SwaggerUIOptions
|
||||||
from connexion.problem import problem
|
|
||||||
from connexion.resolver import Resolver
|
from connexion.resolver import Resolver
|
||||||
|
from connexion.types import MaybeAwaitable
|
||||||
from connexion.uri_parsing import AbstractURIParser
|
from connexion.uri_parsing import AbstractURIParser
|
||||||
|
|
||||||
|
|
||||||
@@ -117,44 +116,20 @@ class FlaskApi(AbstractRoutingAPI):
|
|||||||
return self.blueprint.add_url_rule(rule, endpoint, view_func, **options)
|
return self.blueprint.add_url_rule(rule, endpoint, view_func, **options)
|
||||||
|
|
||||||
|
|
||||||
class FlaskMiddlewareApp(SpecMiddleware):
|
class FlaskASGIApp(SpecMiddleware):
|
||||||
def __init__(self, import_name, server_args: dict, **kwargs):
|
def __init__(self, import_name, server_args: dict, **kwargs):
|
||||||
self.app = flask.Flask(import_name, **server_args)
|
self.app = flask.Flask(import_name, **server_args)
|
||||||
self.app.json = flask_utils.FlaskJSONProvider(self.app)
|
self.app.json = flask_utils.FlaskJSONProvider(self.app)
|
||||||
self.app.url_map.converters["float"] = flask_utils.NumberConverter
|
self.app.url_map.converters["float"] = flask_utils.NumberConverter
|
||||||
self.app.url_map.converters["int"] = flask_utils.IntegerConverter
|
self.app.url_map.converters["int"] = flask_utils.IntegerConverter
|
||||||
|
|
||||||
self.set_errors_handlers()
|
# Propagate Errors so we can handle them in the middleware
|
||||||
|
self.app.config["PROPAGATE_EXCEPTIONS"] = True
|
||||||
|
self.app.config["TRAP_BAD_REQUEST_ERRORS"] = True
|
||||||
|
self.app.config["TRAP_HTTP_EXCEPTIONS"] = True
|
||||||
|
|
||||||
self.asgi_app = WSGIMiddleware(self.app.wsgi_app)
|
self.asgi_app = WSGIMiddleware(self.app.wsgi_app)
|
||||||
|
|
||||||
def set_errors_handlers(self):
|
|
||||||
for error_code in werkzeug.exceptions.default_exceptions:
|
|
||||||
self.app.register_error_handler(error_code, self.common_error_handler)
|
|
||||||
|
|
||||||
self.app.register_error_handler(ProblemException, self.common_error_handler)
|
|
||||||
|
|
||||||
def common_error_handler(self, exception: Exception) -> FlaskResponse:
|
|
||||||
"""Default error handler."""
|
|
||||||
if isinstance(exception, ProblemException):
|
|
||||||
response = exception.to_problem()
|
|
||||||
else:
|
|
||||||
if not isinstance(exception, werkzeug.exceptions.HTTPException):
|
|
||||||
exception = InternalServerError()
|
|
||||||
|
|
||||||
response = problem(
|
|
||||||
title=exception.name,
|
|
||||||
detail=exception.description,
|
|
||||||
status=exception.code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code >= 500:
|
|
||||||
signals.got_request_exception.send(self.app, exception=exception)
|
|
||||||
|
|
||||||
return flask.make_response(
|
|
||||||
(response.body, response.status_code, response.headers)
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_api(self, specification, *, name: str = None, **kwargs):
|
def add_api(self, specification, *, name: str = None, **kwargs):
|
||||||
api = FlaskApi(specification, **kwargs)
|
api = FlaskApi(specification, **kwargs)
|
||||||
|
|
||||||
@@ -177,7 +152,7 @@ class FlaskMiddlewareApp(SpecMiddleware):
|
|||||||
class FlaskApp(AbstractApp):
|
class FlaskApp(AbstractApp):
|
||||||
"""Connexion Application based on ConnexionMiddleware wrapping a Flask application."""
|
"""Connexion Application based on ConnexionMiddleware wrapping a Flask application."""
|
||||||
|
|
||||||
_middleware_app: FlaskMiddlewareApp
|
_middleware_app: FlaskASGIApp
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -237,7 +212,7 @@ class FlaskApp(AbstractApp):
|
|||||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||||
:obj:`security.SECURITY_HANDLERS`
|
:obj:`security.SECURITY_HANDLERS`
|
||||||
"""
|
"""
|
||||||
self._middleware_app = FlaskMiddlewareApp(import_name, server_args or {})
|
self._middleware_app = FlaskASGIApp(import_name, server_args or {})
|
||||||
self.app = self._middleware_app.app
|
self.app = self._middleware_app.app
|
||||||
super().__init__(
|
super().__init__(
|
||||||
import_name,
|
import_name,
|
||||||
@@ -266,6 +241,10 @@ class FlaskApp(AbstractApp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def add_error_handler(
|
def add_error_handler(
|
||||||
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
|
self,
|
||||||
|
code_or_exception: t.Union[int, t.Type[Exception]],
|
||||||
|
function: t.Callable[
|
||||||
|
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
|
||||||
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.app.register_error_handler(code_or_exception, function)
|
self.app.register_error_handler(code_or_exception, function)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from contextvars import ContextVar
|
|||||||
from starlette.types import Receive, Scope
|
from starlette.types import Receive, Scope
|
||||||
from werkzeug.local import LocalProxy
|
from werkzeug.local import LocalProxy
|
||||||
|
|
||||||
from connexion.lifecycle import ASGIRequest
|
from connexion.lifecycle import ConnexionRequest
|
||||||
from connexion.operations import AbstractOperation
|
from connexion.operations import AbstractOperation
|
||||||
|
|
||||||
UNBOUND_MESSAGE = (
|
UNBOUND_MESSAGE = (
|
||||||
@@ -25,5 +25,5 @@ _scope: ContextVar[Scope] = ContextVar("SCOPE")
|
|||||||
scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE)
|
scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE)
|
||||||
|
|
||||||
request = LocalProxy(
|
request = LocalProxy(
|
||||||
lambda: ASGIRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
|
lambda: ConnexionRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import inflection
|
|||||||
from connexion.context import context, operation
|
from connexion.context import context, operation
|
||||||
from connexion.frameworks.abstract import Framework
|
from connexion.frameworks.abstract import Framework
|
||||||
from connexion.http_facts import FORM_CONTENT_TYPES
|
from connexion.http_facts import FORM_CONTENT_TYPES
|
||||||
from connexion.lifecycle import ASGIRequest, WSGIRequest
|
from connexion.lifecycle import ConnexionRequest, WSGIRequest
|
||||||
from connexion.operations import AbstractOperation, Swagger2Operation
|
from connexion.operations import AbstractOperation, Swagger2Operation
|
||||||
from connexion.utils import (
|
from connexion.utils import (
|
||||||
deep_merge,
|
deep_merge,
|
||||||
@@ -43,7 +43,7 @@ class BaseParameterDecorator:
|
|||||||
|
|
||||||
def _maybe_get_body(
|
def _maybe_get_body(
|
||||||
self,
|
self,
|
||||||
request: t.Union[WSGIRequest, ASGIRequest],
|
request: t.Union[WSGIRequest, ConnexionRequest],
|
||||||
*,
|
*,
|
||||||
arguments: t.List[str],
|
arguments: t.List[str],
|
||||||
has_kwargs: bool,
|
has_kwargs: bool,
|
||||||
@@ -95,7 +95,7 @@ class AsyncParameterDecorator(BaseParameterDecorator):
|
|||||||
arguments, has_kwargs = inspect_function_arguments(unwrapped_function)
|
arguments, has_kwargs = inspect_function_arguments(unwrapped_function)
|
||||||
|
|
||||||
@functools.wraps(function)
|
@functools.wraps(function)
|
||||||
async def wrapper(request: ASGIRequest) -> t.Any:
|
async def wrapper(request: ConnexionRequest) -> t.Any:
|
||||||
request_body = self._maybe_get_body(
|
request_body = self._maybe_get_body(
|
||||||
request, arguments=arguments, has_kwargs=has_kwargs
|
request, arguments=arguments, has_kwargs=has_kwargs
|
||||||
)
|
)
|
||||||
@@ -118,7 +118,7 @@ class AsyncParameterDecorator(BaseParameterDecorator):
|
|||||||
|
|
||||||
|
|
||||||
def prep_kwargs(
|
def prep_kwargs(
|
||||||
request: t.Union[WSGIRequest, ASGIRequest],
|
request: t.Union[WSGIRequest, ConnexionRequest],
|
||||||
*,
|
*,
|
||||||
request_body: t.Any,
|
request_body: t.Any,
|
||||||
files: t.Dict[str, t.Any],
|
files: t.Dict[str, t.Any],
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from starlette.responses import Response as StarletteResponse
|
|||||||
from starlette.types import Receive, Scope
|
from starlette.types import Receive, Scope
|
||||||
|
|
||||||
from connexion.frameworks.abstract import Framework
|
from connexion.frameworks.abstract import Framework
|
||||||
from connexion.lifecycle import ASGIRequest
|
from connexion.lifecycle import ConnexionRequest
|
||||||
from connexion.uri_parsing import AbstractURIParser
|
from connexion.uri_parsing import AbstractURIParser
|
||||||
|
|
||||||
|
|
||||||
@@ -48,8 +48,8 @@ class Starlette(Framework):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ASGIRequest: # type: ignore
|
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore
|
||||||
return ASGIRequest(scope, receive, uri_parser=uri_parser)
|
return ConnexionRequest(scope, receive, uri_parser=uri_parser)
|
||||||
|
|
||||||
|
|
||||||
PATH_PARAMETER = re.compile(r"\{([^}]*)\}")
|
PATH_PARAMETER = re.compile(r"\{([^}]*)\}")
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class WSGIRequest(_RequestInterface):
|
|||||||
return getattr(self._werkzeug_request, item)
|
return getattr(self._werkzeug_request, item)
|
||||||
|
|
||||||
|
|
||||||
class ASGIRequest(_RequestInterface):
|
class ConnexionRequest(_RequestInterface):
|
||||||
"""
|
"""
|
||||||
Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request.
|
Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request.
|
||||||
|
|
||||||
@@ -142,6 +142,8 @@ class ASGIRequest(_RequestInterface):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, uri_parser=None, **kwargs):
|
def __init__(self, *args, uri_parser=None, **kwargs):
|
||||||
|
# Might be set in `from_starlette_request` class method
|
||||||
|
if not hasattr(self, "_starlette_request"):
|
||||||
self._starlette_request = StarletteRequest(*args, **kwargs)
|
self._starlette_request = StarletteRequest(*args, **kwargs)
|
||||||
self.uri_parser = uri_parser
|
self.uri_parser = uri_parser
|
||||||
|
|
||||||
@@ -152,6 +154,16 @@ class ASGIRequest(_RequestInterface):
|
|||||||
self._form = None
|
self._form = None
|
||||||
self._files = None
|
self._files = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_starlette_request(
|
||||||
|
cls, request: StarletteRequest, uri_parser=None
|
||||||
|
) -> "ConnexionRequest":
|
||||||
|
# Instantiate the class, and set the `_starlette_request` property before initializing.
|
||||||
|
self = cls.__new__(cls)
|
||||||
|
self._starlette_request = request
|
||||||
|
self.__init__(uri_parser=uri_parser) # type: ignore
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def context(self):
|
def context(self):
|
||||||
if self._context is None:
|
if self._context is None:
|
||||||
@@ -226,6 +238,7 @@ class ASGIRequest(_RequestInterface):
|
|||||||
return await self.body() or None
|
return await self.body() or None
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
|
if self.__getattribute__("_starlette_request"):
|
||||||
return getattr(self._starlette_request, item)
|
return getattr(self._starlette_request, item)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,68 +1,114 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
import werkzeug.exceptions
|
||||||
|
from starlette.concurrency import run_in_threadpool
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.middleware.exceptions import (
|
from starlette.middleware.exceptions import (
|
||||||
ExceptionMiddleware as StarletteExceptionMiddleware,
|
ExceptionMiddleware as StarletteExceptionMiddleware,
|
||||||
)
|
)
|
||||||
from starlette.requests import Request as StarletteRequest
|
from starlette.requests import Request as StarletteRequest
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response as StarletteResponse
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
from connexion.exceptions import InternalServerError, ProblemException, problem
|
from connexion.exceptions import InternalServerError, ProblemException, problem
|
||||||
|
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
||||||
|
from connexion.types import MaybeAwaitable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def connexion_wrapper(
|
||||||
|
handler: t.Callable[
|
||||||
|
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
|
||||||
|
]
|
||||||
|
) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]:
|
||||||
|
"""Wrapper that translates Starlette requests to Connexion requests before passing
|
||||||
|
them to the error handler, and translates the returned Connexion responses to
|
||||||
|
Starlette responses."""
|
||||||
|
|
||||||
|
async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse:
|
||||||
|
request = ConnexionRequest.from_starlette_request(request)
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(handler):
|
||||||
|
response = await handler(request, exc) # type: ignore
|
||||||
|
else:
|
||||||
|
response = await run_in_threadpool(handler, request, exc)
|
||||||
|
|
||||||
|
return StarletteResponse(
|
||||||
|
content=response.body,
|
||||||
|
status_code=response.status_code,
|
||||||
|
media_type=response.mimetype,
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class ExceptionMiddleware(StarletteExceptionMiddleware):
|
class ExceptionMiddleware(StarletteExceptionMiddleware):
|
||||||
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
|
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
|
||||||
existing connexion behavior."""
|
existing connexion behavior."""
|
||||||
|
|
||||||
def __init__(self, next_app: ASGIApp):
|
def __init__(self, next_app: ASGIApp):
|
||||||
super().__init__(next_app)
|
super().__init__(next_app)
|
||||||
self.add_exception_handler(ProblemException, self.problem_handler)
|
self.add_exception_handler(ProblemException, self.problem_handler) # type: ignore
|
||||||
|
self.add_exception_handler(
|
||||||
|
werkzeug.exceptions.HTTPException, self.flask_error_handler
|
||||||
|
)
|
||||||
self.add_exception_handler(Exception, self.common_error_handler)
|
self.add_exception_handler(Exception, self.common_error_handler)
|
||||||
|
|
||||||
|
def add_exception_handler(
|
||||||
|
self,
|
||||||
|
exc_class_or_status_code: t.Union[int, t.Type[Exception]],
|
||||||
|
handler: t.Callable[[ConnexionRequest, Exception], StarletteResponse],
|
||||||
|
) -> None:
|
||||||
|
super().add_exception_handler(
|
||||||
|
exc_class_or_status_code, handler=connexion_wrapper(handler)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def problem_handler(_request: StarletteRequest, exc: ProblemException):
|
def problem_handler(_request: ConnexionRequest, exc: ProblemException):
|
||||||
|
"""Default handler for Connexion ProblemExceptions"""
|
||||||
logger.error("%r", exc)
|
logger.error("%r", exc)
|
||||||
|
return exc.to_problem()
|
||||||
response = exc.to_problem()
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
content=response.body,
|
|
||||||
status_code=response.status_code,
|
|
||||||
media_type=response.mimetype,
|
|
||||||
headers=response.headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def http_exception(_request: StarletteRequest, exc: HTTPException) -> Response:
|
@connexion_wrapper
|
||||||
|
def http_exception(
|
||||||
|
_request: StarletteRequest, exc: HTTPException, **kwargs
|
||||||
|
) -> StarletteResponse:
|
||||||
|
"""Default handler for Starlette HTTPException"""
|
||||||
logger.error("%r", exc)
|
logger.error("%r", exc)
|
||||||
|
return problem(
|
||||||
headers = exc.headers
|
title=exc.detail,
|
||||||
|
detail=exc.detail,
|
||||||
connexion_response = problem(
|
status=exc.status_code,
|
||||||
title=exc.detail, detail=exc.detail, status=exc.status_code, headers=headers
|
headers=exc.headers,
|
||||||
)
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
content=connexion_response.body,
|
|
||||||
status_code=connexion_response.status_code,
|
|
||||||
media_type=connexion_response.mimetype,
|
|
||||||
headers=connexion_response.headers,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response:
|
def common_error_handler(
|
||||||
|
_request: StarletteRequest, exc: Exception
|
||||||
|
) -> ConnexionResponse:
|
||||||
|
"""Default handler for any unhandled Exception"""
|
||||||
logger.error("%r", exc, exc_info=exc)
|
logger.error("%r", exc, exc_info=exc)
|
||||||
|
return InternalServerError().to_problem()
|
||||||
|
|
||||||
response = InternalServerError().to_problem()
|
def flask_error_handler(
|
||||||
|
self, request: StarletteRequest, exc: werkzeug.exceptions.HTTPException
|
||||||
|
) -> ConnexionResponse:
|
||||||
|
"""Default handler for Flask / werkzeug HTTPException"""
|
||||||
|
# If a handler is registered for the received status_code, call it instead.
|
||||||
|
# This is only done automatically for Starlette HTTPExceptions
|
||||||
|
if handler := self._status_handlers.get(exc.code):
|
||||||
|
starlette_exception = HTTPException(exc.code, detail=exc.description)
|
||||||
|
return handler(request, starlette_exception)
|
||||||
|
|
||||||
return Response(
|
return problem(
|
||||||
content=response.body,
|
title=exc.name,
|
||||||
status_code=response.status_code,
|
detail=exc.description,
|
||||||
media_type=response.mimetype,
|
status=exc.code,
|
||||||
headers=response.headers,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
|||||||
from connexion import utils
|
from connexion import utils
|
||||||
from connexion.handlers import ResolverErrorHandler
|
from connexion.handlers import ResolverErrorHandler
|
||||||
from connexion.jsonifier import Jsonifier
|
from connexion.jsonifier import Jsonifier
|
||||||
|
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
||||||
from connexion.middleware.abstract import SpecMiddleware
|
from connexion.middleware.abstract import SpecMiddleware
|
||||||
from connexion.middleware.context import ContextMiddleware
|
from connexion.middleware.context import ContextMiddleware
|
||||||
from connexion.middleware.exceptions import ExceptionMiddleware
|
from connexion.middleware.exceptions import ExceptionMiddleware
|
||||||
@@ -23,6 +24,7 @@ from connexion.middleware.security import SecurityMiddleware
|
|||||||
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
||||||
from connexion.options import SwaggerUIOptions
|
from connexion.options import SwaggerUIOptions
|
||||||
from connexion.resolver import Resolver
|
from connexion.resolver import Resolver
|
||||||
|
from connexion.types import MaybeAwaitable
|
||||||
from connexion.uri_parsing import AbstractURIParser
|
from connexion.uri_parsing import AbstractURIParser
|
||||||
from connexion.utils import inspect_function_arguments
|
from connexion.utils import inspect_function_arguments
|
||||||
|
|
||||||
@@ -419,14 +421,18 @@ class ConnexionMiddleware:
|
|||||||
self.apis.append(api)
|
self.apis.append(api)
|
||||||
|
|
||||||
def add_error_handler(
|
def add_error_handler(
|
||||||
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
|
self,
|
||||||
|
code_or_exception: t.Union[int, t.Type[Exception]],
|
||||||
|
function: t.Callable[
|
||||||
|
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
|
||||||
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Register a callable to handle application errors.
|
Register a callable to handle application errors.
|
||||||
|
|
||||||
:param code_or_exception: An exception class or the status code of HTTP exceptions to
|
:param code_or_exception: An exception class or the status code of HTTP exceptions to
|
||||||
handle.
|
handle.
|
||||||
:param function: Callable that will handle exception.
|
:param function: Callable that will handle exception, may be async.
|
||||||
"""
|
"""
|
||||||
if self.middleware_stack is not None:
|
if self.middleware_stack is not None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
|||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
from connexion.exceptions import ProblemException
|
from connexion.exceptions import ProblemException
|
||||||
from connexion.lifecycle import ASGIRequest
|
from connexion.lifecycle import ConnexionRequest
|
||||||
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
|
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
|
||||||
from connexion.operations import AbstractOperation
|
from connexion.operations import AbstractOperation
|
||||||
from connexion.security import SecurityHandlerFactory
|
from connexion.security import SecurityHandlerFactory
|
||||||
@@ -95,7 +95,7 @@ class SecurityOperation:
|
|||||||
await self.next_app(scope, receive, send)
|
await self.next_app(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
request = ASGIRequest(scope)
|
request = ConnexionRequest(scope)
|
||||||
await self.verification_fn(request)
|
await self.verification_fn(request)
|
||||||
await self.next_app(scope, receive, send)
|
await self.next_app(scope, receive, send)
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ import httpx
|
|||||||
|
|
||||||
from connexion.decorators.parameter import inspect_function_arguments
|
from connexion.decorators.parameter import inspect_function_arguments
|
||||||
from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem
|
from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem
|
||||||
from connexion.lifecycle import ASGIRequest
|
from connexion.lifecycle import ConnexionRequest
|
||||||
from connexion.utils import get_function_from_name
|
from connexion.utils import get_function_from_name
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -248,7 +248,7 @@ class ApiKeySecurityHandler(AbstractSecurityHandler):
|
|||||||
def _get_verify_func(self, api_key_info_func, loc, name):
|
def _get_verify_func(self, api_key_info_func, loc, name):
|
||||||
check_api_key_func = self.check_api_key(api_key_info_func)
|
check_api_key_func = self.check_api_key(api_key_info_func)
|
||||||
|
|
||||||
def wrapper(request: ASGIRequest):
|
def wrapper(request: ConnexionRequest):
|
||||||
if loc == "query":
|
if loc == "query":
|
||||||
api_key = request.query_params.get(name)
|
api_key = request.query_params.get(name)
|
||||||
elif loc == "header":
|
elif loc == "header":
|
||||||
|
|||||||
4
connexion/types.py
Normal file
4
connexion/types.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
import typing as t
|
||||||
|
|
||||||
|
ReturnType = t.TypeVar("ReturnType")
|
||||||
|
MaybeAwaitable = t.Union[t.Awaitable[ReturnType], ReturnType]
|
||||||
@@ -24,12 +24,13 @@ See below for an explanation of the different variables.
|
|||||||
request
|
request
|
||||||
-------
|
-------
|
||||||
|
|
||||||
A ``Request`` object representing the incoming request. This is an instance of the ``ASGIRequest``.
|
A ``Request`` object representing the incoming request. This is an instance of the
|
||||||
|
``ConnexionRequest``.
|
||||||
|
|
||||||
.. dropdown:: View a detailed reference of the ``ASGIRequest`` class
|
.. dropdown:: View a detailed reference of the ``ConnexionRequest`` class
|
||||||
:icon: eye
|
:icon: eye
|
||||||
|
|
||||||
.. autoclass:: connexion.lifecycle.ASGIRequest
|
.. autoclass:: connexion.lifecycle.ConnexionRequest
|
||||||
:noindex:
|
:noindex:
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ request.
|
|||||||
.. dropdown:: View a detailed reference of the ``connexion.request`` class
|
.. dropdown:: View a detailed reference of the ``connexion.request`` class
|
||||||
:icon: eye
|
:icon: eye
|
||||||
|
|
||||||
.. autoclass:: connexion.lifecycle.ASGIRequest
|
.. autoclass:: connexion.lifecycle.ConnexionRequest
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:inherited-members:
|
:inherited-members:
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ operation:
|
|||||||
|
|
||||||
Note that :code:`HEAD` requests will be handled by the :code:`operationId` specified under the
|
Note that :code:`HEAD` requests will be handled by the :code:`operationId` specified under the
|
||||||
:code:`GET` operation in the specification. :code:`Connexion.request.method` can be used to
|
:code:`GET` operation in the specification. :code:`Connexion.request.method` can be used to
|
||||||
determine which request was made. See :class:`.ASGIRequest`.
|
determine which request was made. See :class:`.ConnexionRequest`.
|
||||||
|
|
||||||
Automatic routing
|
Automatic routing
|
||||||
-----------------
|
-----------------
|
||||||
|
|||||||
@@ -127,6 +127,8 @@ Smaller breaking changes
|
|||||||
has been added to work with Flask's ``MethodView`` specifically.
|
has been added to work with Flask's ``MethodView`` specifically.
|
||||||
* Built-in support for uWSGI has been removed. You can re-add this functionality using a custom middleware.
|
* Built-in support for uWSGI has been removed. You can re-add this functionality using a custom middleware.
|
||||||
* The request body is now passed through for ``GET``, ``HEAD``, ``DELETE``, ``CONNECT`` and ``OPTIONS`` methods as well.
|
* The request body is now passed through for ``GET``, ``HEAD``, ``DELETE``, ``CONNECT`` and ``OPTIONS`` methods as well.
|
||||||
|
* Error handlers registered on the on the underlying Flask app directly will be ignored. You
|
||||||
|
should register them on the Connexion app directly.
|
||||||
|
|
||||||
|
|
||||||
Non-breaking changes
|
Non-breaking changes
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from connexion.exceptions import (
|
|||||||
OAuthResponseProblem,
|
OAuthResponseProblem,
|
||||||
OAuthScopeProblem,
|
OAuthScopeProblem,
|
||||||
)
|
)
|
||||||
from connexion.lifecycle import ASGIRequest
|
from connexion.lifecycle import ConnexionRequest
|
||||||
from connexion.security import (
|
from connexion.security import (
|
||||||
NO_VALUE,
|
NO_VALUE,
|
||||||
ApiKeySecurityHandler,
|
ApiKeySecurityHandler,
|
||||||
@@ -61,7 +61,7 @@ def test_verify_oauth_missing_auth_header():
|
|||||||
somefunc, security_handler.validate_scope, ["admin"]
|
somefunc, security_handler.validate_scope, ["admin"]
|
||||||
)
|
)
|
||||||
|
|
||||||
request = ASGIRequest(scope={"type": "http", "headers": []})
|
request = ConnexionRequest(scope={"type": "http", "headers": []})
|
||||||
|
|
||||||
assert wrapped_func(request) is NO_VALUE
|
assert wrapped_func(request) is NO_VALUE
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ async def test_verify_oauth_scopes_remote(monkeypatch):
|
|||||||
token_info_func, security_handler.validate_scope, ["admin"]
|
token_info_func, security_handler.validate_scope, ["admin"]
|
||||||
)
|
)
|
||||||
|
|
||||||
request = ASGIRequest(
|
request = ConnexionRequest(
|
||||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ async def test_verify_oauth_invalid_local_token_response_none():
|
|||||||
somefunc, security_handler.validate_scope, ["admin"]
|
somefunc, security_handler.validate_scope, ["admin"]
|
||||||
)
|
)
|
||||||
|
|
||||||
request = ASGIRequest(
|
request = ConnexionRequest(
|
||||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,7 +143,7 @@ async def test_verify_oauth_scopes_local():
|
|||||||
token_info, security_handler.validate_scope, ["admin"]
|
token_info, security_handler.validate_scope, ["admin"]
|
||||||
)
|
)
|
||||||
|
|
||||||
request = ASGIRequest(
|
request = ConnexionRequest(
|
||||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ def test_verify_basic_missing_auth_header():
|
|||||||
security_handler = BasicSecurityHandler()
|
security_handler = BasicSecurityHandler()
|
||||||
wrapped_func = security_handler._get_verify_func(somefunc)
|
wrapped_func = security_handler._get_verify_func(somefunc)
|
||||||
|
|
||||||
request = ASGIRequest(
|
request = ConnexionRequest(
|
||||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ async def test_verify_basic():
|
|||||||
security_handler = BasicSecurityHandler()
|
security_handler = BasicSecurityHandler()
|
||||||
wrapped_func = security_handler._get_verify_func(basic_info)
|
wrapped_func = security_handler._get_verify_func(basic_info)
|
||||||
|
|
||||||
request = ASGIRequest(
|
request = ConnexionRequest(
|
||||||
scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]}
|
scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,7 +212,7 @@ async def test_verify_apikey_query():
|
|||||||
apikey_info, "query", "auth"
|
apikey_info, "query", "auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
request = ASGIRequest(scope={"type": "http", "query_string": b"auth=foobar"})
|
request = ConnexionRequest(scope={"type": "http", "query_string": b"auth=foobar"})
|
||||||
|
|
||||||
assert await wrapped_func(request) is not None
|
assert await wrapped_func(request) is not None
|
||||||
|
|
||||||
@@ -228,7 +228,9 @@ async def test_verify_apikey_header():
|
|||||||
apikey_info, "header", "X-Auth"
|
apikey_info, "header", "X-Auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]})
|
request = ConnexionRequest(
|
||||||
|
scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]}
|
||||||
|
)
|
||||||
|
|
||||||
assert await wrapped_func(request) is not None
|
assert await wrapped_func(request) is not None
|
||||||
|
|
||||||
@@ -259,16 +261,20 @@ async def test_multiple_schemes():
|
|||||||
wrapped_func = security_handler_factory.verify_multiple_schemes(schemes)
|
wrapped_func = security_handler_factory.verify_multiple_schemes(schemes)
|
||||||
|
|
||||||
# Single key does not succeed
|
# Single key does not succeed
|
||||||
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]})
|
request = ConnexionRequest(
|
||||||
|
scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]}
|
||||||
|
)
|
||||||
|
|
||||||
assert await wrapped_func(request) is NO_VALUE
|
assert await wrapped_func(request) is NO_VALUE
|
||||||
|
|
||||||
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]})
|
request = ConnexionRequest(
|
||||||
|
scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]}
|
||||||
|
)
|
||||||
|
|
||||||
assert await wrapped_func(request) is NO_VALUE
|
assert await wrapped_func(request) is NO_VALUE
|
||||||
|
|
||||||
# Supplying both keys does succeed
|
# Supplying both keys does succeed
|
||||||
request = ASGIRequest(
|
request = ConnexionRequest(
|
||||||
scope={
|
scope={
|
||||||
"type": "http",
|
"type": "http",
|
||||||
"headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]],
|
"headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]],
|
||||||
@@ -287,7 +293,7 @@ async def test_verify_security_oauthproblem():
|
|||||||
security_handler_factory = SecurityHandlerFactory()
|
security_handler_factory = SecurityHandlerFactory()
|
||||||
security_func = security_handler_factory.verify_security([])
|
security_func = security_handler_factory.verify_security([])
|
||||||
|
|
||||||
request = MagicMock(spec_set=ASGIRequest)
|
request = MagicMock(spec_set=ConnexionRequest)
|
||||||
with pytest.raises(OAuthProblem) as exc_info:
|
with pytest.raises(OAuthProblem) as exc_info:
|
||||||
await security_func(request)
|
await security_func(request)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user