Add routing middleware (#1497)

* Add routing middleware

Factor out starlette BaseHTTPMiddleware

Fix exceptions for starlette < 0.19

Fix docstring formatting

Rename middleware/base.py to abstract.py

Rework routing middleware

* Clean up abstract API docstrings

* Move connexion context into extensions

* Allow empty middleware list
This commit is contained in:
Robbe Sneyders
2022-04-19 22:55:20 +02:00
committed by GitHub
parent 7f2931037e
commit 84e33e5897
16 changed files with 406 additions and 136 deletions

View File

@@ -13,4 +13,5 @@ on the framework app.
"""
from .abstract import AbstractAPI, AbstractSwaggerUIAPI # NOQA
from .abstract import (AbstractAPI, AbstractMinimalAPI, # NOQA
AbstractSwaggerUIAPI)

View File

@@ -14,7 +14,7 @@ from ..exceptions import ResolverError
from ..http_facts import METHODS
from ..jsonifier import Jsonifier
from ..lifecycle import ConnexionResponse
from ..operations import make_operation
from ..operations import AbstractOperation, make_operation
from ..options import ConnexionOptions
from ..resolver import Resolver
from ..spec import Specification
@@ -43,7 +43,14 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
*args,
**kwargs
):
"""Base API class with only minimal behavior related to the specification."""
"""Base API class with only minimal behavior related to the specification.
:param specification: OpenAPI specification. Can be provided either as dict, or as path
to file.
:param base_path: Base path to host the API.
:param arguments: Jinja arguments to resolve in specification.
:param options: New style options dictionary.
"""
logger.debug('Loading specification: %s', specification,
extra={'swagger_yaml': specification,
'base_path': base_path,
@@ -109,7 +116,105 @@ class AbstractSwaggerUIAPI(AbstractSpecAPI):
"""
class AbstractAPI(AbstractSpecAPI):
class AbstractMinimalAPI(AbstractSpecAPI):
def __init__(
self,
*args,
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
pass_context_arg_name: t.Optional[str] = None,
**kwargs
) -> None:
"""Minimal interface of an API, with only functionality related to routing.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param debug: Flag to run in debug mode
:param pass_context_arg_name: If not None URL request handling functions with an argument
matching this name will be passed the framework's request context.
"""
super().__init__(*args, **kwargs)
self.debug = debug
self.resolver_error_handler = resolver_error_handler
logger.debug('Security Definitions: %s', self.specification.security_definitions)
self.resolver = resolver or Resolver()
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
self.add_paths()
@staticmethod
@abc.abstractmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create SecurityHandlerFactory to create all security check handlers """
def add_paths(self, paths: t.Optional[dict] = None) -> None:
"""
Adds the paths defined in the specification as endpoints
"""
paths = paths or self.specification.get('paths', dict())
for path, methods in paths.items():
logger.debug('Adding %s%s...', self.base_path, path)
for method in methods:
if method not in METHODS:
continue
try:
self.add_operation(path, method)
except ResolverError as err:
# If we have an error handler for resolver errors, add it as an operation.
# Otherwise treat it as any other error.
if self.resolver_error_handler is not None:
self._add_resolver_error_handler(method, path, err)
else:
self._handle_add_operation_error(path, method, err.exc_info)
except Exception:
# All other relevant exceptions should be handled as well.
self._handle_add_operation_error(path, method, sys.exc_info())
def add_operation(self, path: str, method: str) -> None:
raise NotImplementedError
@abc.abstractmethod
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
"""
Adds the operation according to the user framework in use.
It will be used to register the operation on the user framework router.
"""
def _add_resolver_error_handler(self, method: str, path: str, err: ResolverError):
"""
Adds a handler for ResolverError for the given method and path.
"""
operation = self.resolver_error_handler(
err,
security=self.specification.security,
security_definitions=self.specification.security_definitions
)
self._add_operation_internal(method, path, operation)
def _handle_add_operation_error(self, path: str, method: str, exc_info: tuple):
url = f'{self.base_path}{path}'
error_msg = 'Failed to add operation for {method} {url}'.format(
method=method.upper(),
url=url)
if self.debug:
logger.exception(error_msg)
else:
logger.error(error_msg)
_type, value, traceback = exc_info
raise value.with_traceback(traceback)
class AbstractAPI(AbstractMinimalAPI, metaclass=AbstractAPIMeta):
"""
Defines an abstract interface for a Swagger API
"""
@@ -120,55 +225,17 @@ class AbstractAPI(AbstractSpecAPI):
validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None,
):
"""
:type specification: pathlib.Path | dict
:type base_path: str | None
:type arguments: dict | None
:type validate_responses: bool
:type strict_validation: bool
:type auth_all_paths: bool
:type debug: bool
:param validator_map: Custom validators for the types "parameter", "body" and "response".
:type validator_map: dict
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: If given, a callable that generates an
Operation used for handling ResolveErrors
:type resolver_error_handler: callable | None
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended
to any shadowed built-ins
:type pythonic_params: bool
:param options: New style options dictionary.
:type options: dict | None
:param pass_context_arg_name: If not None URL request handling functions with an argument matching this name
will be passed the framework's request context.
:type pass_context_arg_name: str | None
"""
self.debug = debug
self.validator_map = validator_map
self.resolver_error_handler = resolver_error_handler
logger.debug('Loading specification: %s', specification,
extra={'swagger_yaml': specification,
'base_path': base_path,
'arguments': arguments,
'auth_all_paths': auth_all_paths})
# Avoid validator having ability to modify specification
self.specification = Specification.load(specification, arguments=arguments)
logger.debug('Read specification', extra={'spec': self.specification})
self.options = ConnexionOptions(options, oas_version=self.specification.version)
logger.debug('Options Loaded',
extra={'swagger_ui': self.options.openapi_console_ui_available,
'swagger_path': self.options.openapi_console_ui_from_dir,
'swagger_url': self.options.openapi_console_ui_path})
self._set_base_path(base_path)
logger.debug('Security Definitions: %s', self.specification.security_definitions)
self.resolver = resolver or Resolver()
logger.debug('Validate Responses: %s', str(validate_responses))
self.validate_responses = validate_responses
@@ -179,14 +246,10 @@ class AbstractAPI(AbstractSpecAPI):
logger.debug('Pythonic params: %s', str(pythonic_params))
self.pythonic_params = pythonic_params
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
super().__init__(specification, base_path=base_path, arguments=arguments, options=options)
self.add_paths()
super().__init__(specification, base_path=base_path, arguments=arguments,
resolver=resolver, auth_all_paths=auth_all_paths,
resolver_error_handler=resolver_error_handler,
debug=debug, pass_context_arg_name=pass_context_arg_name, options=options)
if auth_all_paths:
self.add_auth_on_not_found(
@@ -200,11 +263,6 @@ class AbstractAPI(AbstractSpecAPI):
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
"""
@staticmethod
@abc.abstractmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create SecurityHandlerFactory to create all security check handlers """
def add_operation(self, path, method):
"""
Adds one operation to the api.
@@ -236,62 +294,6 @@ class AbstractAPI(AbstractSpecAPI):
)
self._add_operation_internal(method, path, operation)
@abc.abstractmethod
def _add_operation_internal(self, method, path, operation):
"""
Adds the operation according to the user framework in use.
It will be used to register the operation on the user framework router.
"""
def _add_resolver_error_handler(self, method, path, err):
"""
Adds a handler for ResolverError for the given method and path.
"""
operation = self.resolver_error_handler(
err,
security=self.specification.security,
security_definitions=self.specification.security_definitions
)
self._add_operation_internal(method, path, operation)
def add_paths(self, paths=None):
"""
Adds the paths defined in the specification as endpoints
:type paths: list
"""
paths = paths or self.specification.get('paths', dict())
for path, methods in paths.items():
logger.debug('Adding %s%s...', self.base_path, path)
for method in methods:
if method not in METHODS:
continue
try:
self.add_operation(path, method)
except ResolverError as err:
# If we have an error handler for resolver errors, add it as an operation.
# Otherwise treat it as any other error.
if self.resolver_error_handler is not None:
self._add_resolver_error_handler(method, path, err)
else:
self._handle_add_operation_error(path, method, err.exc_info)
except Exception:
# All other relevant exceptions should be handled as well.
self._handle_add_operation_error(path, method, sys.exc_info())
def _handle_add_operation_error(self, path, method, exc_info):
url = f'{self.base_path}{path}'
error_msg = 'Failed to add operation for {method} {url}'.format(
method=method.upper(),
url=url)
if self.debug:
logger.exception(error_msg)
else:
logger.error(error_msg)
_type, value, traceback = exc_info
raise value.with_traceback(traceback)
@classmethod
@abc.abstractmethod
def get_request(self, *args, **kwargs):

View File

@@ -7,6 +7,7 @@ import abc
import logging
import pathlib
from ..middleware import ConnexionMiddleware
from ..options import ConnexionOptions
from ..resolver import Resolver
@@ -16,7 +17,7 @@ logger = logging.getLogger('connexion.app')
class AbstractApp(metaclass=abc.ABCMeta):
def __init__(self, import_name, api_cls, port=None, specification_dir='',
host=None, server=None, server_args=None, arguments=None, auth_all_paths=False, debug=None,
resolver=None, options=None, skip_error_handlers=False):
resolver=None, options=None, skip_error_handlers=False, middlewares=None):
"""
:param import_name: the name of the application package
:type import_name: str
@@ -37,6 +38,8 @@ class AbstractApp(metaclass=abc.ABCMeta):
:param debug: include debugging information
:type debug: bool
:param resolver: Callable that maps operationID to a function
:param middlewares: Callable that maps operationID to a function
:type middlewares: list | None
"""
self.port = port
self.host = host
@@ -54,8 +57,12 @@ class AbstractApp(metaclass=abc.ABCMeta):
self.server = server
self.server_args = dict() if server_args is None else server_args
self.app = self.create_app()
self.middleware = self._apply_middleware()
if middlewares is None:
middlewares = ConnexionMiddleware.default_middlewares
self.middleware = self._apply_middleware(middlewares)
# we get our application root path to avoid duplicating logic
self.root_path = self.get_root_path()
@@ -80,7 +87,7 @@ class AbstractApp(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def _apply_middleware(self):
def _apply_middleware(self, middlewares):
"""
Apply middleware to application
"""

View File

@@ -23,6 +23,7 @@ logger = logging.getLogger('connexion.app')
class FlaskApp(AbstractApp):
def __init__(self, import_name, server='flask', extra_files=None, **kwargs):
"""
:param extra_files: additional files to be watched by the reloader, defaults to the swagger specs of added apis
@@ -41,8 +42,8 @@ class FlaskApp(AbstractApp):
app.url_map.converters['int'] = IntegerConverter
return app
def _apply_middleware(self):
middlewares = [*ConnexionMiddleware.default_middlewares,
def _apply_middleware(self, middlewares):
middlewares = [*middlewares,
a2wsgi.WSGIMiddleware]
middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)

View File

@@ -96,6 +96,17 @@ class BadRequestProblem(ProblemException):
super().__init__(status=400, title=title, detail=detail)
class NotFoundProblem(ProblemException):
description = (
'The requested URL was not found on the server. If you entered the URL manually please '
'check your spelling and try again.'
)
def __init__(self, title="Not Found", detail=description):
super().__init__(status=404, title=title, detail=detail)
class UnsupportedMediaTypeProblem(ProblemException):
def __init__(self, title="Unsupported Media Type", detail=None):

View File

@@ -2,6 +2,8 @@
This module defines interfaces for requests and responses used in Connexion for authentication,
validation, serialization, etc.
"""
from starlette.requests import Request as StarletteRequest
from starlette.responses import StreamingResponse as StarletteStreamingResponse
class ConnexionRequest:
@@ -52,3 +54,11 @@ class ConnexionResponse:
self.body = body
self.headers = headers or {}
self.is_streamed = is_streamed
class MiddlewareRequest(StarletteRequest):
"""Wraps starlette Request so it can easily be extended."""
class MiddlewareResponse(StarletteStreamingResponse):
"""Wraps starlette StreamingResponse so it can easily be extended."""

View File

@@ -1,2 +1,4 @@
from .abstract import AppMiddleware # NOQA
from .main import ConnexionMiddleware # NOQA
from .routing import RoutingMiddleware # NOQA
from .swagger_ui import SwaggerUIMiddleware # NOQA

View File

@@ -4,6 +4,8 @@ import typing as t
class AppMiddleware(abc.ABC):
"""Middlewares that need the APIs to be registered on them should inherit from this base
class"""
@abc.abstractmethod
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None:

View File

@@ -0,0 +1,33 @@
import json
from starlette.exceptions import \
ExceptionMiddleware as StarletteExceptionMiddleware
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from connexion.exceptions import problem
class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior."""
def http_exception(self, request: Request, exc: HTTPException) -> Response:
try:
headers = exc.headers
except AttributeError:
# Starlette < 0.19
headers = {}
connexion_response = problem(title=exc.detail,
detail=exc.detail,
status=exc.status_code,
headers=headers)
return Response(
content=json.dumps(connexion_response.body),
status_code=connexion_response.status_code,
media_type=connexion_response.mimetype,
headers=connexion_response.headers
)

View File

@@ -1,10 +1,11 @@
import pathlib
import typing as t
from starlette.exceptions import ExceptionMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.middleware.base import AppMiddleware
from connexion.middleware.abstract import AppMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
@@ -13,6 +14,7 @@ class ConnexionMiddleware:
default_middlewares = [
ExceptionMiddleware,
SwaggerUIMiddleware,
RoutingMiddleware,
]
def __init__(

View File

@@ -0,0 +1,170 @@
import pathlib
import typing as t
from contextlib import contextmanager
from contextvars import ContextVar
from starlette.requests import Request as StarletteRequest
from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractMinimalAPI
from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation, make_operation
from connexion.resolver import Resolver
CONNEXION_CONTEXT = 'connexion.context'
_scope_receive_send: ContextVar[tuple] = ContextVar('SCOPE_RECEIVE_SEND')
class MiddlewareResolver(Resolver):
def __init__(self, call_next: t.Callable) -> None:
"""Resolver that resolves each operation to the provided call_next function."""
super().__init__()
self.call_next = call_next
def resolve_function_from_operation_id(self, operation_id: str) -> t.Callable:
return self.call_next
class RoutingMiddleware(AppMiddleware):
def __init__(self, app: ASGIApp) -> None:
"""Middleware that resolves the Operation for an incoming request and attaches it to the
scope.
:param app: app to wrap in middleware.
"""
self.app = app
# Pass unknown routes to next app
self.router = Router(default=self.default_fn)
def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
**kwargs
) -> None:
"""Add an API to the router based on a OpenAPI spec.
:param specification: OpenAPI spec as dict or path to file.
:param base_path: Base path where to add this API.
:param arguments: Jinja arguments to replace in the spec.
"""
kwargs.pop("resolver", None)
resolver = MiddlewareResolver(self.create_call_next())
api = MiddlewareAPI(specification, base_path=base_path, arguments=arguments,
resolver=resolver, default=self.default_fn, **kwargs)
self.router.mount(api.base_path, app=api.router)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Route request to matching operation, and attach it to the scope before calling the
next app."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
_scope_receive_send.set((scope.copy(), receive, send))
# Needs to be set so starlette router throws exceptions instead of returning error responses
scope['app'] = self
try:
await self.router(scope, receive, send)
except ValueError:
raise NotFoundProblem
async def default_fn(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Callback to call next app as default when no matching route is found."""
original_scope, *_ = _scope_receive_send.get()
api_base_path = scope.get('root_path', '')[len(original_scope.get('root_path', '')):]
extensions = original_scope.setdefault('extensions', {})
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
connexion_context.update({
'api_base_path': api_base_path
})
await self.app(original_scope, receive, send)
def create_call_next(self):
async def call_next(
operation: AbstractOperation,
request: StarletteRequest = None
) -> None:
"""Attach operation to scope and pass it to the next app"""
scope, receive, send = _scope_receive_send.get()
api_base_path = request.scope.get('root_path', '')[len(scope.get('root_path', '')):]
extensions = scope.setdefault('extensions', {})
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
connexion_context.update({
'api_base_path': api_base_path,
'operation_id': operation.operation_id
})
return await self.app(scope, receive, send)
return call_next
class MiddlewareAPI(AbstractMinimalAPI):
def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
resolver: t.Optional[Resolver] = None,
default: ASGIApp = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
**kwargs
) -> None:
"""API implementation on top of Starlette Router for Connexion middleware."""
self.router = Router(default=default)
super().__init__(
specification,
base_path=base_path,
arguments=arguments,
resolver=resolver,
resolver_error_handler=resolver_error_handler,
debug=debug
)
def add_operation(self, path: str, method: str) -> None:
operation = make_operation(
self.specification,
self,
path,
method,
self.resolver
)
@contextmanager
def patch_operation_function():
"""Patch the operation function so no decorators are set in the middleware. This
should be cleaned up by separating the APIs and Operations between the App and
middleware"""
original_operation_function = AbstractOperation.function
AbstractOperation.function = operation._resolution.function
try:
yield
finally:
AbstractOperation.function = original_operation_function
with patch_operation_function():
self._add_operation_internal(method, path, operation)
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
self.router.add_route(path, operation.function, methods=[method])
@staticmethod
def make_security_handler_factory(pass_context_arg_name):
""" Create default SecurityHandlerFactory to create all security check handlers """
pass

View File

@@ -13,10 +13,9 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractSwaggerUIAPI
from connexion.jsonifier import JSONEncoder, Jsonifier
from connexion.middleware import AppMiddleware
from connexion.utils import yamldumper
from .base import AppMiddleware
logger = logging.getLogger('connexion.middleware.swagger_ui')

View File

@@ -27,7 +27,6 @@ def test_errors(problem_app):
error405 = json.loads(get_greeting.data.decode('utf-8', 'replace'))
assert error405['type'] == 'about:blank'
assert error405['title'] == 'Method Not Allowed'
assert error405['detail'] == 'The method is not allowed for the requested URL.'
assert error405['status'] == 405
assert 'instance' not in error405

View File

@@ -1,7 +1,6 @@
import json
import logging
import pathlib
import sys
import pytest
from connexion import App
@@ -136,7 +135,7 @@ def json_datetime_dir():
return FIXTURES_FOLDER / 'datetime_support'
def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs):
def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', middlewares=None, **kwargs):
debug = True
if 'debug' in kwargs:
debug = kwargs['debug']
@@ -145,6 +144,7 @@ def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs):
cnx_app = App(__name__,
port=5001,
specification_dir=FIXTURES_FOLDER / api_spec_folder,
middlewares=middlewares,
debug=debug)
cnx_app.add_api(spec_file, **kwargs)
@@ -254,4 +254,3 @@ def unordered_definition_app(request):
def bad_operations_app(request):
return build_app_from_fixture('bad_operations', request.param,
resolver_error=501)

View File

@@ -13,15 +13,3 @@ def fake_json_auth(token, required_scopes=None):
return json.loads(token)
except ValueError:
return None
async def async_basic_auth(username, password, required_scopes=None, request=None):
return fake_basic_auth(username, password, required_scopes)
async def async_json_auth(token, required_scopes=None, request=None):
return fake_json_auth(token, required_scopes)
async def async_scope_validation(required_scopes, token_scopes, request):
return required_scopes == token_scopes

44
tests/test_middleware.py Normal file
View File

@@ -0,0 +1,44 @@
import pytest
from connexion.middleware import ConnexionMiddleware
from connexion.middleware.routing import CONNEXION_CONTEXT
from starlette.datastructures import MutableHeaders
from conftest import SPECS, build_app_from_fixture
class TestMiddleware:
"""Middleware to check if operation is accessible on scope."""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
operation_id = scope['extensions'][CONNEXION_CONTEXT]['operation_id']
async def patched_send(message):
if message["type"] != "http.response.start":
await send(message)
return
message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers["operation_id"] = operation_id
await send(message)
await self.app(scope, receive, patched_send)
@pytest.fixture(scope="session", params=SPECS)
def middleware_app(request):
middlewares = ConnexionMiddleware.default_middlewares + [TestMiddleware]
return build_app_from_fixture('simple', request.param, middlewares=middlewares)
def test_routing_middleware(middleware_app):
app_client = middleware_app.app.test_client()
response = app_client.post("/v1.0/greeting/robbe")
assert response.headers.get('operation_id') == 'fakeapi.hello.post_greeting', \
response.status_code