mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-09 20:37:46 +00:00
Add routing middleware (#1497)
* Add routing middleware Factor out starlette BaseHTTPMiddleware Fix exceptions for starlette < 0.19 Fix docstring formatting Rename middleware/base.py to abstract.py Rework routing middleware * Clean up abstract API docstrings * Move connexion context into extensions * Allow empty middleware list
This commit is contained in:
@@ -13,4 +13,5 @@ on the framework app.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from .abstract import AbstractAPI, AbstractSwaggerUIAPI # NOQA
|
from .abstract import (AbstractAPI, AbstractMinimalAPI, # NOQA
|
||||||
|
AbstractSwaggerUIAPI)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from ..exceptions import ResolverError
|
|||||||
from ..http_facts import METHODS
|
from ..http_facts import METHODS
|
||||||
from ..jsonifier import Jsonifier
|
from ..jsonifier import Jsonifier
|
||||||
from ..lifecycle import ConnexionResponse
|
from ..lifecycle import ConnexionResponse
|
||||||
from ..operations import make_operation
|
from ..operations import AbstractOperation, make_operation
|
||||||
from ..options import ConnexionOptions
|
from ..options import ConnexionOptions
|
||||||
from ..resolver import Resolver
|
from ..resolver import Resolver
|
||||||
from ..spec import Specification
|
from ..spec import Specification
|
||||||
@@ -43,7 +43,14 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""Base API class with only minimal behavior related to the specification."""
|
"""Base API class with only minimal behavior related to the specification.
|
||||||
|
|
||||||
|
:param specification: OpenAPI specification. Can be provided either as dict, or as path
|
||||||
|
to file.
|
||||||
|
:param base_path: Base path to host the API.
|
||||||
|
:param arguments: Jinja arguments to resolve in specification.
|
||||||
|
:param options: New style options dictionary.
|
||||||
|
"""
|
||||||
logger.debug('Loading specification: %s', specification,
|
logger.debug('Loading specification: %s', specification,
|
||||||
extra={'swagger_yaml': specification,
|
extra={'swagger_yaml': specification,
|
||||||
'base_path': base_path,
|
'base_path': base_path,
|
||||||
@@ -109,7 +116,105 @@ class AbstractSwaggerUIAPI(AbstractSpecAPI):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AbstractAPI(AbstractSpecAPI):
|
class AbstractMinimalAPI(AbstractSpecAPI):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
resolver: t.Optional[Resolver] = None,
|
||||||
|
resolver_error_handler: t.Optional[t.Callable] = None,
|
||||||
|
debug: bool = False,
|
||||||
|
pass_context_arg_name: t.Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Minimal interface of an API, with only functionality related to routing.
|
||||||
|
|
||||||
|
:param resolver: Callable that maps operationID to a function
|
||||||
|
:param resolver_error_handler: Callable that generates an Operation used for handling
|
||||||
|
ResolveErrors
|
||||||
|
:param debug: Flag to run in debug mode
|
||||||
|
:param pass_context_arg_name: If not None URL request handling functions with an argument
|
||||||
|
matching this name will be passed the framework's request context.
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.debug = debug
|
||||||
|
self.resolver_error_handler = resolver_error_handler
|
||||||
|
|
||||||
|
logger.debug('Security Definitions: %s', self.specification.security_definitions)
|
||||||
|
|
||||||
|
self.resolver = resolver or Resolver()
|
||||||
|
|
||||||
|
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
|
||||||
|
self.pass_context_arg_name = pass_context_arg_name
|
||||||
|
|
||||||
|
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
|
||||||
|
|
||||||
|
self.add_paths()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def make_security_handler_factory(pass_context_arg_name):
|
||||||
|
""" Create SecurityHandlerFactory to create all security check handlers """
|
||||||
|
|
||||||
|
def add_paths(self, paths: t.Optional[dict] = None) -> None:
|
||||||
|
"""
|
||||||
|
Adds the paths defined in the specification as endpoints
|
||||||
|
"""
|
||||||
|
paths = paths or self.specification.get('paths', dict())
|
||||||
|
for path, methods in paths.items():
|
||||||
|
logger.debug('Adding %s%s...', self.base_path, path)
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
if method not in METHODS:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
self.add_operation(path, method)
|
||||||
|
except ResolverError as err:
|
||||||
|
# If we have an error handler for resolver errors, add it as an operation.
|
||||||
|
# Otherwise treat it as any other error.
|
||||||
|
if self.resolver_error_handler is not None:
|
||||||
|
self._add_resolver_error_handler(method, path, err)
|
||||||
|
else:
|
||||||
|
self._handle_add_operation_error(path, method, err.exc_info)
|
||||||
|
except Exception:
|
||||||
|
# All other relevant exceptions should be handled as well.
|
||||||
|
self._handle_add_operation_error(path, method, sys.exc_info())
|
||||||
|
|
||||||
|
def add_operation(self, path: str, method: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
|
||||||
|
"""
|
||||||
|
Adds the operation according to the user framework in use.
|
||||||
|
It will be used to register the operation on the user framework router.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _add_resolver_error_handler(self, method: str, path: str, err: ResolverError):
|
||||||
|
"""
|
||||||
|
Adds a handler for ResolverError for the given method and path.
|
||||||
|
"""
|
||||||
|
operation = self.resolver_error_handler(
|
||||||
|
err,
|
||||||
|
security=self.specification.security,
|
||||||
|
security_definitions=self.specification.security_definitions
|
||||||
|
)
|
||||||
|
self._add_operation_internal(method, path, operation)
|
||||||
|
|
||||||
|
def _handle_add_operation_error(self, path: str, method: str, exc_info: tuple):
|
||||||
|
url = f'{self.base_path}{path}'
|
||||||
|
error_msg = 'Failed to add operation for {method} {url}'.format(
|
||||||
|
method=method.upper(),
|
||||||
|
url=url)
|
||||||
|
if self.debug:
|
||||||
|
logger.exception(error_msg)
|
||||||
|
else:
|
||||||
|
logger.error(error_msg)
|
||||||
|
_type, value, traceback = exc_info
|
||||||
|
raise value.with_traceback(traceback)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractAPI(AbstractMinimalAPI, metaclass=AbstractAPIMeta):
|
||||||
"""
|
"""
|
||||||
Defines an abstract interface for a Swagger API
|
Defines an abstract interface for a Swagger API
|
||||||
"""
|
"""
|
||||||
@@ -120,55 +225,17 @@ class AbstractAPI(AbstractSpecAPI):
|
|||||||
validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None,
|
validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:type specification: pathlib.Path | dict
|
|
||||||
:type base_path: str | None
|
|
||||||
:type arguments: dict | None
|
|
||||||
:type validate_responses: bool
|
:type validate_responses: bool
|
||||||
:type strict_validation: bool
|
:type strict_validation: bool
|
||||||
:type auth_all_paths: bool
|
:type auth_all_paths: bool
|
||||||
:type debug: bool
|
|
||||||
:param validator_map: Custom validators for the types "parameter", "body" and "response".
|
:param validator_map: Custom validators for the types "parameter", "body" and "response".
|
||||||
:type validator_map: dict
|
:type validator_map: dict
|
||||||
:param resolver: Callable that maps operationID to a function
|
|
||||||
:param resolver_error_handler: If given, a callable that generates an
|
|
||||||
Operation used for handling ResolveErrors
|
|
||||||
:type resolver_error_handler: callable | None
|
:type resolver_error_handler: callable | None
|
||||||
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended
|
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended
|
||||||
to any shadowed built-ins
|
to any shadowed built-ins
|
||||||
:type pythonic_params: bool
|
:type pythonic_params: bool
|
||||||
:param options: New style options dictionary.
|
|
||||||
:type options: dict | None
|
|
||||||
:param pass_context_arg_name: If not None URL request handling functions with an argument matching this name
|
|
||||||
will be passed the framework's request context.
|
|
||||||
:type pass_context_arg_name: str | None
|
|
||||||
"""
|
"""
|
||||||
self.debug = debug
|
|
||||||
self.validator_map = validator_map
|
self.validator_map = validator_map
|
||||||
self.resolver_error_handler = resolver_error_handler
|
|
||||||
|
|
||||||
logger.debug('Loading specification: %s', specification,
|
|
||||||
extra={'swagger_yaml': specification,
|
|
||||||
'base_path': base_path,
|
|
||||||
'arguments': arguments,
|
|
||||||
'auth_all_paths': auth_all_paths})
|
|
||||||
|
|
||||||
# Avoid validator having ability to modify specification
|
|
||||||
self.specification = Specification.load(specification, arguments=arguments)
|
|
||||||
|
|
||||||
logger.debug('Read specification', extra={'spec': self.specification})
|
|
||||||
|
|
||||||
self.options = ConnexionOptions(options, oas_version=self.specification.version)
|
|
||||||
|
|
||||||
logger.debug('Options Loaded',
|
|
||||||
extra={'swagger_ui': self.options.openapi_console_ui_available,
|
|
||||||
'swagger_path': self.options.openapi_console_ui_from_dir,
|
|
||||||
'swagger_url': self.options.openapi_console_ui_path})
|
|
||||||
|
|
||||||
self._set_base_path(base_path)
|
|
||||||
|
|
||||||
logger.debug('Security Definitions: %s', self.specification.security_definitions)
|
|
||||||
|
|
||||||
self.resolver = resolver or Resolver()
|
|
||||||
|
|
||||||
logger.debug('Validate Responses: %s', str(validate_responses))
|
logger.debug('Validate Responses: %s', str(validate_responses))
|
||||||
self.validate_responses = validate_responses
|
self.validate_responses = validate_responses
|
||||||
@@ -179,14 +246,10 @@ class AbstractAPI(AbstractSpecAPI):
|
|||||||
logger.debug('Pythonic params: %s', str(pythonic_params))
|
logger.debug('Pythonic params: %s', str(pythonic_params))
|
||||||
self.pythonic_params = pythonic_params
|
self.pythonic_params = pythonic_params
|
||||||
|
|
||||||
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
|
super().__init__(specification, base_path=base_path, arguments=arguments,
|
||||||
self.pass_context_arg_name = pass_context_arg_name
|
resolver=resolver, auth_all_paths=auth_all_paths,
|
||||||
|
resolver_error_handler=resolver_error_handler,
|
||||||
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
|
debug=debug, pass_context_arg_name=pass_context_arg_name, options=options)
|
||||||
|
|
||||||
super().__init__(specification, base_path=base_path, arguments=arguments, options=options)
|
|
||||||
|
|
||||||
self.add_paths()
|
|
||||||
|
|
||||||
if auth_all_paths:
|
if auth_all_paths:
|
||||||
self.add_auth_on_not_found(
|
self.add_auth_on_not_found(
|
||||||
@@ -200,11 +263,6 @@ class AbstractAPI(AbstractSpecAPI):
|
|||||||
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
|
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abc.abstractmethod
|
|
||||||
def make_security_handler_factory(pass_context_arg_name):
|
|
||||||
""" Create SecurityHandlerFactory to create all security check handlers """
|
|
||||||
|
|
||||||
def add_operation(self, path, method):
|
def add_operation(self, path, method):
|
||||||
"""
|
"""
|
||||||
Adds one operation to the api.
|
Adds one operation to the api.
|
||||||
@@ -236,62 +294,6 @@ class AbstractAPI(AbstractSpecAPI):
|
|||||||
)
|
)
|
||||||
self._add_operation_internal(method, path, operation)
|
self._add_operation_internal(method, path, operation)
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _add_operation_internal(self, method, path, operation):
|
|
||||||
"""
|
|
||||||
Adds the operation according to the user framework in use.
|
|
||||||
It will be used to register the operation on the user framework router.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _add_resolver_error_handler(self, method, path, err):
|
|
||||||
"""
|
|
||||||
Adds a handler for ResolverError for the given method and path.
|
|
||||||
"""
|
|
||||||
operation = self.resolver_error_handler(
|
|
||||||
err,
|
|
||||||
security=self.specification.security,
|
|
||||||
security_definitions=self.specification.security_definitions
|
|
||||||
)
|
|
||||||
self._add_operation_internal(method, path, operation)
|
|
||||||
|
|
||||||
def add_paths(self, paths=None):
|
|
||||||
"""
|
|
||||||
Adds the paths defined in the specification as endpoints
|
|
||||||
|
|
||||||
:type paths: list
|
|
||||||
"""
|
|
||||||
paths = paths or self.specification.get('paths', dict())
|
|
||||||
for path, methods in paths.items():
|
|
||||||
logger.debug('Adding %s%s...', self.base_path, path)
|
|
||||||
|
|
||||||
for method in methods:
|
|
||||||
if method not in METHODS:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
self.add_operation(path, method)
|
|
||||||
except ResolverError as err:
|
|
||||||
# If we have an error handler for resolver errors, add it as an operation.
|
|
||||||
# Otherwise treat it as any other error.
|
|
||||||
if self.resolver_error_handler is not None:
|
|
||||||
self._add_resolver_error_handler(method, path, err)
|
|
||||||
else:
|
|
||||||
self._handle_add_operation_error(path, method, err.exc_info)
|
|
||||||
except Exception:
|
|
||||||
# All other relevant exceptions should be handled as well.
|
|
||||||
self._handle_add_operation_error(path, method, sys.exc_info())
|
|
||||||
|
|
||||||
def _handle_add_operation_error(self, path, method, exc_info):
|
|
||||||
url = f'{self.base_path}{path}'
|
|
||||||
error_msg = 'Failed to add operation for {method} {url}'.format(
|
|
||||||
method=method.upper(),
|
|
||||||
url=url)
|
|
||||||
if self.debug:
|
|
||||||
logger.exception(error_msg)
|
|
||||||
else:
|
|
||||||
logger.error(error_msg)
|
|
||||||
_type, value, traceback = exc_info
|
|
||||||
raise value.with_traceback(traceback)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_request(self, *args, **kwargs):
|
def get_request(self, *args, **kwargs):
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import abc
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
|
from ..middleware import ConnexionMiddleware
|
||||||
from ..options import ConnexionOptions
|
from ..options import ConnexionOptions
|
||||||
from ..resolver import Resolver
|
from ..resolver import Resolver
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ logger = logging.getLogger('connexion.app')
|
|||||||
class AbstractApp(metaclass=abc.ABCMeta):
|
class AbstractApp(metaclass=abc.ABCMeta):
|
||||||
def __init__(self, import_name, api_cls, port=None, specification_dir='',
|
def __init__(self, import_name, api_cls, port=None, specification_dir='',
|
||||||
host=None, server=None, server_args=None, arguments=None, auth_all_paths=False, debug=None,
|
host=None, server=None, server_args=None, arguments=None, auth_all_paths=False, debug=None,
|
||||||
resolver=None, options=None, skip_error_handlers=False):
|
resolver=None, options=None, skip_error_handlers=False, middlewares=None):
|
||||||
"""
|
"""
|
||||||
:param import_name: the name of the application package
|
:param import_name: the name of the application package
|
||||||
:type import_name: str
|
:type import_name: str
|
||||||
@@ -37,6 +38,8 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
|||||||
:param debug: include debugging information
|
:param debug: include debugging information
|
||||||
:type debug: bool
|
:type debug: bool
|
||||||
:param resolver: Callable that maps operationID to a function
|
:param resolver: Callable that maps operationID to a function
|
||||||
|
:param middlewares: Callable that maps operationID to a function
|
||||||
|
:type middlewares: list | None
|
||||||
"""
|
"""
|
||||||
self.port = port
|
self.port = port
|
||||||
self.host = host
|
self.host = host
|
||||||
@@ -54,8 +57,12 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
self.server = server
|
self.server = server
|
||||||
self.server_args = dict() if server_args is None else server_args
|
self.server_args = dict() if server_args is None else server_args
|
||||||
|
|
||||||
self.app = self.create_app()
|
self.app = self.create_app()
|
||||||
self.middleware = self._apply_middleware()
|
|
||||||
|
if middlewares is None:
|
||||||
|
middlewares = ConnexionMiddleware.default_middlewares
|
||||||
|
self.middleware = self._apply_middleware(middlewares)
|
||||||
|
|
||||||
# we get our application root path to avoid duplicating logic
|
# we get our application root path to avoid duplicating logic
|
||||||
self.root_path = self.get_root_path()
|
self.root_path = self.get_root_path()
|
||||||
@@ -80,7 +87,7 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _apply_middleware(self):
|
def _apply_middleware(self, middlewares):
|
||||||
"""
|
"""
|
||||||
Apply middleware to application
|
Apply middleware to application
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ logger = logging.getLogger('connexion.app')
|
|||||||
|
|
||||||
|
|
||||||
class FlaskApp(AbstractApp):
|
class FlaskApp(AbstractApp):
|
||||||
|
|
||||||
def __init__(self, import_name, server='flask', extra_files=None, **kwargs):
|
def __init__(self, import_name, server='flask', extra_files=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
:param extra_files: additional files to be watched by the reloader, defaults to the swagger specs of added apis
|
:param extra_files: additional files to be watched by the reloader, defaults to the swagger specs of added apis
|
||||||
@@ -41,8 +42,8 @@ class FlaskApp(AbstractApp):
|
|||||||
app.url_map.converters['int'] = IntegerConverter
|
app.url_map.converters['int'] = IntegerConverter
|
||||||
return app
|
return app
|
||||||
|
|
||||||
def _apply_middleware(self):
|
def _apply_middleware(self, middlewares):
|
||||||
middlewares = [*ConnexionMiddleware.default_middlewares,
|
middlewares = [*middlewares,
|
||||||
a2wsgi.WSGIMiddleware]
|
a2wsgi.WSGIMiddleware]
|
||||||
middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)
|
middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)
|
||||||
|
|
||||||
|
|||||||
@@ -96,6 +96,17 @@ class BadRequestProblem(ProblemException):
|
|||||||
super().__init__(status=400, title=title, detail=detail)
|
super().__init__(status=400, title=title, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundProblem(ProblemException):
|
||||||
|
|
||||||
|
description = (
|
||||||
|
'The requested URL was not found on the server. If you entered the URL manually please '
|
||||||
|
'check your spelling and try again.'
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, title="Not Found", detail=description):
|
||||||
|
super().__init__(status=404, title=title, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedMediaTypeProblem(ProblemException):
|
class UnsupportedMediaTypeProblem(ProblemException):
|
||||||
|
|
||||||
def __init__(self, title="Unsupported Media Type", detail=None):
|
def __init__(self, title="Unsupported Media Type", detail=None):
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
This module defines interfaces for requests and responses used in Connexion for authentication,
|
This module defines interfaces for requests and responses used in Connexion for authentication,
|
||||||
validation, serialization, etc.
|
validation, serialization, etc.
|
||||||
"""
|
"""
|
||||||
|
from starlette.requests import Request as StarletteRequest
|
||||||
|
from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
||||||
|
|
||||||
|
|
||||||
class ConnexionRequest:
|
class ConnexionRequest:
|
||||||
@@ -52,3 +54,11 @@ class ConnexionResponse:
|
|||||||
self.body = body
|
self.body = body
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
self.is_streamed = is_streamed
|
self.is_streamed = is_streamed
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareRequest(StarletteRequest):
|
||||||
|
"""Wraps starlette Request so it can easily be extended."""
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareResponse(StarletteStreamingResponse):
|
||||||
|
"""Wraps starlette StreamingResponse so it can easily be extended."""
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
|
from .abstract import AppMiddleware # NOQA
|
||||||
from .main import ConnexionMiddleware # NOQA
|
from .main import ConnexionMiddleware # NOQA
|
||||||
|
from .routing import RoutingMiddleware # NOQA
|
||||||
from .swagger_ui import SwaggerUIMiddleware # NOQA
|
from .swagger_ui import SwaggerUIMiddleware # NOQA
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import typing as t
|
|||||||
|
|
||||||
|
|
||||||
class AppMiddleware(abc.ABC):
|
class AppMiddleware(abc.ABC):
|
||||||
|
"""Middlewares that need the APIs to be registered on them should inherit from this base
|
||||||
|
class"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None:
|
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None:
|
||||||
33
connexion/middleware/exceptions.py
Normal file
33
connexion/middleware/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from starlette.exceptions import \
|
||||||
|
ExceptionMiddleware as StarletteExceptionMiddleware
|
||||||
|
from starlette.exceptions import HTTPException
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
from connexion.exceptions import problem
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionMiddleware(StarletteExceptionMiddleware):
|
||||||
|
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
|
||||||
|
existing connexion behavior."""
|
||||||
|
|
||||||
|
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
||||||
|
try:
|
||||||
|
headers = exc.headers
|
||||||
|
except AttributeError:
|
||||||
|
# Starlette < 0.19
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
connexion_response = problem(title=exc.detail,
|
||||||
|
detail=exc.detail,
|
||||||
|
status=exc.status_code,
|
||||||
|
headers=headers)
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(connexion_response.body),
|
||||||
|
status_code=connexion_response.status_code,
|
||||||
|
media_type=connexion_response.mimetype,
|
||||||
|
headers=connexion_response.headers
|
||||||
|
)
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from starlette.exceptions import ExceptionMiddleware
|
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
from connexion.middleware.base import AppMiddleware
|
from connexion.middleware.abstract import AppMiddleware
|
||||||
|
from connexion.middleware.exceptions import ExceptionMiddleware
|
||||||
|
from connexion.middleware.routing import RoutingMiddleware
|
||||||
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
||||||
|
|
||||||
|
|
||||||
@@ -13,6 +14,7 @@ class ConnexionMiddleware:
|
|||||||
default_middlewares = [
|
default_middlewares = [
|
||||||
ExceptionMiddleware,
|
ExceptionMiddleware,
|
||||||
SwaggerUIMiddleware,
|
SwaggerUIMiddleware,
|
||||||
|
RoutingMiddleware,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
170
connexion/middleware/routing.py
Normal file
170
connexion/middleware/routing.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import pathlib
|
||||||
|
import typing as t
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from starlette.requests import Request as StarletteRequest
|
||||||
|
from starlette.routing import Router
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
from connexion.apis import AbstractMinimalAPI
|
||||||
|
from connexion.exceptions import NotFoundProblem
|
||||||
|
from connexion.middleware import AppMiddleware
|
||||||
|
from connexion.operations import AbstractOperation, make_operation
|
||||||
|
from connexion.resolver import Resolver
|
||||||
|
|
||||||
|
CONNEXION_CONTEXT = 'connexion.context'
|
||||||
|
|
||||||
|
|
||||||
|
_scope_receive_send: ContextVar[tuple] = ContextVar('SCOPE_RECEIVE_SEND')
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareResolver(Resolver):
|
||||||
|
|
||||||
|
def __init__(self, call_next: t.Callable) -> None:
|
||||||
|
"""Resolver that resolves each operation to the provided call_next function."""
|
||||||
|
super().__init__()
|
||||||
|
self.call_next = call_next
|
||||||
|
|
||||||
|
def resolve_function_from_operation_id(self, operation_id: str) -> t.Callable:
|
||||||
|
return self.call_next
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingMiddleware(AppMiddleware):
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
"""Middleware that resolves the Operation for an incoming request and attaches it to the
|
||||||
|
scope.
|
||||||
|
|
||||||
|
:param app: app to wrap in middleware.
|
||||||
|
"""
|
||||||
|
self.app = app
|
||||||
|
# Pass unknown routes to next app
|
||||||
|
self.router = Router(default=self.default_fn)
|
||||||
|
|
||||||
|
def add_api(
|
||||||
|
self,
|
||||||
|
specification: t.Union[pathlib.Path, str, dict],
|
||||||
|
base_path: t.Optional[str] = None,
|
||||||
|
arguments: t.Optional[dict] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Add an API to the router based on a OpenAPI spec.
|
||||||
|
|
||||||
|
:param specification: OpenAPI spec as dict or path to file.
|
||||||
|
:param base_path: Base path where to add this API.
|
||||||
|
:param arguments: Jinja arguments to replace in the spec.
|
||||||
|
"""
|
||||||
|
kwargs.pop("resolver", None)
|
||||||
|
resolver = MiddlewareResolver(self.create_call_next())
|
||||||
|
api = MiddlewareAPI(specification, base_path=base_path, arguments=arguments,
|
||||||
|
resolver=resolver, default=self.default_fn, **kwargs)
|
||||||
|
self.router.mount(api.base_path, app=api.router)
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
"""Route request to matching operation, and attach it to the scope before calling the
|
||||||
|
next app."""
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
_scope_receive_send.set((scope.copy(), receive, send))
|
||||||
|
|
||||||
|
# Needs to be set so starlette router throws exceptions instead of returning error responses
|
||||||
|
scope['app'] = self
|
||||||
|
try:
|
||||||
|
await self.router(scope, receive, send)
|
||||||
|
except ValueError:
|
||||||
|
raise NotFoundProblem
|
||||||
|
|
||||||
|
async def default_fn(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
"""Callback to call next app as default when no matching route is found."""
|
||||||
|
original_scope, *_ = _scope_receive_send.get()
|
||||||
|
|
||||||
|
api_base_path = scope.get('root_path', '')[len(original_scope.get('root_path', '')):]
|
||||||
|
|
||||||
|
extensions = original_scope.setdefault('extensions', {})
|
||||||
|
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
|
||||||
|
connexion_context.update({
|
||||||
|
'api_base_path': api_base_path
|
||||||
|
})
|
||||||
|
await self.app(original_scope, receive, send)
|
||||||
|
|
||||||
|
def create_call_next(self):
|
||||||
|
|
||||||
|
async def call_next(
|
||||||
|
operation: AbstractOperation,
|
||||||
|
request: StarletteRequest = None
|
||||||
|
) -> None:
|
||||||
|
"""Attach operation to scope and pass it to the next app"""
|
||||||
|
scope, receive, send = _scope_receive_send.get()
|
||||||
|
|
||||||
|
api_base_path = request.scope.get('root_path', '')[len(scope.get('root_path', '')):]
|
||||||
|
|
||||||
|
extensions = scope.setdefault('extensions', {})
|
||||||
|
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
|
||||||
|
connexion_context.update({
|
||||||
|
'api_base_path': api_base_path,
|
||||||
|
'operation_id': operation.operation_id
|
||||||
|
})
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
return call_next
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareAPI(AbstractMinimalAPI):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
specification: t.Union[pathlib.Path, str, dict],
|
||||||
|
base_path: t.Optional[str] = None,
|
||||||
|
arguments: t.Optional[dict] = None,
|
||||||
|
resolver: t.Optional[Resolver] = None,
|
||||||
|
default: ASGIApp = None,
|
||||||
|
resolver_error_handler: t.Optional[t.Callable] = None,
|
||||||
|
debug: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""API implementation on top of Starlette Router for Connexion middleware."""
|
||||||
|
self.router = Router(default=default)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
specification,
|
||||||
|
base_path=base_path,
|
||||||
|
arguments=arguments,
|
||||||
|
resolver=resolver,
|
||||||
|
resolver_error_handler=resolver_error_handler,
|
||||||
|
debug=debug
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_operation(self, path: str, method: str) -> None:
|
||||||
|
operation = make_operation(
|
||||||
|
self.specification,
|
||||||
|
self,
|
||||||
|
path,
|
||||||
|
method,
|
||||||
|
self.resolver
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_operation_function():
|
||||||
|
"""Patch the operation function so no decorators are set in the middleware. This
|
||||||
|
should be cleaned up by separating the APIs and Operations between the App and
|
||||||
|
middleware"""
|
||||||
|
original_operation_function = AbstractOperation.function
|
||||||
|
AbstractOperation.function = operation._resolution.function
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
AbstractOperation.function = original_operation_function
|
||||||
|
|
||||||
|
with patch_operation_function():
|
||||||
|
self._add_operation_internal(method, path, operation)
|
||||||
|
|
||||||
|
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
|
||||||
|
self.router.add_route(path, operation.function, methods=[method])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_security_handler_factory(pass_context_arg_name):
|
||||||
|
""" Create default SecurityHandlerFactory to create all security check handlers """
|
||||||
|
pass
|
||||||
@@ -13,10 +13,9 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
|||||||
|
|
||||||
from connexion.apis import AbstractSwaggerUIAPI
|
from connexion.apis import AbstractSwaggerUIAPI
|
||||||
from connexion.jsonifier import JSONEncoder, Jsonifier
|
from connexion.jsonifier import JSONEncoder, Jsonifier
|
||||||
|
from connexion.middleware import AppMiddleware
|
||||||
from connexion.utils import yamldumper
|
from connexion.utils import yamldumper
|
||||||
|
|
||||||
from .base import AppMiddleware
|
|
||||||
|
|
||||||
logger = logging.getLogger('connexion.middleware.swagger_ui')
|
logger = logging.getLogger('connexion.middleware.swagger_ui')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ def test_errors(problem_app):
|
|||||||
error405 = json.loads(get_greeting.data.decode('utf-8', 'replace'))
|
error405 = json.loads(get_greeting.data.decode('utf-8', 'replace'))
|
||||||
assert error405['type'] == 'about:blank'
|
assert error405['type'] == 'about:blank'
|
||||||
assert error405['title'] == 'Method Not Allowed'
|
assert error405['title'] == 'Method Not Allowed'
|
||||||
assert error405['detail'] == 'The method is not allowed for the requested URL.'
|
|
||||||
assert error405['status'] == 405
|
assert error405['status'] == 405
|
||||||
assert 'instance' not in error405
|
assert 'instance' not in error405
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from connexion import App
|
from connexion import App
|
||||||
@@ -136,7 +135,7 @@ def json_datetime_dir():
|
|||||||
return FIXTURES_FOLDER / 'datetime_support'
|
return FIXTURES_FOLDER / 'datetime_support'
|
||||||
|
|
||||||
|
|
||||||
def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs):
|
def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', middlewares=None, **kwargs):
|
||||||
debug = True
|
debug = True
|
||||||
if 'debug' in kwargs:
|
if 'debug' in kwargs:
|
||||||
debug = kwargs['debug']
|
debug = kwargs['debug']
|
||||||
@@ -145,6 +144,7 @@ def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs):
|
|||||||
cnx_app = App(__name__,
|
cnx_app = App(__name__,
|
||||||
port=5001,
|
port=5001,
|
||||||
specification_dir=FIXTURES_FOLDER / api_spec_folder,
|
specification_dir=FIXTURES_FOLDER / api_spec_folder,
|
||||||
|
middlewares=middlewares,
|
||||||
debug=debug)
|
debug=debug)
|
||||||
|
|
||||||
cnx_app.add_api(spec_file, **kwargs)
|
cnx_app.add_api(spec_file, **kwargs)
|
||||||
@@ -254,4 +254,3 @@ def unordered_definition_app(request):
|
|||||||
def bad_operations_app(request):
|
def bad_operations_app(request):
|
||||||
return build_app_from_fixture('bad_operations', request.param,
|
return build_app_from_fixture('bad_operations', request.param,
|
||||||
resolver_error=501)
|
resolver_error=501)
|
||||||
|
|
||||||
|
|||||||
@@ -13,15 +13,3 @@ def fake_json_auth(token, required_scopes=None):
|
|||||||
return json.loads(token)
|
return json.loads(token)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def async_basic_auth(username, password, required_scopes=None, request=None):
|
|
||||||
return fake_basic_auth(username, password, required_scopes)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_json_auth(token, required_scopes=None, request=None):
|
|
||||||
return fake_json_auth(token, required_scopes)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_scope_validation(required_scopes, token_scopes, request):
|
|
||||||
return required_scopes == token_scopes
|
|
||||||
|
|||||||
44
tests/test_middleware.py
Normal file
44
tests/test_middleware.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import pytest
|
||||||
|
from connexion.middleware import ConnexionMiddleware
|
||||||
|
from connexion.middleware.routing import CONNEXION_CONTEXT
|
||||||
|
from starlette.datastructures import MutableHeaders
|
||||||
|
|
||||||
|
from conftest import SPECS, build_app_from_fixture
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddleware:
|
||||||
|
"""Middleware to check if operation is accessible on scope."""
|
||||||
|
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
operation_id = scope['extensions'][CONNEXION_CONTEXT]['operation_id']
|
||||||
|
|
||||||
|
async def patched_send(message):
|
||||||
|
if message["type"] != "http.response.start":
|
||||||
|
await send(message)
|
||||||
|
return
|
||||||
|
|
||||||
|
message.setdefault("headers", [])
|
||||||
|
headers = MutableHeaders(scope=message)
|
||||||
|
headers["operation_id"] = operation_id
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
await self.app(scope, receive, patched_send)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", params=SPECS)
|
||||||
|
def middleware_app(request):
|
||||||
|
middlewares = ConnexionMiddleware.default_middlewares + [TestMiddleware]
|
||||||
|
return build_app_from_fixture('simple', request.param, middlewares=middlewares)
|
||||||
|
|
||||||
|
|
||||||
|
def test_routing_middleware(middleware_app):
|
||||||
|
app_client = middleware_app.app.test_client()
|
||||||
|
|
||||||
|
response = app_client.post("/v1.0/greeting/robbe")
|
||||||
|
|
||||||
|
assert response.headers.get('operation_id') == 'fakeapi.hello.post_greeting', \
|
||||||
|
response.status_code
|
||||||
Reference in New Issue
Block a user