diff --git a/connexion/apis/__init__.py b/connexion/apis/__init__.py index e7d0778..da5a4bf 100644 --- a/connexion/apis/__init__.py +++ b/connexion/apis/__init__.py @@ -13,4 +13,5 @@ on the framework app. """ -from .abstract import AbstractAPI, AbstractSwaggerUIAPI # NOQA +from .abstract import (AbstractAPI, AbstractMinimalAPI, # NOQA + AbstractSwaggerUIAPI) diff --git a/connexion/apis/abstract.py b/connexion/apis/abstract.py index 1864563..acc3651 100644 --- a/connexion/apis/abstract.py +++ b/connexion/apis/abstract.py @@ -14,7 +14,7 @@ from ..exceptions import ResolverError from ..http_facts import METHODS from ..jsonifier import Jsonifier from ..lifecycle import ConnexionResponse -from ..operations import make_operation +from ..operations import AbstractOperation, make_operation from ..options import ConnexionOptions from ..resolver import Resolver from ..spec import Specification @@ -43,7 +43,14 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta): *args, **kwargs ): - """Base API class with only minimal behavior related to the specification.""" + """Base API class with only minimal behavior related to the specification. + + :param specification: OpenAPI specification. Can be provided either as dict, or as path + to file. + :param base_path: Base path to host the API. + :param arguments: Jinja arguments to resolve in specification. + :param options: New style options dictionary. + """ logger.debug('Loading specification: %s', specification, extra={'swagger_yaml': specification, 'base_path': base_path, @@ -109,7 +116,105 @@ class AbstractSwaggerUIAPI(AbstractSpecAPI): """ -class AbstractAPI(AbstractSpecAPI): +class AbstractMinimalAPI(AbstractSpecAPI): + + def __init__( + self, + *args, + resolver: t.Optional[Resolver] = None, + resolver_error_handler: t.Optional[t.Callable] = None, + debug: bool = False, + pass_context_arg_name: t.Optional[str] = None, + **kwargs + ) -> None: + """Minimal interface of an API, with only functionality related to routing. + + :param resolver: Callable that maps operationID to a function + :param resolver_error_handler: Callable that generates an Operation used for handling + ResolveErrors + :param debug: Flag to run in debug mode + :param pass_context_arg_name: If not None URL request handling functions with an argument + matching this name will be passed the framework's request context. + """ + super().__init__(*args, **kwargs) + self.debug = debug + self.resolver_error_handler = resolver_error_handler + + logger.debug('Security Definitions: %s', self.specification.security_definitions) + + self.resolver = resolver or Resolver() + + logger.debug('pass_context_arg_name: %s', pass_context_arg_name) + self.pass_context_arg_name = pass_context_arg_name + + self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name) + + self.add_paths() + + @staticmethod + @abc.abstractmethod + def make_security_handler_factory(pass_context_arg_name): + """ Create SecurityHandlerFactory to create all security check handlers """ + + def add_paths(self, paths: t.Optional[dict] = None) -> None: + """ + Adds the paths defined in the specification as endpoints + """ + paths = paths or self.specification.get('paths', dict()) + for path, methods in paths.items(): + logger.debug('Adding %s%s...', self.base_path, path) + + for method in methods: + if method not in METHODS: + continue + try: + self.add_operation(path, method) + except ResolverError as err: + # If we have an error handler for resolver errors, add it as an operation. + # Otherwise treat it as any other error. + if self.resolver_error_handler is not None: + self._add_resolver_error_handler(method, path, err) + else: + self._handle_add_operation_error(path, method, err.exc_info) + except Exception: + # All other relevant exceptions should be handled as well. + self._handle_add_operation_error(path, method, sys.exc_info()) + + def add_operation(self, path: str, method: str) -> None: + raise NotImplementedError + + @abc.abstractmethod + def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None: + """ + Adds the operation according to the user framework in use. + It will be used to register the operation on the user framework router. + """ + + def _add_resolver_error_handler(self, method: str, path: str, err: ResolverError): + """ + Adds a handler for ResolverError for the given method and path. + """ + operation = self.resolver_error_handler( + err, + security=self.specification.security, + security_definitions=self.specification.security_definitions + ) + self._add_operation_internal(method, path, operation) + + def _handle_add_operation_error(self, path: str, method: str, exc_info: tuple): + url = f'{self.base_path}{path}' + error_msg = 'Failed to add operation for {method} {url}'.format( + method=method.upper(), + url=url) + if self.debug: + logger.exception(error_msg) + else: + logger.error(error_msg) + _type, value, traceback = exc_info + raise value.with_traceback(traceback) + + +class AbstractAPI(AbstractMinimalAPI, metaclass=AbstractAPIMeta): """ Defines an abstract interface for a Swagger API """ @@ -120,55 +225,17 @@ class AbstractAPI(AbstractSpecAPI): validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None, ): """ - :type specification: pathlib.Path | dict - :type base_path: str | None - :type arguments: dict | None :type validate_responses: bool :type strict_validation: bool :type auth_all_paths: bool - :type debug: bool :param validator_map: Custom validators for the types "parameter", "body" and "response". :type validator_map: dict - :param resolver: Callable that maps operationID to a function - :param resolver_error_handler: If given, a callable that generates an - Operation used for handling ResolveErrors :type resolver_error_handler: callable | None :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to any shadowed built-ins :type pythonic_params: bool - :param options: New style options dictionary. - :type options: dict | None - :param pass_context_arg_name: If not None URL request handling functions with an argument matching this name - will be passed the framework's request context. - :type pass_context_arg_name: str | None """ - self.debug = debug self.validator_map = validator_map - self.resolver_error_handler = resolver_error_handler - - logger.debug('Loading specification: %s', specification, - extra={'swagger_yaml': specification, - 'base_path': base_path, - 'arguments': arguments, - 'auth_all_paths': auth_all_paths}) - - # Avoid validator having ability to modify specification - self.specification = Specification.load(specification, arguments=arguments) - - logger.debug('Read specification', extra={'spec': self.specification}) - - self.options = ConnexionOptions(options, oas_version=self.specification.version) - - logger.debug('Options Loaded', - extra={'swagger_ui': self.options.openapi_console_ui_available, - 'swagger_path': self.options.openapi_console_ui_from_dir, - 'swagger_url': self.options.openapi_console_ui_path}) - - self._set_base_path(base_path) - - logger.debug('Security Definitions: %s', self.specification.security_definitions) - - self.resolver = resolver or Resolver() logger.debug('Validate Responses: %s', str(validate_responses)) self.validate_responses = validate_responses @@ -179,14 +246,10 @@ class AbstractAPI(AbstractSpecAPI): logger.debug('Pythonic params: %s', str(pythonic_params)) self.pythonic_params = pythonic_params - logger.debug('pass_context_arg_name: %s', pass_context_arg_name) - self.pass_context_arg_name = pass_context_arg_name - - self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name) - - super().__init__(specification, base_path=base_path, arguments=arguments, options=options) - - self.add_paths() + super().__init__(specification, base_path=base_path, arguments=arguments, + resolver=resolver, auth_all_paths=auth_all_paths, + resolver_error_handler=resolver_error_handler, + debug=debug, pass_context_arg_name=pass_context_arg_name, options=options) if auth_all_paths: self.add_auth_on_not_found( @@ -200,11 +263,6 @@ class AbstractAPI(AbstractSpecAPI): Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass. """ - @staticmethod - @abc.abstractmethod - def make_security_handler_factory(pass_context_arg_name): - """ Create SecurityHandlerFactory to create all security check handlers """ - def add_operation(self, path, method): """ Adds one operation to the api. @@ -236,62 +294,6 @@ class AbstractAPI(AbstractSpecAPI): ) self._add_operation_internal(method, path, operation) - @abc.abstractmethod - def _add_operation_internal(self, method, path, operation): - """ - Adds the operation according to the user framework in use. - It will be used to register the operation on the user framework router. - """ - - def _add_resolver_error_handler(self, method, path, err): - """ - Adds a handler for ResolverError for the given method and path. - """ - operation = self.resolver_error_handler( - err, - security=self.specification.security, - security_definitions=self.specification.security_definitions - ) - self._add_operation_internal(method, path, operation) - - def add_paths(self, paths=None): - """ - Adds the paths defined in the specification as endpoints - - :type paths: list - """ - paths = paths or self.specification.get('paths', dict()) - for path, methods in paths.items(): - logger.debug('Adding %s%s...', self.base_path, path) - - for method in methods: - if method not in METHODS: - continue - try: - self.add_operation(path, method) - except ResolverError as err: - # If we have an error handler for resolver errors, add it as an operation. - # Otherwise treat it as any other error. - if self.resolver_error_handler is not None: - self._add_resolver_error_handler(method, path, err) - else: - self._handle_add_operation_error(path, method, err.exc_info) - except Exception: - # All other relevant exceptions should be handled as well. - self._handle_add_operation_error(path, method, sys.exc_info()) - - def _handle_add_operation_error(self, path, method, exc_info): - url = f'{self.base_path}{path}' - error_msg = 'Failed to add operation for {method} {url}'.format( - method=method.upper(), - url=url) - if self.debug: - logger.exception(error_msg) - else: - logger.error(error_msg) - _type, value, traceback = exc_info - raise value.with_traceback(traceback) - @classmethod @abc.abstractmethod def get_request(self, *args, **kwargs): diff --git a/connexion/apps/abstract.py b/connexion/apps/abstract.py index 1a454e4..d6fe6d5 100644 --- a/connexion/apps/abstract.py +++ b/connexion/apps/abstract.py @@ -7,6 +7,7 @@ import abc import logging import pathlib +from ..middleware import ConnexionMiddleware from ..options import ConnexionOptions from ..resolver import Resolver @@ -16,7 +17,7 @@ logger = logging.getLogger('connexion.app') class AbstractApp(metaclass=abc.ABCMeta): def __init__(self, import_name, api_cls, port=None, specification_dir='', host=None, server=None, server_args=None, arguments=None, auth_all_paths=False, debug=None, - resolver=None, options=None, skip_error_handlers=False): + resolver=None, options=None, skip_error_handlers=False, middlewares=None): """ :param import_name: the name of the application package :type import_name: str @@ -37,6 +38,8 @@ class AbstractApp(metaclass=abc.ABCMeta): :param debug: include debugging information :type debug: bool :param resolver: Callable that maps operationID to a function + :param middlewares: Callable that maps operationID to a function + :type middlewares: list | None """ self.port = port self.host = host @@ -54,8 +57,12 @@ class AbstractApp(metaclass=abc.ABCMeta): self.server = server self.server_args = dict() if server_args is None else server_args + self.app = self.create_app() - self.middleware = self._apply_middleware() + + if middlewares is None: + middlewares = ConnexionMiddleware.default_middlewares + self.middleware = self._apply_middleware(middlewares) # we get our application root path to avoid duplicating logic self.root_path = self.get_root_path() @@ -80,7 +87,7 @@ class AbstractApp(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def _apply_middleware(self): + def _apply_middleware(self, middlewares): """ Apply middleware to application """ diff --git a/connexion/apps/flask_app.py b/connexion/apps/flask_app.py index e886173..c8da4d2 100644 --- a/connexion/apps/flask_app.py +++ b/connexion/apps/flask_app.py @@ -23,6 +23,7 @@ logger = logging.getLogger('connexion.app') class FlaskApp(AbstractApp): + def __init__(self, import_name, server='flask', extra_files=None, **kwargs): """ :param extra_files: additional files to be watched by the reloader, defaults to the swagger specs of added apis @@ -41,8 +42,8 @@ class FlaskApp(AbstractApp): app.url_map.converters['int'] = IntegerConverter return app - def _apply_middleware(self): - middlewares = [*ConnexionMiddleware.default_middlewares, + def _apply_middleware(self, middlewares): + middlewares = [*middlewares, a2wsgi.WSGIMiddleware] middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares) diff --git a/connexion/exceptions.py b/connexion/exceptions.py index 0436b0c..e1311f3 100644 --- a/connexion/exceptions.py +++ b/connexion/exceptions.py @@ -96,6 +96,17 @@ class BadRequestProblem(ProblemException): super().__init__(status=400, title=title, detail=detail) +class NotFoundProblem(ProblemException): + + description = ( + 'The requested URL was not found on the server. If you entered the URL manually please ' + 'check your spelling and try again.' + ) + + def __init__(self, title="Not Found", detail=description): + super().__init__(status=404, title=title, detail=detail) + + class UnsupportedMediaTypeProblem(ProblemException): def __init__(self, title="Unsupported Media Type", detail=None): diff --git a/connexion/lifecycle.py b/connexion/lifecycle.py index bfe741a..231d9d6 100644 --- a/connexion/lifecycle.py +++ b/connexion/lifecycle.py @@ -2,6 +2,8 @@ This module defines interfaces for requests and responses used in Connexion for authentication, validation, serialization, etc. """ +from starlette.requests import Request as StarletteRequest +from starlette.responses import StreamingResponse as StarletteStreamingResponse class ConnexionRequest: @@ -52,3 +54,11 @@ class ConnexionResponse: self.body = body self.headers = headers or {} self.is_streamed = is_streamed + + +class MiddlewareRequest(StarletteRequest): + """Wraps starlette Request so it can easily be extended.""" + + +class MiddlewareResponse(StarletteStreamingResponse): + """Wraps starlette StreamingResponse so it can easily be extended.""" diff --git a/connexion/middleware/__init__.py b/connexion/middleware/__init__.py index 136930c..302bc67 100644 --- a/connexion/middleware/__init__.py +++ b/connexion/middleware/__init__.py @@ -1,2 +1,4 @@ +from .abstract import AppMiddleware # NOQA from .main import ConnexionMiddleware # NOQA +from .routing import RoutingMiddleware # NOQA from .swagger_ui import SwaggerUIMiddleware # NOQA diff --git a/connexion/middleware/base.py b/connexion/middleware/abstract.py similarity index 65% rename from connexion/middleware/base.py rename to connexion/middleware/abstract.py index 1e35c0e..4afbc24 100644 --- a/connexion/middleware/base.py +++ b/connexion/middleware/abstract.py @@ -4,6 +4,8 @@ import typing as t class AppMiddleware(abc.ABC): + """Middlewares that need the APIs to be registered on them should inherit from this base + class""" @abc.abstractmethod def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None: diff --git a/connexion/middleware/exceptions.py b/connexion/middleware/exceptions.py new file mode 100644 index 0000000..3e07578 --- /dev/null +++ b/connexion/middleware/exceptions.py @@ -0,0 +1,33 @@ +import json + +from starlette.exceptions import \ + ExceptionMiddleware as StarletteExceptionMiddleware +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import Response + +from connexion.exceptions import problem + + +class ExceptionMiddleware(StarletteExceptionMiddleware): + """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to + existing connexion behavior.""" + + def http_exception(self, request: Request, exc: HTTPException) -> Response: + try: + headers = exc.headers + except AttributeError: + # Starlette < 0.19 + headers = {} + + connexion_response = problem(title=exc.detail, + detail=exc.detail, + status=exc.status_code, + headers=headers) + + return Response( + content=json.dumps(connexion_response.body), + status_code=connexion_response.status_code, + media_type=connexion_response.mimetype, + headers=connexion_response.headers + ) diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index 0e4a26b..863abf4 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -1,10 +1,11 @@ import pathlib import typing as t -from starlette.exceptions import ExceptionMiddleware from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.middleware.base import AppMiddleware +from connexion.middleware.abstract import AppMiddleware +from connexion.middleware.exceptions import ExceptionMiddleware +from connexion.middleware.routing import RoutingMiddleware from connexion.middleware.swagger_ui import SwaggerUIMiddleware @@ -13,6 +14,7 @@ class ConnexionMiddleware: default_middlewares = [ ExceptionMiddleware, SwaggerUIMiddleware, + RoutingMiddleware, ] def __init__( diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py new file mode 100644 index 0000000..08e528c --- /dev/null +++ b/connexion/middleware/routing.py @@ -0,0 +1,170 @@ +import pathlib +import typing as t +from contextlib import contextmanager +from contextvars import ContextVar + +from starlette.requests import Request as StarletteRequest +from starlette.routing import Router +from starlette.types import ASGIApp, Receive, Scope, Send + +from connexion.apis import AbstractMinimalAPI +from connexion.exceptions import NotFoundProblem +from connexion.middleware import AppMiddleware +from connexion.operations import AbstractOperation, make_operation +from connexion.resolver import Resolver + +CONNEXION_CONTEXT = 'connexion.context' + + +_scope_receive_send: ContextVar[tuple] = ContextVar('SCOPE_RECEIVE_SEND') + + +class MiddlewareResolver(Resolver): + + def __init__(self, call_next: t.Callable) -> None: + """Resolver that resolves each operation to the provided call_next function.""" + super().__init__() + self.call_next = call_next + + def resolve_function_from_operation_id(self, operation_id: str) -> t.Callable: + return self.call_next + + +class RoutingMiddleware(AppMiddleware): + + def __init__(self, app: ASGIApp) -> None: + """Middleware that resolves the Operation for an incoming request and attaches it to the + scope. + + :param app: app to wrap in middleware. + """ + self.app = app + # Pass unknown routes to next app + self.router = Router(default=self.default_fn) + + def add_api( + self, + specification: t.Union[pathlib.Path, str, dict], + base_path: t.Optional[str] = None, + arguments: t.Optional[dict] = None, + **kwargs + ) -> None: + """Add an API to the router based on a OpenAPI spec. + + :param specification: OpenAPI spec as dict or path to file. + :param base_path: Base path where to add this API. + :param arguments: Jinja arguments to replace in the spec. + """ + kwargs.pop("resolver", None) + resolver = MiddlewareResolver(self.create_call_next()) + api = MiddlewareAPI(specification, base_path=base_path, arguments=arguments, + resolver=resolver, default=self.default_fn, **kwargs) + self.router.mount(api.base_path, app=api.router) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Route request to matching operation, and attach it to the scope before calling the + next app.""" + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + _scope_receive_send.set((scope.copy(), receive, send)) + + # Needs to be set so starlette router throws exceptions instead of returning error responses + scope['app'] = self + try: + await self.router(scope, receive, send) + except ValueError: + raise NotFoundProblem + + async def default_fn(self, scope: Scope, receive: Receive, send: Send) -> None: + """Callback to call next app as default when no matching route is found.""" + original_scope, *_ = _scope_receive_send.get() + + api_base_path = scope.get('root_path', '')[len(original_scope.get('root_path', '')):] + + extensions = original_scope.setdefault('extensions', {}) + connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {}) + connexion_context.update({ + 'api_base_path': api_base_path + }) + await self.app(original_scope, receive, send) + + def create_call_next(self): + + async def call_next( + operation: AbstractOperation, + request: StarletteRequest = None + ) -> None: + """Attach operation to scope and pass it to the next app""" + scope, receive, send = _scope_receive_send.get() + + api_base_path = request.scope.get('root_path', '')[len(scope.get('root_path', '')):] + + extensions = scope.setdefault('extensions', {}) + connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {}) + connexion_context.update({ + 'api_base_path': api_base_path, + 'operation_id': operation.operation_id + }) + return await self.app(scope, receive, send) + + return call_next + + +class MiddlewareAPI(AbstractMinimalAPI): + + def __init__( + self, + specification: t.Union[pathlib.Path, str, dict], + base_path: t.Optional[str] = None, + arguments: t.Optional[dict] = None, + resolver: t.Optional[Resolver] = None, + default: ASGIApp = None, + resolver_error_handler: t.Optional[t.Callable] = None, + debug: bool = False, + **kwargs + ) -> None: + """API implementation on top of Starlette Router for Connexion middleware.""" + self.router = Router(default=default) + + super().__init__( + specification, + base_path=base_path, + arguments=arguments, + resolver=resolver, + resolver_error_handler=resolver_error_handler, + debug=debug + ) + + def add_operation(self, path: str, method: str) -> None: + operation = make_operation( + self.specification, + self, + path, + method, + self.resolver + ) + + @contextmanager + def patch_operation_function(): + """Patch the operation function so no decorators are set in the middleware. This + should be cleaned up by separating the APIs and Operations between the App and + middleware""" + original_operation_function = AbstractOperation.function + AbstractOperation.function = operation._resolution.function + try: + yield + finally: + AbstractOperation.function = original_operation_function + + with patch_operation_function(): + self._add_operation_internal(method, path, operation) + + def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None: + self.router.add_route(path, operation.function, methods=[method]) + + @staticmethod + def make_security_handler_factory(pass_context_arg_name): + """ Create default SecurityHandlerFactory to create all security check handlers """ + pass diff --git a/connexion/middleware/swagger_ui.py b/connexion/middleware/swagger_ui.py index efc3316..b7543de 100644 --- a/connexion/middleware/swagger_ui.py +++ b/connexion/middleware/swagger_ui.py @@ -13,10 +13,9 @@ from starlette.types import ASGIApp, Receive, Scope, Send from connexion.apis import AbstractSwaggerUIAPI from connexion.jsonifier import JSONEncoder, Jsonifier +from connexion.middleware import AppMiddleware from connexion.utils import yamldumper -from .base import AppMiddleware - logger = logging.getLogger('connexion.middleware.swagger_ui') diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py index 6798c10..6c43191 100644 --- a/tests/api/test_errors.py +++ b/tests/api/test_errors.py @@ -27,7 +27,6 @@ def test_errors(problem_app): error405 = json.loads(get_greeting.data.decode('utf-8', 'replace')) assert error405['type'] == 'about:blank' assert error405['title'] == 'Method Not Allowed' - assert error405['detail'] == 'The method is not allowed for the requested URL.' assert error405['status'] == 405 assert 'instance' not in error405 diff --git a/tests/conftest.py b/tests/conftest.py index 8144c76..b87672a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import json import logging import pathlib -import sys import pytest from connexion import App @@ -136,7 +135,7 @@ def json_datetime_dir(): return FIXTURES_FOLDER / 'datetime_support' -def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs): +def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', middlewares=None, **kwargs): debug = True if 'debug' in kwargs: debug = kwargs['debug'] @@ -145,6 +144,7 @@ def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs): cnx_app = App(__name__, port=5001, specification_dir=FIXTURES_FOLDER / api_spec_folder, + middlewares=middlewares, debug=debug) cnx_app.add_api(spec_file, **kwargs) @@ -254,4 +254,3 @@ def unordered_definition_app(request): def bad_operations_app(request): return build_app_from_fixture('bad_operations', request.param, resolver_error=501) - diff --git a/tests/fakeapi/auth.py b/tests/fakeapi/auth.py index 396ee51..2aa7bff 100644 --- a/tests/fakeapi/auth.py +++ b/tests/fakeapi/auth.py @@ -13,15 +13,3 @@ def fake_json_auth(token, required_scopes=None): return json.loads(token) except ValueError: return None - - -async def async_basic_auth(username, password, required_scopes=None, request=None): - return fake_basic_auth(username, password, required_scopes) - - -async def async_json_auth(token, required_scopes=None, request=None): - return fake_json_auth(token, required_scopes) - - -async def async_scope_validation(required_scopes, token_scopes, request): - return required_scopes == token_scopes diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..28a0a01 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,44 @@ +import pytest +from connexion.middleware import ConnexionMiddleware +from connexion.middleware.routing import CONNEXION_CONTEXT +from starlette.datastructures import MutableHeaders + +from conftest import SPECS, build_app_from_fixture + + +class TestMiddleware: + """Middleware to check if operation is accessible on scope.""" + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + operation_id = scope['extensions'][CONNEXION_CONTEXT]['operation_id'] + + async def patched_send(message): + if message["type"] != "http.response.start": + await send(message) + return + + message.setdefault("headers", []) + headers = MutableHeaders(scope=message) + headers["operation_id"] = operation_id + + await send(message) + + await self.app(scope, receive, patched_send) + + +@pytest.fixture(scope="session", params=SPECS) +def middleware_app(request): + middlewares = ConnexionMiddleware.default_middlewares + [TestMiddleware] + return build_app_from_fixture('simple', request.param, middlewares=middlewares) + + +def test_routing_middleware(middleware_app): + app_client = middleware_app.app.test_client() + + response = app_client.post("/v1.0/greeting/robbe") + + assert response.headers.get('operation_id') == 'fakeapi.hello.post_greeting', \ + response.status_code