Centralize error handling in ExceptionMiddleware (#1754)

I was writing the documentation on exception handling, and I noticed
that it was very hard to explain our current behavior.

Error handlers can be registered either on the internal Flask app (not
the Starlette one) or on the Connexion app, which leads to some
undefined (actually just really hard to explain) behavior. Eg.
- Registering error handlers on a status code would capture
`starlette.HTTPException` errors on the Connexion app, and
`werkzeug.HTTPException` errors on the Flask App, which means that
registering an error handler on a status code doesn't catch all the
errors with that status code.
- Flask does some default error handling which leads to some exceptions
never reaching the error handlers registered on the Connexion app.

So I made the following changes:
- Replaced the default error handlers we registered on the Flask app
with a default handler on the `ExceptionMiddleware` that takes into
account other handlers registered on status codes.
- Configured Flask to propagate exceptions instead of catching them.
- Abstracted away the Starlette `Request` and `Response` types, so users
can and must now use `ConnexionRequest`
  and `ConnexionResponse` types in error handlers.
- Renamed the `ASGIRequest` class to `ConnexionRequest` since it is the
only Request class part of the high level
  Connexion interface.

We could also rename `ConnexionRequest` and `ConnexionResponse` to just
`Request` and `Response`. Wdyt?
This commit is contained in:
Robbe Sneyders
2023-10-29 09:37:54 +01:00
committed by GitHub
parent 1b72019b1b
commit b9ba13cde5
17 changed files with 178 additions and 109 deletions

View File

@@ -10,10 +10,12 @@ from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware import ConnexionMiddleware, MiddlewarePosition, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser
@@ -250,14 +252,18 @@ class AbstractApp:
@abc.abstractmethod
def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
"""
Register a callable to handle application errors.
:param code_or_exception: An exception class or the status code of HTTP exceptions to
handle.
:param function: Callable that will handle exception.
:param function: Callable that will handle exception, may be async.
"""
def test_client(self, **kwargs):

View File

@@ -14,11 +14,13 @@ from starlette.types import Receive, Scope, Send
from connexion.apps.abstract import AbstractApp
from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser
logger = logging.getLogger(__name__)
@@ -88,7 +90,7 @@ class AsyncApi(RoutedAPI[AsyncOperation]):
)
class AsyncMiddlewareApp(RoutedMiddleware[AsyncApi]):
class AsyncASGIApp(RoutedMiddleware[AsyncApi]):
api_cls = AsyncApi
@@ -176,7 +178,7 @@ class AsyncApp(AbstractApp):
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self._middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp()
self._middleware_app: AsyncASGIApp = AsyncASGIApp()
super().__init__(
import_name,
@@ -205,6 +207,10 @@ class AsyncApp(AbstractApp):
)
def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
self.middleware.add_error_handler(code_or_exception, function)

View File

@@ -6,23 +6,22 @@ import pathlib
import typing as t
import flask
import werkzeug.exceptions
from a2wsgi import WSGIMiddleware
from flask import Response as FlaskResponse
from flask import signals
from starlette.types import Receive, Scope, Send
from connexion.apps.abstract import AbstractApp
from connexion.decorators import FlaskDecorator
from connexion.exceptions import InternalServerError, ProblemException, ResolverError
from connexion.exceptions import ResolverError
from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.options import SwaggerUIOptions
from connexion.problem import problem
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser
@@ -117,44 +116,20 @@ class FlaskApi(AbstractRoutingAPI):
return self.blueprint.add_url_rule(rule, endpoint, view_func, **options)
class FlaskMiddlewareApp(SpecMiddleware):
class FlaskASGIApp(SpecMiddleware):
def __init__(self, import_name, server_args: dict, **kwargs):
self.app = flask.Flask(import_name, **server_args)
self.app.json = flask_utils.FlaskJSONProvider(self.app)
self.app.url_map.converters["float"] = flask_utils.NumberConverter
self.app.url_map.converters["int"] = flask_utils.IntegerConverter
self.set_errors_handlers()
# Propagate Errors so we can handle them in the middleware
self.app.config["PROPAGATE_EXCEPTIONS"] = True
self.app.config["TRAP_BAD_REQUEST_ERRORS"] = True
self.app.config["TRAP_HTTP_EXCEPTIONS"] = True
self.asgi_app = WSGIMiddleware(self.app.wsgi_app)
def set_errors_handlers(self):
for error_code in werkzeug.exceptions.default_exceptions:
self.app.register_error_handler(error_code, self.common_error_handler)
self.app.register_error_handler(ProblemException, self.common_error_handler)
def common_error_handler(self, exception: Exception) -> FlaskResponse:
"""Default error handler."""
if isinstance(exception, ProblemException):
response = exception.to_problem()
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = InternalServerError()
response = problem(
title=exception.name,
detail=exception.description,
status=exception.code,
)
if response.status_code >= 500:
signals.got_request_exception.send(self.app, exception=exception)
return flask.make_response(
(response.body, response.status_code, response.headers)
)
def add_api(self, specification, *, name: str = None, **kwargs):
api = FlaskApi(specification, **kwargs)
@@ -177,7 +152,7 @@ class FlaskMiddlewareApp(SpecMiddleware):
class FlaskApp(AbstractApp):
"""Connexion Application based on ConnexionMiddleware wrapping a Flask application."""
_middleware_app: FlaskMiddlewareApp
_middleware_app: FlaskASGIApp
def __init__(
self,
@@ -237,7 +212,7 @@ class FlaskApp(AbstractApp):
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self._middleware_app = FlaskMiddlewareApp(import_name, server_args or {})
self._middleware_app = FlaskASGIApp(import_name, server_args or {})
self.app = self._middleware_app.app
super().__init__(
import_name,
@@ -266,6 +241,10 @@ class FlaskApp(AbstractApp):
)
def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
self.app.register_error_handler(code_or_exception, function)

View File

@@ -3,7 +3,7 @@ from contextvars import ContextVar
from starlette.types import Receive, Scope
from werkzeug.local import LocalProxy
from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.operations import AbstractOperation
UNBOUND_MESSAGE = (
@@ -25,5 +25,5 @@ _scope: ContextVar[Scope] = ContextVar("SCOPE")
scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE)
request = LocalProxy(
lambda: ASGIRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
lambda: ConnexionRequest(scope, receive), unbound_message=UNBOUND_MESSAGE
)

View File

@@ -16,7 +16,7 @@ import inflection
from connexion.context import context, operation
from connexion.frameworks.abstract import Framework
from connexion.http_facts import FORM_CONTENT_TYPES
from connexion.lifecycle import ASGIRequest, WSGIRequest
from connexion.lifecycle import ConnexionRequest, WSGIRequest
from connexion.operations import AbstractOperation, Swagger2Operation
from connexion.utils import (
deep_merge,
@@ -43,7 +43,7 @@ class BaseParameterDecorator:
def _maybe_get_body(
self,
request: t.Union[WSGIRequest, ASGIRequest],
request: t.Union[WSGIRequest, ConnexionRequest],
*,
arguments: t.List[str],
has_kwargs: bool,
@@ -95,7 +95,7 @@ class AsyncParameterDecorator(BaseParameterDecorator):
arguments, has_kwargs = inspect_function_arguments(unwrapped_function)
@functools.wraps(function)
async def wrapper(request: ASGIRequest) -> t.Any:
async def wrapper(request: ConnexionRequest) -> t.Any:
request_body = self._maybe_get_body(
request, arguments=arguments, has_kwargs=has_kwargs
)
@@ -118,7 +118,7 @@ class AsyncParameterDecorator(BaseParameterDecorator):
def prep_kwargs(
request: t.Union[WSGIRequest, ASGIRequest],
request: t.Union[WSGIRequest, ConnexionRequest],
*,
request_body: t.Any,
files: t.Dict[str, t.Any],

View File

@@ -8,7 +8,7 @@ from starlette.responses import Response as StarletteResponse
from starlette.types import Receive, Scope
from connexion.frameworks.abstract import Framework
from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.uri_parsing import AbstractURIParser
@@ -48,8 +48,8 @@ class Starlette(Framework):
)
@staticmethod
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ASGIRequest: # type: ignore
return ASGIRequest(scope, receive, uri_parser=uri_parser)
def get_request(*, scope: Scope, receive: Receive, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore
return ConnexionRequest(scope, receive, uri_parser=uri_parser)
PATH_PARAMETER = re.compile(r"\{([^}]*)\}")

View File

@@ -130,7 +130,7 @@ class WSGIRequest(_RequestInterface):
return getattr(self._werkzeug_request, item)
class ASGIRequest(_RequestInterface):
class ConnexionRequest(_RequestInterface):
"""
Implementation of the Connexion :code:`_RequestInterface` representing an ASGI request.
@@ -142,6 +142,8 @@ class ASGIRequest(_RequestInterface):
"""
def __init__(self, *args, uri_parser=None, **kwargs):
# Might be set in `from_starlette_request` class method
if not hasattr(self, "_starlette_request"):
self._starlette_request = StarletteRequest(*args, **kwargs)
self.uri_parser = uri_parser
@@ -152,6 +154,16 @@ class ASGIRequest(_RequestInterface):
self._form = None
self._files = None
@classmethod
def from_starlette_request(
cls, request: StarletteRequest, uri_parser=None
) -> "ConnexionRequest":
# Instantiate the class, and set the `_starlette_request` property before initializing.
self = cls.__new__(cls)
self._starlette_request = request
self.__init__(uri_parser=uri_parser) # type: ignore
return self
@property
def context(self):
if self._context is None:
@@ -226,6 +238,7 @@ class ASGIRequest(_RequestInterface):
return await self.body() or None
def __getattr__(self, item):
if self.__getattribute__("_starlette_request"):
return getattr(self._starlette_request, item)

View File

@@ -1,68 +1,114 @@
import asyncio
import logging
import typing as t
import werkzeug.exceptions
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.middleware.exceptions import (
ExceptionMiddleware as StarletteExceptionMiddleware,
)
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response
from starlette.responses import Response as StarletteResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import InternalServerError, ProblemException, problem
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.types import MaybeAwaitable
logger = logging.getLogger(__name__)
def connexion_wrapper(
handler: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
]
) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]:
"""Wrapper that translates Starlette requests to Connexion requests before passing
them to the error handler, and translates the returned Connexion responses to
Starlette responses."""
async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse:
request = ConnexionRequest.from_starlette_request(request)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc) # type: ignore
else:
response = await run_in_threadpool(handler, request, exc)
return StarletteResponse(
content=response.body,
status_code=response.status_code,
media_type=response.mimetype,
headers=response.headers,
)
return wrapper
class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior."""
def __init__(self, next_app: ASGIApp):
super().__init__(next_app)
self.add_exception_handler(ProblemException, self.problem_handler)
self.add_exception_handler(ProblemException, self.problem_handler) # type: ignore
self.add_exception_handler(
werkzeug.exceptions.HTTPException, self.flask_error_handler
)
self.add_exception_handler(Exception, self.common_error_handler)
def add_exception_handler(
self,
exc_class_or_status_code: t.Union[int, t.Type[Exception]],
handler: t.Callable[[ConnexionRequest, Exception], StarletteResponse],
) -> None:
super().add_exception_handler(
exc_class_or_status_code, handler=connexion_wrapper(handler)
)
@staticmethod
def problem_handler(_request: StarletteRequest, exc: ProblemException):
def problem_handler(_request: ConnexionRequest, exc: ProblemException):
"""Default handler for Connexion ProblemExceptions"""
logger.error("%r", exc)
response = exc.to_problem()
return Response(
content=response.body,
status_code=response.status_code,
media_type=response.mimetype,
headers=response.headers,
)
return exc.to_problem()
@staticmethod
def http_exception(_request: StarletteRequest, exc: HTTPException) -> Response:
@connexion_wrapper
def http_exception(
_request: StarletteRequest, exc: HTTPException, **kwargs
) -> StarletteResponse:
"""Default handler for Starlette HTTPException"""
logger.error("%r", exc)
headers = exc.headers
connexion_response = problem(
title=exc.detail, detail=exc.detail, status=exc.status_code, headers=headers
)
return Response(
content=connexion_response.body,
status_code=connexion_response.status_code,
media_type=connexion_response.mimetype,
headers=connexion_response.headers,
return problem(
title=exc.detail,
detail=exc.detail,
status=exc.status_code,
headers=exc.headers,
)
@staticmethod
def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response:
def common_error_handler(
_request: StarletteRequest, exc: Exception
) -> ConnexionResponse:
"""Default handler for any unhandled Exception"""
logger.error("%r", exc, exc_info=exc)
return InternalServerError().to_problem()
response = InternalServerError().to_problem()
def flask_error_handler(
self, request: StarletteRequest, exc: werkzeug.exceptions.HTTPException
) -> ConnexionResponse:
"""Default handler for Flask / werkzeug HTTPException"""
# If a handler is registered for the received status_code, call it instead.
# This is only done automatically for Starlette HTTPExceptions
if handler := self._status_handlers.get(exc.code):
starlette_exception = HTTPException(exc.code, detail=exc.description)
return handler(request, starlette_exception)
return Response(
content=response.body,
status_code=response.status_code,
media_type=response.mimetype,
headers=response.headers,
return problem(
title=exc.name,
detail=exc.description,
status=exc.code,
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

View File

@@ -12,6 +12,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from connexion import utils
from connexion.handlers import ResolverErrorHandler
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
@@ -23,6 +24,7 @@ from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from connexion.types import MaybeAwaitable
from connexion.uri_parsing import AbstractURIParser
from connexion.utils import inspect_function_arguments
@@ -419,14 +421,18 @@ class ConnexionMiddleware:
self.apis.append(api)
def add_error_handler(
self, code_or_exception: t.Union[int, t.Type[Exception]], function: t.Callable
self,
code_or_exception: t.Union[int, t.Type[Exception]],
function: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
],
) -> None:
"""
Register a callable to handle application errors.
:param code_or_exception: An exception class or the status code of HTTP exceptions to
handle.
:param function: Callable that will handle exception.
:param function: Callable that will handle exception, may be async.
"""
if self.middleware_stack is not None:
raise RuntimeError(

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import ProblemException
from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation
from connexion.security import SecurityHandlerFactory
@@ -95,7 +95,7 @@ class SecurityOperation:
await self.next_app(scope, receive, send)
return
request = ASGIRequest(scope)
request = ConnexionRequest(scope)
await self.verification_fn(request)
await self.next_app(scope, receive, send)

View File

@@ -54,7 +54,7 @@ import httpx
from connexion.decorators.parameter import inspect_function_arguments
from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem
from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.utils import get_function_from_name
logger = logging.getLogger(__name__)
@@ -248,7 +248,7 @@ class ApiKeySecurityHandler(AbstractSecurityHandler):
def _get_verify_func(self, api_key_info_func, loc, name):
check_api_key_func = self.check_api_key(api_key_info_func)
def wrapper(request: ASGIRequest):
def wrapper(request: ConnexionRequest):
if loc == "query":
api_key = request.query_params.get(name)
elif loc == "header":

4
connexion/types.py Normal file
View File

@@ -0,0 +1,4 @@
import typing as t
ReturnType = t.TypeVar("ReturnType")
MaybeAwaitable = t.Union[t.Awaitable[ReturnType], ReturnType]

View File

@@ -24,12 +24,13 @@ See below for an explanation of the different variables.
request
-------
A ``Request`` object representing the incoming request. This is an instance of the ``ASGIRequest``.
A ``Request`` object representing the incoming request. This is an instance of the
``ConnexionRequest``.
.. dropdown:: View a detailed reference of the ``ASGIRequest`` class
.. dropdown:: View a detailed reference of the ``ConnexionRequest`` class
:icon: eye
.. autoclass:: connexion.lifecycle.ASGIRequest
.. autoclass:: connexion.lifecycle.ConnexionRequest
:noindex:
:members:
:undoc-members:

View File

@@ -488,7 +488,7 @@ request.
.. dropdown:: View a detailed reference of the ``connexion.request`` class
:icon: eye
.. autoclass:: connexion.lifecycle.ASGIRequest
.. autoclass:: connexion.lifecycle.ConnexionRequest
:members:
:undoc-members:
:inherited-members:

View File

@@ -59,7 +59,7 @@ operation:
Note that :code:`HEAD` requests will be handled by the :code:`operationId` specified under the
:code:`GET` operation in the specification. :code:`Connexion.request.method` can be used to
determine which request was made. See :class:`.ASGIRequest`.
determine which request was made. See :class:`.ConnexionRequest`.
Automatic routing
-----------------

View File

@@ -127,6 +127,8 @@ Smaller breaking changes
has been added to work with Flask's ``MethodView`` specifically.
* Built-in support for uWSGI has been removed. You can re-add this functionality using a custom middleware.
* The request body is now passed through for ``GET``, ``HEAD``, ``DELETE``, ``CONNECT`` and ``OPTIONS`` methods as well.
* Error handlers registered on the on the underlying Flask app directly will be ignored. You
should register them on the Connexion app directly.
Non-breaking changes

View File

@@ -10,7 +10,7 @@ from connexion.exceptions import (
OAuthResponseProblem,
OAuthScopeProblem,
)
from connexion.lifecycle import ASGIRequest
from connexion.lifecycle import ConnexionRequest
from connexion.security import (
NO_VALUE,
ApiKeySecurityHandler,
@@ -61,7 +61,7 @@ def test_verify_oauth_missing_auth_header():
somefunc, security_handler.validate_scope, ["admin"]
)
request = ASGIRequest(scope={"type": "http", "headers": []})
request = ConnexionRequest(scope={"type": "http", "headers": []})
assert wrapped_func(request) is NO_VALUE
@@ -83,7 +83,7 @@ async def test_verify_oauth_scopes_remote(monkeypatch):
token_info_func, security_handler.validate_scope, ["admin"]
)
request = ASGIRequest(
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
@@ -124,7 +124,7 @@ async def test_verify_oauth_invalid_local_token_response_none():
somefunc, security_handler.validate_scope, ["admin"]
)
request = ASGIRequest(
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
@@ -143,7 +143,7 @@ async def test_verify_oauth_scopes_local():
token_info, security_handler.validate_scope, ["admin"]
)
request = ASGIRequest(
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
@@ -178,7 +178,7 @@ def test_verify_basic_missing_auth_header():
security_handler = BasicSecurityHandler()
wrapped_func = security_handler._get_verify_func(somefunc)
request = ASGIRequest(
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
@@ -194,7 +194,7 @@ async def test_verify_basic():
security_handler = BasicSecurityHandler()
wrapped_func = security_handler._get_verify_func(basic_info)
request = ASGIRequest(
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]}
)
@@ -212,7 +212,7 @@ async def test_verify_apikey_query():
apikey_info, "query", "auth"
)
request = ASGIRequest(scope={"type": "http", "query_string": b"auth=foobar"})
request = ConnexionRequest(scope={"type": "http", "query_string": b"auth=foobar"})
assert await wrapped_func(request) is not None
@@ -228,7 +228,9 @@ async def test_verify_apikey_header():
apikey_info, "header", "X-Auth"
)
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]})
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]}
)
assert await wrapped_func(request) is not None
@@ -259,16 +261,20 @@ async def test_multiple_schemes():
wrapped_func = security_handler_factory.verify_multiple_schemes(schemes)
# Single key does not succeed
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]})
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]}
)
assert await wrapped_func(request) is NO_VALUE
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]})
request = ConnexionRequest(
scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]}
)
assert await wrapped_func(request) is NO_VALUE
# Supplying both keys does succeed
request = ASGIRequest(
request = ConnexionRequest(
scope={
"type": "http",
"headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]],
@@ -287,7 +293,7 @@ async def test_verify_security_oauthproblem():
security_handler_factory = SecurityHandlerFactory()
security_func = security_handler_factory.verify_security([])
request = MagicMock(spec_set=ASGIRequest)
request = MagicMock(spec_set=ConnexionRequest)
with pytest.raises(OAuthProblem) as exc_info:
await security_func(request)