Centralize error handling in ExceptionMiddleware (#1754)

I was writing the documentation on exception handling, and I noticed
that it was very hard to explain our current behavior.

Error handlers can be registered either on the internal Flask app (not
the Starlette one) or on the Connexion app, which leads to some
undefined (actually just really hard to explain) behavior. Eg.
- Registering error handlers on a status code would capture
`starlette.HTTPException` errors on the Connexion app, and
`werkzeug.HTTPException` errors on the Flask App, which means that
registering an error handler on a status code doesn't catch all the
errors with that status code.
- Flask does some default error handling which leads to some exceptions
never reaching the error handlers registered on the Connexion app.

So I made the following changes:
- Replaced the default error handlers we registered on the Flask app
with a default handler on the `ExceptionMiddleware` that takes into
account other handlers registered on status codes.
- Configured Flask to propagate exceptions instead of catching them.
- Abstracted away the Starlette `Request` and `Response` types, so users
can and must now use `ConnexionRequest`
  and `ConnexionResponse` types in error handlers.
- Renamed the `ASGIRequest` class to `ConnexionRequest` since it is the
only Request class part of the high level
  Connexion interface.

We could also rename `ConnexionRequest` and `ConnexionResponse` to just
`Request` and `Response`. Wdyt?
This commit is contained in:
Robbe Sneyders
2023-10-29 09:37:54 +01:00
committed by GitHub
parent 1b72019b1b
commit b9ba13cde5
17 changed files with 178 additions and 109 deletions

View File

@@ -10,10 +10,12 @@ from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware
from connexion.middleware.lifespan import Lifespan from connexion.middleware.lifespan import Lifespan
from connexion.options import SwaggerUIOptions from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
@@ -250,14 +252,18 @@ class AbstractApp:
@abc.abstractmethod @abc.abstractmethod
def add_error_handler( def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None: ) -> None:
""" """
Register a callable to handle application errors. Register a callable to handle application errors.
:param code_or_exception: An exception class or the status code of HTTP exceptions to :param code_or_exception: An exception class or the status code of HTTP exceptions to
handle. handle.
:param function: Callable that will handle exception. :param function: Callable that will handle exception, may be async.
""" """
def test_client(self, **kwargs): def test_client(self, **kwargs):

View File

@@ -14,11 +14,13 @@ from starlette.types import Receive, Scope, Send
from connexion.apps.abstract import AbstractApp from connexion.apps.abstract import AbstractApp
from connexion.decorators import StarletteDecorator from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -88,7 +90,7 @@ class AsyncApi(RoutedAPI[AsyncOperation]):
) )
class AsyncMiddlewareApp(RoutedMiddleware[AsyncApi]): class AsyncASGIApp(RoutedMiddleware[AsyncApi]):
api_cls = AsyncApi api_cls = AsyncApi
@@ -176,7 +178,7 @@ class AsyncApp(AbstractApp):
:param security_map: A dictionary of security handlers to use. Defaults to :param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS` :obj:`security.SECURITY_HANDLERS`
""" """
self._middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp() self._middleware_app: AsyncASGIApp = AsyncASGIApp()
super().__init__( super().__init__(
import_name, import_name,
@@ -205,6 +207,10 @@ class AsyncApp(AbstractApp):
) )
def add_error_handler( def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None: ) -> None:
self.middleware.add_error_handler(code_or_exception, function) self.middleware.add_error_handler(code_or_exception, function)

View File

@@ -6,23 +6,22 @@ import pathlib
import typing as t import typing as t
import flask import flask
import werkzeug.exceptions
from a2wsgi import WSGIMiddleware from a2wsgi import WSGIMiddleware
from flask import Response as FlaskResponse from flask import Response as FlaskResponse
from flask import signals
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from connexion.apps.abstract import AbstractApp from connexion.apps.abstract import AbstractApp
from connexion.decorators import FlaskDecorator from connexion.decorators import FlaskDecorator
from connexion.exceptions import InternalServerError, ProblemException, ResolverError from connexion.exceptions import ResolverError
from connexion.frameworks import flask as flask_utils from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions from connexion.options import SwaggerUIOptions
from connexion.problem import problem
from connexion.resolver import Resolver from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
@@ -117,44 +116,20 @@ class FlaskApi(AbstractRoutingAPI):
return self.blueprint.add_url_rule(rule, endpoint, view_func, **options) return self.blueprint.add_url_rule(rule, endpoint, view_func, **options)
class FlaskMiddlewareApp(SpecMiddleware): class FlaskASGIApp(SpecMiddleware):
def __init__(self, import_name, server_args: dict, **kwargs): def __init__(self, import_name, server_args: dict, **kwargs):
self.app = flask.Flask(import_name, **server_args) self.app = flask.Flask(import_name, **server_args)
self.app.json = flask_utils.FlaskJSONProvider(self.app) self.app.json = flask_utils.FlaskJSONProvider(self.app)
self.app.url_map.converters["float"] = flask_utils.NumberConverter self.app.url_map.converters["float"] = flask_utils.NumberConverter
self.app.url_map.converters["int"] = flask_utils.IntegerConverter self.app.url_map.converters["int"] = flask_utils.IntegerConverter
self.set_errors_handlers() # Propagate Errors so we can handle them in the middleware
self.app.config["PROPAGATE_EXCEPTIONS"] = True
self.app.config["TRAP_BAD_REQUEST_ERRORS"] = True
self.app.config["TRAP_HTTP_EXCEPTIONS"] = True
self.asgi_app = WSGIMiddleware(self.app.wsgi_app) self.asgi_app = WSGIMiddleware(self.app.wsgi_app)
def set_errors_handlers(self):
for error_code in werkzeug.exceptions.default_exceptions:
self.app.register_error_handler(error_code, self.common_error_handler)
self.app.register_error_handler(ProblemException, self.common_error_handler)
def common_error_handler(self, exception: Exception) -> FlaskResponse:
"""Default error handler."""
if isinstance(exception, ProblemException):
response = exception.to_problem()
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = InternalServerError()
response = problem(
title=exception.name,
detail=exception.description,
status=exception.code,
)
if response.status_code >= 500:
signals.got_request_exception.send(self.app, exception=exception)
return flask.make_response(
(response.body, response.status_code, response.headers)
)
def add_api(self, specification, *, name: str = None, **kwargs): def add_api(self, specification, *, name: str = None, **kwargs):
api = FlaskApi(specification, **kwargs) api = FlaskApi(specification, **kwargs)
@@ -177,7 +152,7 @@ class FlaskMiddlewareApp(SpecMiddleware):
class FlaskApp(AbstractApp): class FlaskApp(AbstractApp):
"""Connexion Application based on ConnexionMiddleware wrapping a Flask application.""" """Connexion Application based on ConnexionMiddleware wrapping a Flask application."""
_middleware_app: FlaskMiddlewareApp _middleware_app: FlaskASGIApp
def __init__( def __init__(
self, self,
@@ -237,7 +212,7 @@ class FlaskApp(AbstractApp):
:param security_map: A dictionary of security handlers to use. Defaults to :param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS` :obj:`security.SECURITY_HANDLERS`
""" """
self._middleware_app = FlaskMiddlewareApp(import_name, server_args or {}) self._middleware_app = FlaskASGIApp(import_name, server_args or {})
self.app = self._middleware_app.app self.app = self._middleware_app.app
super().__init__( super().__init__(
import_name, import_name,
@@ -266,6 +241,10 @@ class FlaskApp(AbstractApp):
) )
def add_error_handler( def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None: ) -> None:
self.app.register_error_handler(code_or_exception, function) self.app.register_error_handler(code_or_exception, function)

View File

@@ -3,7 +3,7 @@ from contextvars import ContextVar
from starlette.types import Receive, Scope from starlette.types import Receive, Scope
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy
from connexion.lifecycle import ASGIRequest from connexion.lifecycle import ConnexionRequest
from connexion.operations import AbstractOperation from connexion.operations import AbstractOperation
UNBOUND_MESSAGE = ( UNBOUND_MESSAGE = (
@@ -25,5 +25,5 @@ _scope: ContextVar[Scope] = ContextVar("SCOPE")
scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE) scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE)
request = LocalProxy( request = LocalProxy(
lambda: ASGIRequest(scope, receive), unbound_message=UNBOUND_MESSAGE lambda: ConnexionRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
) )

View File

@@ -16,7 +16,7 @@ import inflection
from connexion.context import context, operation from connexion.context import context, operation
from connexion.frameworks.abstract import Framework from connexion.frameworks.abstract import Framework
from connexion.http_facts import FORM_CONTENT_TYPES from connexion.http_facts import FORM_CONTENT_TYPES
from connexion.lifecycle import ASGIRequest, WSGIRequest from connexion.lifecycle import ConnexionRequest, WSGIRequest
from connexion.operations import AbstractOperation, Swagger2Operation from connexion.operations import AbstractOperation, Swagger2Operation
from connexion.utils import ( from connexion.utils import (
deep_merge, deep_merge,
@@ -43,7 +43,7 @@ class BaseParameterDecorator:
def _maybe_get_body( def _maybe_get_body(
self, self,
request: t.Union[WSGIRequest, ASGIRequest], request: t.Union[WSGIRequest, ConnexionRequest],
*, *,
arguments: t.List[str], arguments: t.List[str],
has_kwargs: bool, has_kwargs: bool,
@@ -95,7 +95,7 @@ class AsyncParameterDecorator(BaseParameterDecorator):
arguments, has_kwargs = inspect_function_arguments(unwrapped_function) arguments, has_kwargs = inspect_function_arguments(unwrapped_function)
@functools.wraps(function) @functools.wraps(function)
async def wrapper(request: ASGIRequest) -> t.Any: async def wrapper(request: ConnexionRequest) -> t.Any:
request_body = self._maybe_get_body( request_body = self._maybe_get_body(
request, arguments=arguments, has_kwargs=has_kwargs request, arguments=arguments, has_kwargs=has_kwargs
) )
@@ -118,7 +118,7 @@ class AsyncParameterDecorator(BaseParameterDecorator):
def prep_kwargs( def prep_kwargs(
request: t.Union[WSGIRequest, ASGIRequest], request: t.Union[WSGIRequest, ConnexionRequest],
*, *,
request_body: t.Any, request_body: t.Any,
files: t.Dict[str, t.Any], files: t.Dict[str, t.Any],

View File

@@ -8,7 +8,7 @@ from starlette.responses import Response as StarletteResponse
from starlette.types import Receive, Scope from starlette.types import Receive, Scope
from connexion.frameworks.abstract import Framework from connexion.frameworks.abstract import Framework
from connexion.lifecycle import ASGIRequest from connexion.lifecycle import ConnexionRequest
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
@@ -48,8 +48,8 @@ class Starlette(Framework):
) )
@staticmethod @staticmethod
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ASGIRequest: # type: ignore def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore
return ASGIRequest(scope, receive, uri_parser=uri_parser) return ConnexionRequest(scope, receive, uri_parser=uri_parser)
PATH_PARAMETER = re.compile(r"\{([^}]*)\}") PATH_PARAMETER = re.compile(r"\{([^}]*)\}")

View File

@@ -130,7 +130,7 @@ class WSGIRequest(_RequestInterface):
return getattr(self._werkzeug_request, item) return getattr(self._werkzeug_request, item)
class ASGIRequest(_RequestInterface): class ConnexionRequest(_RequestInterface):
""" """
Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request. Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request.
@@ -142,6 +142,8 @@ class ASGIRequest(_RequestInterface):
""" """
def __init__(self, *args, uri_parser=None, **kwargs): def __init__(self, *args, uri_parser=None, **kwargs):
# Might be set in `from_starlette_request` class method
if not hasattr(self, "_starlette_request"):
self._starlette_request = StarletteRequest(*args, **kwargs) self._starlette_request = StarletteRequest(*args, **kwargs)
self.uri_parser = uri_parser self.uri_parser = uri_parser
@@ -152,6 +154,16 @@ class ASGIRequest(_RequestInterface):
self._form = None self._form = None
self._files = None self._files = None
@classmethod
def from_starlette_request(
cls, request: StarletteRequest, uri_parser=None
) -> "ConnexionRequest":
# Instantiate the class, and set the `_starlette_request` property before initializing.
self = cls.__new__(cls)
self._starlette_request = request
self.__init__(uri_parser=uri_parser) # type: ignore
return self
@property @property
def context(self): def context(self):
if self._context is None: if self._context is None:
@@ -226,6 +238,7 @@ class ASGIRequest(_RequestInterface):
return await self.body() or None return await self.body() or None
def __getattr__(self, item): def __getattr__(self, item):
if self.__getattribute__("_starlette_request"):
return getattr(self._starlette_request, item) return getattr(self._starlette_request, item)

View File

@@ -1,68 +1,114 @@
import asyncio
import logging import logging
import typing as t
import werkzeug.exceptions
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.middleware.exceptions import ( from starlette.middleware.exceptions import (
ExceptionMiddleware as StarletteExceptionMiddleware, ExceptionMiddleware as StarletteExceptionMiddleware,
) )
from starlette.requests import Request as StarletteRequest from starlette.requests import Request as StarletteRequest
from starlette.responses import Response from starlette.responses import Response as StarletteResponse
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import InternalServerError, ProblemException, problem from connexion.exceptions import InternalServerError, ProblemException, problem
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.types import MaybeAwaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def connexion_wrapper(
handler: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
]
) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]:
"""Wrapper that translates Starlette requests to Connexion requests before passing
them to the error handler, and translates the returned Connexion responses to
Starlette responses."""
async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse:
request = ConnexionRequest.from_starlette_request(request)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc) # type: ignore
else:
response = await run_in_threadpool(handler, request, exc)
return StarletteResponse(
content=response.body,
status_code=response.status_code,
media_type=response.mimetype,
headers=response.headers,
)
return wrapper
class ExceptionMiddleware(StarletteExceptionMiddleware): class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior.""" existing connexion behavior."""
def __init__(self, next_app: ASGIApp): def __init__(self, next_app: ASGIApp):
super().__init__(next_app) super().__init__(next_app)
self.add_exception_handler(ProblemException, self.problem_handler) self.add_exception_handler(ProblemException, self.problem_handler) # type: ignore
self.add_exception_handler(
werkzeug.exceptions.HTTPException, self.flask_error_handler
)
self.add_exception_handler(Exception, self.common_error_handler) self.add_exception_handler(Exception, self.common_error_handler)
def add_exception_handler(
self,
exc_class_or_status_code: t.Union[int, t.Type[Exception]],
handler: t.Callable[[ConnexionRequest, Exception], StarletteResponse],
) -> None:
super().add_exception_handler(
exc_class_or_status_code, handler=connexion_wrapper(handler)
)
@staticmethod @staticmethod
def problem_handler(_request: StarletteRequest, exc: ProblemException): def problem_handler(_request: ConnexionRequest, exc: ProblemException):
"""Default handler for Connexion ProblemExceptions"""
logger.error("%r", exc) logger.error("%r", exc)
return exc.to_problem()
response = exc.to_problem()
return Response(
content=response.body,
status_code=response.status_code,
media_type=response.mimetype,
headers=response.headers,
)
@staticmethod @staticmethod
def http_exception(_request: StarletteRequest, exc: HTTPException) -> Response: @connexion_wrapper
def http_exception(
_request: StarletteRequest, exc: HTTPException, **kwargs
) -> StarletteResponse:
"""Default handler for Starlette HTTPException"""
logger.error("%r", exc) logger.error("%r", exc)
return problem(
headers = exc.headers title=exc.detail,
detail=exc.detail,
connexion_response = problem( status=exc.status_code,
title=exc.detail, detail=exc.detail, status=exc.status_code, headers=headers headers=exc.headers,
)
return Response(
content=connexion_response.body,
status_code=connexion_response.status_code,
media_type=connexion_response.mimetype,
headers=connexion_response.headers,
) )
@staticmethod @staticmethod
def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response: def common_error_handler(
_request: StarletteRequest, exc: Exception
) -> ConnexionResponse:
"""Default handler for any unhandled Exception"""
logger.error("%r", exc, exc_info=exc) logger.error("%r", exc, exc_info=exc)
return InternalServerError().to_problem()
response = InternalServerError().to_problem() def flask_error_handler(
self, request: StarletteRequest, exc: werkzeug.exceptions.HTTPException
) -> ConnexionResponse:
"""Default handler for Flask / werkzeug HTTPException"""
# If a handler is registered for the received status_code, call it instead.
# This is only done automatically for Starlette HTTPExceptions
if handler := self._status_handlers.get(exc.code):
starlette_exception = HTTPException(exc.code, detail=exc.description)
return handler(request, starlette_exception)
return Response( return problem(
content=response.body, title=exc.name,
status_code=response.status_code, detail=exc.description,
media_type=response.mimetype, status=exc.code,
headers=response.headers,
) )
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

View File

@@ -12,6 +12,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from connexion import utils from connexion import utils
from connexion.handlers import ResolverErrorHandler from connexion.handlers import ResolverErrorHandler
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import SpecMiddleware from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware from connexion.middleware.exceptions import ExceptionMiddleware
@@ -23,6 +24,7 @@ from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.options import SwaggerUIOptions from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
from connexion.utils import inspect_function_arguments from connexion.utils import inspect_function_arguments
@@ -419,14 +421,18 @@ class ConnexionMiddleware:
self.apis.append(api) self.apis.append(api)
def add_error_handler( def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None: ) -> None:
""" """
Register a callable to handle application errors. Register a callable to handle application errors.
:param code_or_exception: An exception class or the status code of HTTP exceptions to :param code_or_exception: An exception class or the status code of HTTP exceptions to
handle. handle.
:param function: Callable that will handle exception. :param function: Callable that will handle exception, may be async.
""" """
if self.middleware_stack is not None: if self.middleware_stack is not None:
raise RuntimeError( raise RuntimeError(

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import ProblemException from connexion.exceptions import ProblemException
from connexion.lifecycle import ASGIRequest from connexion.lifecycle import ConnexionRequest
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation from connexion.operations import AbstractOperation
from connexion.security import SecurityHandlerFactory from connexion.security import SecurityHandlerFactory
@@ -95,7 +95,7 @@ class SecurityOperation:
await self.next_app(scope, receive, send) await self.next_app(scope, receive, send)
return return
request = ASGIRequest(scope) request = ConnexionRequest(scope)
await self.verification_fn(request) await self.verification_fn(request)
await self.next_app(scope, receive, send) await self.next_app(scope, receive, send)

View File

@@ -54,7 +54,7 @@ import httpx
from connexion.decorators.parameter import inspect_function_arguments from connexion.decorators.parameter import inspect_function_arguments
from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem
from connexion.lifecycle import ASGIRequest from connexion.lifecycle import ConnexionRequest
from connexion.utils import get_function_from_name from connexion.utils import get_function_from_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -248,7 +248,7 @@ class ApiKeySecurityHandler(AbstractSecurityHandler):
def _get_verify_func(self, api_key_info_func, loc, name): def _get_verify_func(self, api_key_info_func, loc, name):
check_api_key_func = self.check_api_key(api_key_info_func) check_api_key_func = self.check_api_key(api_key_info_func)
def wrapper(request: ASGIRequest): def wrapper(request: ConnexionRequest):
if loc == "query": if loc == "query":
api_key = request.query_params.get(name) api_key = request.query_params.get(name)
elif loc == "header": elif loc == "header":

4
connexion/types.py Normal file
View File

@@ -0,0 +1,4 @@
import typing as t
ReturnType = t.TypeVar("ReturnType")
MaybeAwaitable = t.Union[t.Awaitable[ReturnType], ReturnType]

View File

@@ -24,12 +24,13 @@ See below for an explanation of the different variables.
request request
------- -------
A ``Request`` object representing the incoming request. This is an instance of the ``ASGIRequest``. A ``Request`` object representing the incoming request. This is an instance of the
``ConnexionRequest``.
.. dropdown:: View a detailed reference of the ``ASGIRequest`` class .. dropdown:: View a detailed reference of the ``ConnexionRequest`` class
:icon: eye :icon: eye
.. autoclass:: connexion.lifecycle.ASGIRequest .. autoclass:: connexion.lifecycle.ConnexionRequest
:noindex: :noindex:
:members: :members:
:undoc-members: :undoc-members:

View File

@@ -488,7 +488,7 @@ request.
.. dropdown:: View a detailed reference of the ``connexion.request`` class .. dropdown:: View a detailed reference of the ``connexion.request`` class
:icon: eye :icon: eye
.. autoclass:: connexion.lifecycle.ASGIRequest .. autoclass:: connexion.lifecycle.ConnexionRequest
:members: :members:
:undoc-members: :undoc-members:
:inherited-members: :inherited-members:

View File

@@ -59,7 +59,7 @@ operation:
Note that :code:`HEAD` requests will be handled by the :code:`operationId` specified under the Note that :code:`HEAD` requests will be handled by the :code:`operationId` specified under the
:code:`GET` operation in the specification. :code:`Connexion.request.method` can be used to :code:`GET` operation in the specification. :code:`Connexion.request.method` can be used to
determine which request was made. See :class:`.ASGIRequest`. determine which request was made. See :class:`.ConnexionRequest`.
Automatic routing Automatic routing
----------------- -----------------

View File

@@ -127,6 +127,8 @@ Smaller breaking changes
has been added to work with Flask's ``MethodView`` specifically. has been added to work with Flask's ``MethodView`` specifically.
* Built-in support for uWSGI has been removed. You can re-add this functionality using a custom middleware. * Built-in support for uWSGI has been removed. You can re-add this functionality using a custom middleware.
* The request body is now passed through for ``GET``, ``HEAD``, ``DELETE``, ``CONNECT`` and ``OPTIONS`` methods as well. * The request body is now passed through for ``GET``, ``HEAD``, ``DELETE``, ``CONNECT`` and ``OPTIONS`` methods as well.
* Error handlers registered on the on the underlying Flask app directly will be ignored. You
should register them on the Connexion app directly.
Non-breaking changes Non-breaking changes

View File

@@ -10,7 +10,7 @@ from connexion.exceptions import (
OAuthResponseProblem, OAuthResponseProblem,
OAuthScopeProblem, OAuthScopeProblem,
) )
from connexion.lifecycle import ASGIRequest from connexion.lifecycle import ConnexionRequest
from connexion.security import ( from connexion.security import (
NO_VALUE, NO_VALUE,
ApiKeySecurityHandler, ApiKeySecurityHandler,
@@ -61,7 +61,7 @@ def test_verify_oauth_missing_auth_header():
somefunc, security_handler.validate_scope, ["admin"] somefunc, security_handler.validate_scope, ["admin"]
) )
request = ASGIRequest(scope={"type": "http", "headers": []}) request = ConnexionRequest(scope={"type": "http", "headers": []})
assert wrapped_func(request) is NO_VALUE assert wrapped_func(request) is NO_VALUE
@@ -83,7 +83,7 @@ async def test_verify_oauth_scopes_remote(monkeypatch):
token_info_func, security_handler.validate_scope, ["admin"] token_info_func, security_handler.validate_scope, ["admin"]
) )
request = ASGIRequest( request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
) )
@@ -124,7 +124,7 @@ async def test_verify_oauth_invalid_local_token_response_none():
somefunc, security_handler.validate_scope, ["admin"] somefunc, security_handler.validate_scope, ["admin"]
) )
request = ASGIRequest( request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
) )
@@ -143,7 +143,7 @@ async def test_verify_oauth_scopes_local():
token_info, security_handler.validate_scope, ["admin"] token_info, security_handler.validate_scope, ["admin"]
) )
request = ASGIRequest( request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
) )
@@ -178,7 +178,7 @@ def test_verify_basic_missing_auth_header():
security_handler = BasicSecurityHandler() security_handler = BasicSecurityHandler()
wrapped_func = security_handler._get_verify_func(somefunc) wrapped_func = security_handler._get_verify_func(somefunc)
request = ASGIRequest( request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]} scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
) )
@@ -194,7 +194,7 @@ async def test_verify_basic():
security_handler = BasicSecurityHandler() security_handler = BasicSecurityHandler()
wrapped_func = security_handler._get_verify_func(basic_info) wrapped_func = security_handler._get_verify_func(basic_info)
request = ASGIRequest( request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]} scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]}
) )
@@ -212,7 +212,7 @@ async def test_verify_apikey_query():
apikey_info, "query", "auth" apikey_info, "query", "auth"
) )
request = ASGIRequest(scope={"type": "http", "query_string": b"auth=foobar"}) request = ConnexionRequest(scope={"type": "http", "query_string": b"auth=foobar"})
assert await wrapped_func(request) is not None assert await wrapped_func(request) is not None
@@ -228,7 +228,9 @@ async def test_verify_apikey_header():
apikey_info, "header", "X-Auth" apikey_info, "header", "X-Auth"
) )
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]}) request = ConnexionRequest(
scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]}
)
assert await wrapped_func(request) is not None assert await wrapped_func(request) is not None
@@ -259,16 +261,20 @@ async def test_multiple_schemes():
wrapped_func = security_handler_factory.verify_multiple_schemes(schemes) wrapped_func = security_handler_factory.verify_multiple_schemes(schemes)
# Single key does not succeed # Single key does not succeed
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]}) request = ConnexionRequest(
scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]}
)
assert await wrapped_func(request) is NO_VALUE assert await wrapped_func(request) is NO_VALUE
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]}) request = ConnexionRequest(
scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]}
)
assert await wrapped_func(request) is NO_VALUE assert await wrapped_func(request) is NO_VALUE
# Supplying both keys does succeed # Supplying both keys does succeed
request = ASGIRequest( request = ConnexionRequest(
scope={ scope={
"type": "http", "type": "http",
"headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]], "headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]],
@@ -287,7 +293,7 @@ async def test_verify_security_oauthproblem():
security_handler_factory = SecurityHandlerFactory() security_handler_factory = SecurityHandlerFactory()
security_func = security_handler_factory.verify_security([]) security_func = security_handler_factory.verify_security([])
request = MagicMock(spec_set=ASGIRequest) request = MagicMock(spec_set=ConnexionRequest)
with pytest.raises(OAuthProblem) as exc_info: with pytest.raises(OAuthProblem) as exc_info:
await security_func(request) await security_func(request)