Add routing middleware (#1497)

* Add routing middleware

Factor out starlette BaseHTTPMiddleware

Fix exceptions for starlette < 0.19

Fix docstring formatting

Rename middleware/base.py to abstract.py

Rework routing middleware

* Clean up abstract API docstrings

* Move connexion context into extensions

* Allow empty middleware list
This commit is contained in:
Robbe Sneyders
2022-04-19 22:55:20 +02:00
committed by GitHub
parent 7f2931037e
commit 84e33e5897
16 changed files with 406 additions and 136 deletions

View File

@@ -13,4 +13,5 @@ on the framework app.
""" """
from .abstract import AbstractAPI, AbstractSwaggerUIAPI # NOQA from .abstract import (AbstractAPI, AbstractMinimalAPI, # NOQA
AbstractSwaggerUIAPI)

View File

@@ -14,7 +14,7 @@ from ..exceptions import ResolverError
from ..http_facts import METHODS from ..http_facts import METHODS
from ..jsonifier import Jsonifier from ..jsonifier import Jsonifier
from ..lifecycle import ConnexionResponse from ..lifecycle import ConnexionResponse
from ..operations import make_operation from ..operations import AbstractOperation, make_operation
from ..options import ConnexionOptions from ..options import ConnexionOptions
from ..resolver import Resolver from ..resolver import Resolver
from ..spec import Specification from ..spec import Specification
@@ -43,7 +43,14 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
*args, *args,
**kwargs **kwargs
): ):
"""Base API class with only minimal behavior related to the specification.""" """Base API class with only minimal behavior related to the specification.
:param specification: OpenAPI specification. Can be provided either as dict, or as path
to file.
:param base_path: Base path to host the API.
:param arguments: Jinja arguments to resolve in specification.
:param options: New style options dictionary.
"""
logger.debug('Loading specification: %s', specification, logger.debug('Loading specification: %s', specification,
extra={'swagger_yaml': specification, extra={'swagger_yaml': specification,
'base_path': base_path, 'base_path': base_path,
@@ -109,7 +116,105 @@ class AbstractSwaggerUIAPI(AbstractSpecAPI):
""" """
class AbstractAPI(AbstractSpecAPI): class AbstractMinimalAPI(AbstractSpecAPI):
def __init__(
self,
*args,
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
pass_context_arg_name: t.Optional[str] = None,
**kwargs
) -> None:
"""Minimal interface of an API, with only functionality related to routing.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param debug: Flag to run in debug mode
:param pass_context_arg_name: If not None URL request handling functions with an argument
matching this name will be passed the framework's request context.
"""
super().__init__(*args, **kwargs)
self.debug = debug
self.resolver_error_handler = resolver_error_handler
logger.debug('Security Definitions: %s', self.specification.security_definitions)
self.resolver = resolver or Resolver()
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
self.add_paths()
@staticmethod
@abc.abstractmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create SecurityHandlerFactory to create all security check handlers """
def add_paths(self, paths: t.Optional[dict] = None) -> None:
"""
Adds the paths defined in the specification as endpoints
"""
paths = paths or self.specification.get('paths', dict())
for path, methods in paths.items():
logger.debug('Adding %s%s...', self.base_path, path)
for method in methods:
if method not in METHODS:
continue
try:
self.add_operation(path, method)
except ResolverError as err:
# If we have an error handler for resolver errors, add it as an operation.
# Otherwise treat it as any other error.
if self.resolver_error_handler is not None:
self._add_resolver_error_handler(method, path, err)
else:
self._handle_add_operation_error(path, method, err.exc_info)
except Exception:
# All other relevant exceptions should be handled as well.
self._handle_add_operation_error(path, method, sys.exc_info())
def add_operation(self, path: str, method: str) -> None:
raise NotImplementedError
@abc.abstractmethod
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
"""
Adds the operation according to the user framework in use.
It will be used to register the operation on the user framework router.
"""
def _add_resolver_error_handler(self, method: str, path: str, err: ResolverError):
"""
Adds a handler for ResolverError for the given method and path.
"""
operation = self.resolver_error_handler(
err,
security=self.specification.security,
security_definitions=self.specification.security_definitions
)
self._add_operation_internal(method, path, operation)
def _handle_add_operation_error(self, path: str, method: str, exc_info: tuple):
url = f'{self.base_path}{path}'
error_msg = 'Failed to add operation for {method} {url}'.format(
method=method.upper(),
url=url)
if self.debug:
logger.exception(error_msg)
else:
logger.error(error_msg)
_type, value, traceback = exc_info
raise value.with_traceback(traceback)
class AbstractAPI(AbstractMinimalAPI, metaclass=AbstractAPIMeta):
""" """
Defines an abstract interface for a Swagger API Defines an abstract interface for a Swagger API
""" """
@@ -120,55 +225,17 @@ class AbstractAPI(AbstractSpecAPI):
validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None, validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None,
): ):
""" """
:type specification: pathlib.Path | dict
:type base_path: str | None
:type arguments: dict | None
:type validate_responses: bool :type validate_responses: bool
:type strict_validation: bool :type strict_validation: bool
:type auth_all_paths: bool :type auth_all_paths: bool
:type debug: bool
:param validator_map: Custom validators for the types "parameter", "body" and "response". :param validator_map: Custom validators for the types "parameter", "body" and "response".
:type validator_map: dict :type validator_map: dict
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: If given, a callable that generates an
Operation used for handling ResolveErrors
:type resolver_error_handler: callable | None :type resolver_error_handler: callable | None
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended
to any shadowed built-ins to any shadowed built-ins
:type pythonic_params: bool :type pythonic_params: bool
:param options: New style options dictionary.
:type options: dict | None
:param pass_context_arg_name: If not None URL request handling functions with an argument matching this name
will be passed the framework's request context.
:type pass_context_arg_name: str | None
""" """
self.debug = debug
self.validator_map = validator_map self.validator_map = validator_map
self.resolver_error_handler = resolver_error_handler
logger.debug('Loading specification: %s', specification,
extra={'swagger_yaml': specification,
'base_path': base_path,
'arguments': arguments,
'auth_all_paths': auth_all_paths})
# Avoid validator having ability to modify specification
self.specification = Specification.load(specification, arguments=arguments)
logger.debug('Read specification', extra={'spec': self.specification})
self.options = ConnexionOptions(options, oas_version=self.specification.version)
logger.debug('Options Loaded',
extra={'swagger_ui': self.options.openapi_console_ui_available,
'swagger_path': self.options.openapi_console_ui_from_dir,
'swagger_url': self.options.openapi_console_ui_path})
self._set_base_path(base_path)
logger.debug('Security Definitions: %s', self.specification.security_definitions)
self.resolver = resolver or Resolver()
logger.debug('Validate Responses: %s', str(validate_responses)) logger.debug('Validate Responses: %s', str(validate_responses))
self.validate_responses = validate_responses self.validate_responses = validate_responses
@@ -179,14 +246,10 @@ class AbstractAPI(AbstractSpecAPI):
logger.debug('Pythonic params: %s', str(pythonic_params)) logger.debug('Pythonic params: %s', str(pythonic_params))
self.pythonic_params = pythonic_params self.pythonic_params = pythonic_params
logger.debug('pass_context_arg_name: %s', pass_context_arg_name) super().__init__(specification, base_path=base_path, arguments=arguments,
self.pass_context_arg_name = pass_context_arg_name resolver=resolver, auth_all_paths=auth_all_paths,
resolver_error_handler=resolver_error_handler,
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name) debug=debug, pass_context_arg_name=pass_context_arg_name, options=options)
super().__init__(specification, base_path=base_path, arguments=arguments, options=options)
self.add_paths()
if auth_all_paths: if auth_all_paths:
self.add_auth_on_not_found( self.add_auth_on_not_found(
@@ -200,11 +263,6 @@ class AbstractAPI(AbstractSpecAPI):
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass. Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
""" """
@staticmethod
@abc.abstractmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create SecurityHandlerFactory to create all security check handlers """
def add_operation(self, path, method): def add_operation(self, path, method):
""" """
Adds one operation to the api. Adds one operation to the api.
@@ -236,62 +294,6 @@ class AbstractAPI(AbstractSpecAPI):
) )
self._add_operation_internal(method, path, operation) self._add_operation_internal(method, path, operation)
@abc.abstractmethod
def _add_operation_internal(self, method, path, operation):
"""
Adds the operation according to the user framework in use.
It will be used to register the operation on the user framework router.
"""
def _add_resolver_error_handler(self, method, path, err):
"""
Adds a handler for ResolverError for the given method and path.
"""
operation = self.resolver_error_handler(
err,
security=self.specification.security,
security_definitions=self.specification.security_definitions
)
self._add_operation_internal(method, path, operation)
def add_paths(self, paths=None):
"""
Adds the paths defined in the specification as endpoints
:type paths: list
"""
paths = paths or self.specification.get('paths', dict())
for path, methods in paths.items():
logger.debug('Adding %s%s...', self.base_path, path)
for method in methods:
if method not in METHODS:
continue
try:
self.add_operation(path, method)
except ResolverError as err:
# If we have an error handler for resolver errors, add it as an operation.
# Otherwise treat it as any other error.
if self.resolver_error_handler is not None:
self._add_resolver_error_handler(method, path, err)
else:
self._handle_add_operation_error(path, method, err.exc_info)
except Exception:
# All other relevant exceptions should be handled as well.
self._handle_add_operation_error(path, method, sys.exc_info())
def _handle_add_operation_error(self, path, method, exc_info):
url = f'{self.base_path}{path}'
error_msg = 'Failed to add operation for {method} {url}'.format(
method=method.upper(),
url=url)
if self.debug:
logger.exception(error_msg)
else:
logger.error(error_msg)
_type, value, traceback = exc_info
raise value.with_traceback(traceback)
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def get_request(self, *args, **kwargs): def get_request(self, *args, **kwargs):

View File

@@ -7,6 +7,7 @@ import abc
import logging import logging
import pathlib import pathlib
from ..middleware import ConnexionMiddleware
from ..options import ConnexionOptions from ..options import ConnexionOptions
from ..resolver import Resolver from ..resolver import Resolver
@@ -16,7 +17,7 @@ logger = logging.getLogger('connexion.app')
class AbstractApp(metaclass=abc.ABCMeta): class AbstractApp(metaclass=abc.ABCMeta):
def __init__(self, import_name, api_cls, port=None, specification_dir='', def __init__(self, import_name, api_cls, port=None, specification_dir='',
host=None, server=None, server_args=None, arguments=None, auth_all_paths=False, debug=None, host=None, server=None, server_args=None, arguments=None, auth_all_paths=False, debug=None,
resolver=None, options=None, skip_error_handlers=False): resolver=None, options=None, skip_error_handlers=False, middlewares=None):
""" """
:param import_name: the name of the application package :param import_name: the name of the application package
:type import_name: str :type import_name: str
@@ -37,6 +38,8 @@ class AbstractApp(metaclass=abc.ABCMeta):
:param debug: include debugging information :param debug: include debugging information
:type debug: bool :type debug: bool
:param resolver: Callable that maps operationID to a function :param resolver: Callable that maps operationID to a function
:param middlewares: Callable that maps operationID to a function
:type middlewares: list | None
""" """
self.port = port self.port = port
self.host = host self.host = host
@@ -54,8 +57,12 @@ class AbstractApp(metaclass=abc.ABCMeta):
self.server = server self.server = server
self.server_args = dict() if server_args is None else server_args self.server_args = dict() if server_args is None else server_args
self.app = self.create_app() self.app = self.create_app()
self.middleware = self._apply_middleware()
if middlewares is None:
middlewares = ConnexionMiddleware.default_middlewares
self.middleware = self._apply_middleware(middlewares)
# we get our application root path to avoid duplicating logic # we get our application root path to avoid duplicating logic
self.root_path = self.get_root_path() self.root_path = self.get_root_path()
@@ -80,7 +87,7 @@ class AbstractApp(metaclass=abc.ABCMeta):
""" """
@abc.abstractmethod @abc.abstractmethod
def _apply_middleware(self): def _apply_middleware(self, middlewares):
""" """
Apply middleware to application Apply middleware to application
""" """

View File

@@ -23,6 +23,7 @@ logger = logging.getLogger('connexion.app')
class FlaskApp(AbstractApp): class FlaskApp(AbstractApp):
def __init__(self, import_name, server='flask', extra_files=None, **kwargs): def __init__(self, import_name, server='flask', extra_files=None, **kwargs):
""" """
:param extra_files: additional files to be watched by the reloader, defaults to the swagger specs of added apis :param extra_files: additional files to be watched by the reloader, defaults to the swagger specs of added apis
@@ -41,8 +42,8 @@ class FlaskApp(AbstractApp):
app.url_map.converters['int'] = IntegerConverter app.url_map.converters['int'] = IntegerConverter
return app return app
def _apply_middleware(self): def _apply_middleware(self, middlewares):
middlewares = [*ConnexionMiddleware.default_middlewares, middlewares = [*middlewares,
a2wsgi.WSGIMiddleware] a2wsgi.WSGIMiddleware]
middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares) middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)

View File

@@ -96,6 +96,17 @@ class BadRequestProblem(ProblemException):
super().__init__(status=400, title=title, detail=detail) super().__init__(status=400, title=title, detail=detail)
class NotFoundProblem(ProblemException):
description = (
'The requested URL was not found on the server. If you entered the URL manually please '
'check your spelling and try again.'
)
def __init__(self, title="Not Found", detail=description):
super().__init__(status=404, title=title, detail=detail)
class UnsupportedMediaTypeProblem(ProblemException): class UnsupportedMediaTypeProblem(ProblemException):
def __init__(self, title="Unsupported Media Type", detail=None): def __init__(self, title="Unsupported Media Type", detail=None):

View File

@@ -2,6 +2,8 @@
This module defines interfaces for requests and responses used in Connexion for authentication, This module defines interfaces for requests and responses used in Connexion for authentication,
validation, serialization, etc. validation, serialization, etc.
""" """
from starlette.requests import Request as StarletteRequest
from starlette.responses import StreamingResponse as StarletteStreamingResponse
class ConnexionRequest: class ConnexionRequest:
@@ -52,3 +54,11 @@ class ConnexionResponse:
self.body = body self.body = body
self.headers = headers or {} self.headers = headers or {}
self.is_streamed = is_streamed self.is_streamed = is_streamed
class MiddlewareRequest(StarletteRequest):
"""Wraps starlette Request so it can easily be extended."""
class MiddlewareResponse(StarletteStreamingResponse):
"""Wraps starlette StreamingResponse so it can easily be extended."""

View File

@@ -1,2 +1,4 @@
from .abstract import AppMiddleware # NOQA
from .main import ConnexionMiddleware # NOQA from .main import ConnexionMiddleware # NOQA
from .routing import RoutingMiddleware # NOQA
from .swagger_ui import SwaggerUIMiddleware # NOQA from .swagger_ui import SwaggerUIMiddleware # NOQA

View File

@@ -4,6 +4,8 @@ import typing as t
class AppMiddleware(abc.ABC): class AppMiddleware(abc.ABC):
"""Middlewares that need the APIs to be registered on them should inherit from this base
class"""
@abc.abstractmethod @abc.abstractmethod
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None: def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None:

View File

@@ -0,0 +1,33 @@
import json
from starlette.exceptions import \
ExceptionMiddleware as StarletteExceptionMiddleware
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from connexion.exceptions import problem
class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior."""
def http_exception(self, request: Request, exc: HTTPException) -> Response:
try:
headers = exc.headers
except AttributeError:
# Starlette < 0.19
headers = {}
connexion_response = problem(title=exc.detail,
detail=exc.detail,
status=exc.status_code,
headers=headers)
return Response(
content=json.dumps(connexion_response.body),
status_code=connexion_response.status_code,
media_type=connexion_response.mimetype,
headers=connexion_response.headers
)

View File

@@ -1,10 +1,11 @@
import pathlib import pathlib
import typing as t import typing as t
from starlette.exceptions import ExceptionMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.middleware.base import AppMiddleware from connexion.middleware.abstract import AppMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware from connexion.middleware.swagger_ui import SwaggerUIMiddleware
@@ -13,6 +14,7 @@ class ConnexionMiddleware:
default_middlewares = [ default_middlewares = [
ExceptionMiddleware, ExceptionMiddleware,
SwaggerUIMiddleware, SwaggerUIMiddleware,
RoutingMiddleware,
] ]
def __init__( def __init__(

View File

@@ -0,0 +1,170 @@
import pathlib
import typing as t
from contextlib import contextmanager
from contextvars import ContextVar
from starlette.requests import Request as StarletteRequest
from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractMinimalAPI
from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation, make_operation
from connexion.resolver import Resolver
CONNEXION_CONTEXT = 'connexion.context'
_scope_receive_send: ContextVar[tuple] = ContextVar('SCOPE_RECEIVE_SEND')
class MiddlewareResolver(Resolver):
def __init__(self, call_next: t.Callable) -> None:
"""Resolver that resolves each operation to the provided call_next function."""
super().__init__()
self.call_next = call_next
def resolve_function_from_operation_id(self, operation_id: str) -> t.Callable:
return self.call_next
class RoutingMiddleware(AppMiddleware):
def __init__(self, app: ASGIApp) -> None:
"""Middleware that resolves the Operation for an incoming request and attaches it to the
scope.
:param app: app to wrap in middleware.
"""
self.app = app
# Pass unknown routes to next app
self.router = Router(default=self.default_fn)
def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
**kwargs
) -> None:
"""Add an API to the router based on a OpenAPI spec.
:param specification: OpenAPI spec as dict or path to file.
:param base_path: Base path where to add this API.
:param arguments: Jinja arguments to replace in the spec.
"""
kwargs.pop("resolver", None)
resolver = MiddlewareResolver(self.create_call_next())
api = MiddlewareAPI(specification, base_path=base_path, arguments=arguments,
resolver=resolver, default=self.default_fn, **kwargs)
self.router.mount(api.base_path, app=api.router)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Route request to matching operation, and attach it to the scope before calling the
next app."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
_scope_receive_send.set((scope.copy(), receive, send))
# Needs to be set so starlette router throws exceptions instead of returning error responses
scope['app'] = self
try:
await self.router(scope, receive, send)
except ValueError:
raise NotFoundProblem
async def default_fn(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Callback to call next app as default when no matching route is found."""
original_scope, *_ = _scope_receive_send.get()
api_base_path = scope.get('root_path', '')[len(original_scope.get('root_path', '')):]
extensions = original_scope.setdefault('extensions', {})
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
connexion_context.update({
'api_base_path': api_base_path
})
await self.app(original_scope, receive, send)
def create_call_next(self):
async def call_next(
operation: AbstractOperation,
request: StarletteRequest = None
) -> None:
"""Attach operation to scope and pass it to the next app"""
scope, receive, send = _scope_receive_send.get()
api_base_path = request.scope.get('root_path', '')[len(scope.get('root_path', '')):]
extensions = scope.setdefault('extensions', {})
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
connexion_context.update({
'api_base_path': api_base_path,
'operation_id': operation.operation_id
})
return await self.app(scope, receive, send)
return call_next
class MiddlewareAPI(AbstractMinimalAPI):
def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
resolver: t.Optional[Resolver] = None,
default: ASGIApp = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
**kwargs
) -> None:
"""API implementation on top of Starlette Router for Connexion middleware."""
self.router = Router(default=default)
super().__init__(
specification,
base_path=base_path,
arguments=arguments,
resolver=resolver,
resolver_error_handler=resolver_error_handler,
debug=debug
)
def add_operation(self, path: str, method: str) -> None:
operation = make_operation(
self.specification,
self,
path,
method,
self.resolver
)
@contextmanager
def patch_operation_function():
"""Patch the operation function so no decorators are set in the middleware. This
should be cleaned up by separating the APIs and Operations between the App and
middleware"""
original_operation_function = AbstractOperation.function
AbstractOperation.function = operation._resolution.function
try:
yield
finally:
AbstractOperation.function = original_operation_function
with patch_operation_function():
self._add_operation_internal(method, path, operation)
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
self.router.add_route(path, operation.function, methods=[method])
@staticmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create default SecurityHandlerFactory to create all security check handlers """
pass

View File

@@ -13,10 +13,9 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractSwaggerUIAPI from connexion.apis import AbstractSwaggerUIAPI
from connexion.jsonifier import JSONEncoder, Jsonifier from connexion.jsonifier import JSONEncoder, Jsonifier
from connexion.middleware import AppMiddleware
from connexion.utils import yamldumper from connexion.utils import yamldumper
from .base import AppMiddleware
logger = logging.getLogger('connexion.middleware.swagger_ui') logger = logging.getLogger('connexion.middleware.swagger_ui')

View File

@@ -27,7 +27,6 @@ def test_errors(problem_app):
error405 = json.loads(get_greeting.data.decode('utf-8', 'replace')) error405 = json.loads(get_greeting.data.decode('utf-8', 'replace'))
assert error405['type'] == 'about:blank' assert error405['type'] == 'about:blank'
assert error405['title'] == 'Method Not Allowed' assert error405['title'] == 'Method Not Allowed'
assert error405['detail'] == 'The method is not allowed for the requested URL.'
assert error405['status'] == 405 assert error405['status'] == 405
assert 'instance' not in error405 assert 'instance' not in error405

View File

@@ -1,7 +1,6 @@
import json import json
import logging import logging
import pathlib import pathlib
import sys
import pytest import pytest
from connexion import App from connexion import App
@@ -136,7 +135,7 @@ def json_datetime_dir():
return FIXTURES_FOLDER / 'datetime_support' return FIXTURES_FOLDER / 'datetime_support'
def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs): def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', middlewares=None, **kwargs):
debug = True debug = True
if 'debug' in kwargs: if 'debug' in kwargs:
debug = kwargs['debug'] debug = kwargs['debug']
@@ -145,6 +144,7 @@ def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs):
cnx_app = App(__name__, cnx_app = App(__name__,
port=5001, port=5001,
specification_dir=FIXTURES_FOLDER / api_spec_folder, specification_dir=FIXTURES_FOLDER / api_spec_folder,
middlewares=middlewares,
debug=debug) debug=debug)
cnx_app.add_api(spec_file, **kwargs) cnx_app.add_api(spec_file, **kwargs)
@@ -254,4 +254,3 @@ def unordered_definition_app(request):
def bad_operations_app(request): def bad_operations_app(request):
return build_app_from_fixture('bad_operations', request.param, return build_app_from_fixture('bad_operations', request.param,
resolver_error=501) resolver_error=501)

View File

@@ -13,15 +13,3 @@ def fake_json_auth(token, required_scopes=None):
return json.loads(token) return json.loads(token)
except ValueError: except ValueError:
return None return None
async def async_basic_auth(username, password, required_scopes=None, request=None):
return fake_basic_auth(username, password, required_scopes)
async def async_json_auth(token, required_scopes=None, request=None):
return fake_json_auth(token, required_scopes)
async def async_scope_validation(required_scopes, token_scopes, request):
return required_scopes == token_scopes

44
tests/test_middleware.py Normal file
View File

@@ -0,0 +1,44 @@
import pytest
from connexion.middleware import ConnexionMiddleware
from connexion.middleware.routing import CONNEXION_CONTEXT
from starlette.datastructures import MutableHeaders
from conftest import SPECS, build_app_from_fixture
class TestMiddleware:
"""Middleware to check if operation is accessible on scope."""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
operation_id = scope['extensions'][CONNEXION_CONTEXT]['operation_id']
async def patched_send(message):
if message["type"] != "http.response.start":
await send(message)
return
message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers["operation_id"] = operation_id
await send(message)
await self.app(scope, receive, patched_send)
@pytest.fixture(scope="session", params=SPECS)
def middleware_app(request):
middlewares = ConnexionMiddleware.default_middlewares + [TestMiddleware]
return build_app_from_fixture('simple', request.param, middlewares=middlewares)
def test_routing_middleware(middleware_app):
app_client = middleware_app.app.test_client()
response = app_client.post("/v1.0/greeting/robbe")
assert response.headers.get('operation_id') == 'fakeapi.hello.post_greeting', \
response.status_code