mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 04:19:26 +00:00
Add routing middleware (#1497)
* Add routing middleware Factor out starlette BaseHTTPMiddleware Fix exceptions for starlette < 0.19 Fix docstring formatting Rename middleware/base.py to abstract.py Rework routing middleware * Clean up abstract API docstrings * Move connexion context into extensions * Allow empty middleware list
This commit is contained in:
@@ -13,4 +13,5 @@ on the framework app.
|
||||
"""
|
||||
|
||||
|
||||
from .abstract import AbstractAPI, AbstractSwaggerUIAPI # NOQA
|
||||
from .abstract import (AbstractAPI, AbstractMinimalAPI, # NOQA
|
||||
AbstractSwaggerUIAPI)
|
||||
|
||||
@@ -14,7 +14,7 @@ from ..exceptions import ResolverError
|
||||
from ..http_facts import METHODS
|
||||
from ..jsonifier import Jsonifier
|
||||
from ..lifecycle import ConnexionResponse
|
||||
from ..operations import make_operation
|
||||
from ..operations import AbstractOperation, make_operation
|
||||
from ..options import ConnexionOptions
|
||||
from ..resolver import Resolver
|
||||
from ..spec import Specification
|
||||
@@ -43,7 +43,14 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""Base API class with only minimal behavior related to the specification."""
|
||||
"""Base API class with only minimal behavior related to the specification.
|
||||
|
||||
:param specification: OpenAPI specification. Can be provided either as dict, or as path
|
||||
to file.
|
||||
:param base_path: Base path to host the API.
|
||||
:param arguments: Jinja arguments to resolve in specification.
|
||||
:param options: New style options dictionary.
|
||||
"""
|
||||
logger.debug('Loading specification: %s', specification,
|
||||
extra={'swagger_yaml': specification,
|
||||
'base_path': base_path,
|
||||
@@ -109,7 +116,105 @@ class AbstractSwaggerUIAPI(AbstractSpecAPI):
|
||||
"""
|
||||
|
||||
|
||||
class AbstractAPI(AbstractSpecAPI):
|
||||
class AbstractMinimalAPI(AbstractSpecAPI):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
resolver: t.Optional[Resolver] = None,
|
||||
resolver_error_handler: t.Optional[t.Callable] = None,
|
||||
debug: bool = False,
|
||||
pass_context_arg_name: t.Optional[str] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""Minimal interface of an API, with only functionality related to routing.
|
||||
|
||||
:param resolver: Callable that maps operationID to a function
|
||||
:param resolver_error_handler: Callable that generates an Operation used for handling
|
||||
ResolveErrors
|
||||
:param debug: Flag to run in debug mode
|
||||
:param pass_context_arg_name: If not None URL request handling functions with an argument
|
||||
matching this name will be passed the framework's request context.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.debug = debug
|
||||
self.resolver_error_handler = resolver_error_handler
|
||||
|
||||
logger.debug('Security Definitions: %s', self.specification.security_definitions)
|
||||
|
||||
self.resolver = resolver or Resolver()
|
||||
|
||||
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
|
||||
self.pass_context_arg_name = pass_context_arg_name
|
||||
|
||||
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
|
||||
|
||||
self.add_paths()
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def make_security_handler_factory(pass_context_arg_name):
|
||||
""" Create SecurityHandlerFactory to create all security check handlers """
|
||||
|
||||
def add_paths(self, paths: t.Optional[dict] = None) -> None:
|
||||
"""
|
||||
Adds the paths defined in the specification as endpoints
|
||||
"""
|
||||
paths = paths or self.specification.get('paths', dict())
|
||||
for path, methods in paths.items():
|
||||
logger.debug('Adding %s%s...', self.base_path, path)
|
||||
|
||||
for method in methods:
|
||||
if method not in METHODS:
|
||||
continue
|
||||
try:
|
||||
self.add_operation(path, method)
|
||||
except ResolverError as err:
|
||||
# If we have an error handler for resolver errors, add it as an operation.
|
||||
# Otherwise treat it as any other error.
|
||||
if self.resolver_error_handler is not None:
|
||||
self._add_resolver_error_handler(method, path, err)
|
||||
else:
|
||||
self._handle_add_operation_error(path, method, err.exc_info)
|
||||
except Exception:
|
||||
# All other relevant exceptions should be handled as well.
|
||||
self._handle_add_operation_error(path, method, sys.exc_info())
|
||||
|
||||
def add_operation(self, path: str, method: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
|
||||
"""
|
||||
Adds the operation according to the user framework in use.
|
||||
It will be used to register the operation on the user framework router.
|
||||
"""
|
||||
|
||||
def _add_resolver_error_handler(self, method: str, path: str, err: ResolverError):
|
||||
"""
|
||||
Adds a handler for ResolverError for the given method and path.
|
||||
"""
|
||||
operation = self.resolver_error_handler(
|
||||
err,
|
||||
security=self.specification.security,
|
||||
security_definitions=self.specification.security_definitions
|
||||
)
|
||||
self._add_operation_internal(method, path, operation)
|
||||
|
||||
def _handle_add_operation_error(self, path: str, method: str, exc_info: tuple):
|
||||
url = f'{self.base_path}{path}'
|
||||
error_msg = 'Failed to add operation for {method} {url}'.format(
|
||||
method=method.upper(),
|
||||
url=url)
|
||||
if self.debug:
|
||||
logger.exception(error_msg)
|
||||
else:
|
||||
logger.error(error_msg)
|
||||
_type, value, traceback = exc_info
|
||||
raise value.with_traceback(traceback)
|
||||
|
||||
|
||||
class AbstractAPI(AbstractMinimalAPI, metaclass=AbstractAPIMeta):
|
||||
"""
|
||||
Defines an abstract interface for a Swagger API
|
||||
"""
|
||||
@@ -120,55 +225,17 @@ class AbstractAPI(AbstractSpecAPI):
|
||||
validator_map=None, pythonic_params=False, pass_context_arg_name=None, options=None,
|
||||
):
|
||||
"""
|
||||
:type specification: pathlib.Path | dict
|
||||
:type base_path: str | None
|
||||
:type arguments: dict | None
|
||||
:type validate_responses: bool
|
||||
:type strict_validation: bool
|
||||
:type auth_all_paths: bool
|
||||
:type debug: bool
|
||||
:param validator_map: Custom validators for the types "parameter", "body" and "response".
|
||||
:type validator_map: dict
|
||||
:param resolver: Callable that maps operationID to a function
|
||||
:param resolver_error_handler: If given, a callable that generates an
|
||||
Operation used for handling ResolveErrors
|
||||
:type resolver_error_handler: callable | None
|
||||
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended
|
||||
to any shadowed built-ins
|
||||
:type pythonic_params: bool
|
||||
:param options: New style options dictionary.
|
||||
:type options: dict | None
|
||||
:param pass_context_arg_name: If not None URL request handling functions with an argument matching this name
|
||||
will be passed the framework's request context.
|
||||
:type pass_context_arg_name: str | None
|
||||
"""
|
||||
self.debug = debug
|
||||
self.validator_map = validator_map
|
||||
self.resolver_error_handler = resolver_error_handler
|
||||
|
||||
logger.debug('Loading specification: %s', specification,
|
||||
extra={'swagger_yaml': specification,
|
||||
'base_path': base_path,
|
||||
'arguments': arguments,
|
||||
'auth_all_paths': auth_all_paths})
|
||||
|
||||
# Avoid validator having ability to modify specification
|
||||
self.specification = Specification.load(specification, arguments=arguments)
|
||||
|
||||
logger.debug('Read specification', extra={'spec': self.specification})
|
||||
|
||||
self.options = ConnexionOptions(options, oas_version=self.specification.version)
|
||||
|
||||
logger.debug('Options Loaded',
|
||||
extra={'swagger_ui': self.options.openapi_console_ui_available,
|
||||
'swagger_path': self.options.openapi_console_ui_from_dir,
|
||||
'swagger_url': self.options.openapi_console_ui_path})
|
||||
|
||||
self._set_base_path(base_path)
|
||||
|
||||
logger.debug('Security Definitions: %s', self.specification.security_definitions)
|
||||
|
||||
self.resolver = resolver or Resolver()
|
||||
|
||||
logger.debug('Validate Responses: %s', str(validate_responses))
|
||||
self.validate_responses = validate_responses
|
||||
@@ -179,14 +246,10 @@ class AbstractAPI(AbstractSpecAPI):
|
||||
logger.debug('Pythonic params: %s', str(pythonic_params))
|
||||
self.pythonic_params = pythonic_params
|
||||
|
||||
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
|
||||
self.pass_context_arg_name = pass_context_arg_name
|
||||
|
||||
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
|
||||
|
||||
super().__init__(specification, base_path=base_path, arguments=arguments, options=options)
|
||||
|
||||
self.add_paths()
|
||||
super().__init__(specification, base_path=base_path, arguments=arguments,
|
||||
resolver=resolver, auth_all_paths=auth_all_paths,
|
||||
resolver_error_handler=resolver_error_handler,
|
||||
debug=debug, pass_context_arg_name=pass_context_arg_name, options=options)
|
||||
|
||||
if auth_all_paths:
|
||||
self.add_auth_on_not_found(
|
||||
@@ -200,11 +263,6 @@ class AbstractAPI(AbstractSpecAPI):
|
||||
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def make_security_handler_factory(pass_context_arg_name):
|
||||
""" Create SecurityHandlerFactory to create all security check handlers """
|
||||
|
||||
def add_operation(self, path, method):
|
||||
"""
|
||||
Adds one operation to the api.
|
||||
@@ -236,62 +294,6 @@ class AbstractAPI(AbstractSpecAPI):
|
||||
)
|
||||
self._add_operation_internal(method, path, operation)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _add_operation_internal(self, method, path, operation):
|
||||
"""
|
||||
Adds the operation according to the user framework in use.
|
||||
It will be used to register the operation on the user framework router.
|
||||
"""
|
||||
|
||||
def _add_resolver_error_handler(self, method, path, err):
|
||||
"""
|
||||
Adds a handler for ResolverError for the given method and path.
|
||||
"""
|
||||
operation = self.resolver_error_handler(
|
||||
err,
|
||||
security=self.specification.security,
|
||||
security_definitions=self.specification.security_definitions
|
||||
)
|
||||
self._add_operation_internal(method, path, operation)
|
||||
|
||||
def add_paths(self, paths=None):
|
||||
"""
|
||||
Adds the paths defined in the specification as endpoints
|
||||
|
||||
:type paths: list
|
||||
"""
|
||||
paths = paths or self.specification.get('paths', dict())
|
||||
for path, methods in paths.items():
|
||||
logger.debug('Adding %s%s...', self.base_path, path)
|
||||
|
||||
for method in methods:
|
||||
if method not in METHODS:
|
||||
continue
|
||||
try:
|
||||
self.add_operation(path, method)
|
||||
except ResolverError as err:
|
||||
# If we have an error handler for resolver errors, add it as an operation.
|
||||
# Otherwise treat it as any other error.
|
||||
if self.resolver_error_handler is not None:
|
||||
self._add_resolver_error_handler(method, path, err)
|
||||
else:
|
||||
self._handle_add_operation_error(path, method, err.exc_info)
|
||||
except Exception:
|
||||
# All other relevant exceptions should be handled as well.
|
||||
self._handle_add_operation_error(path, method, sys.exc_info())
|
||||
|
||||
def _handle_add_operation_error(self, path, method, exc_info):
|
||||
url = f'{self.base_path}{path}'
|
||||
error_msg = 'Failed to add operation for {method} {url}'.format(
|
||||
method=method.upper(),
|
||||
url=url)
|
||||
if self.debug:
|
||||
logger.exception(error_msg)
|
||||
else:
|
||||
logger.error(error_msg)
|
||||
_type, value, traceback = exc_info
|
||||
raise value.with_traceback(traceback)
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_request(self, *args, **kwargs):
|
||||
|
||||
@@ -7,6 +7,7 @@ import abc
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
from ..middleware import ConnexionMiddleware
|
||||
from ..options import ConnexionOptions
|
||||
from ..resolver import Resolver
|
||||
|
||||
@@ -16,7 +17,7 @@ logger = logging.getLogger('connexion.app')
|
||||
class AbstractApp(metaclass=abc.ABCMeta):
|
||||
def __init__(self, import_name, api_cls, port=None, specification_dir='',
|
||||
host=None, server=None, server_args=None, arguments=None, auth_all_paths=False, debug=None,
|
||||
resolver=None, options=None, skip_error_handlers=False):
|
||||
resolver=None, options=None, skip_error_handlers=False, middlewares=None):
|
||||
"""
|
||||
:param import_name: the name of the application package
|
||||
:type import_name: str
|
||||
@@ -37,6 +38,8 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
||||
:param debug: include debugging information
|
||||
:type debug: bool
|
||||
:param resolver: Callable that maps operationID to a function
|
||||
:param middlewares: Callable that maps operationID to a function
|
||||
:type middlewares: list | None
|
||||
"""
|
||||
self.port = port
|
||||
self.host = host
|
||||
@@ -54,8 +57,12 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
||||
|
||||
self.server = server
|
||||
self.server_args = dict() if server_args is None else server_args
|
||||
|
||||
self.app = self.create_app()
|
||||
self.middleware = self._apply_middleware()
|
||||
|
||||
if middlewares is None:
|
||||
middlewares = ConnexionMiddleware.default_middlewares
|
||||
self.middleware = self._apply_middleware(middlewares)
|
||||
|
||||
# we get our application root path to avoid duplicating logic
|
||||
self.root_path = self.get_root_path()
|
||||
@@ -80,7 +87,7 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _apply_middleware(self):
|
||||
def _apply_middleware(self, middlewares):
|
||||
"""
|
||||
Apply middleware to application
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,7 @@ logger = logging.getLogger('connexion.app')
|
||||
|
||||
|
||||
class FlaskApp(AbstractApp):
|
||||
|
||||
def __init__(self, import_name, server='flask', extra_files=None, **kwargs):
|
||||
"""
|
||||
:param extra_files: additional files to be watched by the reloader, defaults to the swagger specs of added apis
|
||||
@@ -41,8 +42,8 @@ class FlaskApp(AbstractApp):
|
||||
app.url_map.converters['int'] = IntegerConverter
|
||||
return app
|
||||
|
||||
def _apply_middleware(self):
|
||||
middlewares = [*ConnexionMiddleware.default_middlewares,
|
||||
def _apply_middleware(self, middlewares):
|
||||
middlewares = [*middlewares,
|
||||
a2wsgi.WSGIMiddleware]
|
||||
middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)
|
||||
|
||||
|
||||
@@ -96,6 +96,17 @@ class BadRequestProblem(ProblemException):
|
||||
super().__init__(status=400, title=title, detail=detail)
|
||||
|
||||
|
||||
class NotFoundProblem(ProblemException):
|
||||
|
||||
description = (
|
||||
'The requested URL was not found on the server. If you entered the URL manually please '
|
||||
'check your spelling and try again.'
|
||||
)
|
||||
|
||||
def __init__(self, title="Not Found", detail=description):
|
||||
super().__init__(status=404, title=title, detail=detail)
|
||||
|
||||
|
||||
class UnsupportedMediaTypeProblem(ProblemException):
|
||||
|
||||
def __init__(self, title="Unsupported Media Type", detail=None):
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
This module defines interfaces for requests and responses used in Connexion for authentication,
|
||||
validation, serialization, etc.
|
||||
"""
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
||||
|
||||
|
||||
class ConnexionRequest:
|
||||
@@ -52,3 +54,11 @@ class ConnexionResponse:
|
||||
self.body = body
|
||||
self.headers = headers or {}
|
||||
self.is_streamed = is_streamed
|
||||
|
||||
|
||||
class MiddlewareRequest(StarletteRequest):
|
||||
"""Wraps starlette Request so it can easily be extended."""
|
||||
|
||||
|
||||
class MiddlewareResponse(StarletteStreamingResponse):
|
||||
"""Wraps starlette StreamingResponse so it can easily be extended."""
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from .abstract import AppMiddleware # NOQA
|
||||
from .main import ConnexionMiddleware # NOQA
|
||||
from .routing import RoutingMiddleware # NOQA
|
||||
from .swagger_ui import SwaggerUIMiddleware # NOQA
|
||||
|
||||
@@ -4,6 +4,8 @@ import typing as t
|
||||
|
||||
|
||||
class AppMiddleware(abc.ABC):
|
||||
"""Middlewares that need the APIs to be registered on them should inherit from this base
|
||||
class"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None:
|
||||
33
connexion/middleware/exceptions.py
Normal file
33
connexion/middleware/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import json
|
||||
|
||||
from starlette.exceptions import \
|
||||
ExceptionMiddleware as StarletteExceptionMiddleware
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from connexion.exceptions import problem
|
||||
|
||||
|
||||
class ExceptionMiddleware(StarletteExceptionMiddleware):
|
||||
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
|
||||
existing connexion behavior."""
|
||||
|
||||
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
||||
try:
|
||||
headers = exc.headers
|
||||
except AttributeError:
|
||||
# Starlette < 0.19
|
||||
headers = {}
|
||||
|
||||
connexion_response = problem(title=exc.detail,
|
||||
detail=exc.detail,
|
||||
status=exc.status_code,
|
||||
headers=headers)
|
||||
|
||||
return Response(
|
||||
content=json.dumps(connexion_response.body),
|
||||
status_code=connexion_response.status_code,
|
||||
media_type=connexion_response.mimetype,
|
||||
headers=connexion_response.headers
|
||||
)
|
||||
@@ -1,10 +1,11 @@
|
||||
import pathlib
|
||||
import typing as t
|
||||
|
||||
from starlette.exceptions import ExceptionMiddleware
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.middleware.base import AppMiddleware
|
||||
from connexion.middleware.abstract import AppMiddleware
|
||||
from connexion.middleware.exceptions import ExceptionMiddleware
|
||||
from connexion.middleware.routing import RoutingMiddleware
|
||||
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ class ConnexionMiddleware:
|
||||
default_middlewares = [
|
||||
ExceptionMiddleware,
|
||||
SwaggerUIMiddleware,
|
||||
RoutingMiddleware,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
170
connexion/middleware/routing.py
Normal file
170
connexion/middleware/routing.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import pathlib
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.routing import Router
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.apis import AbstractMinimalAPI
|
||||
from connexion.exceptions import NotFoundProblem
|
||||
from connexion.middleware import AppMiddleware
|
||||
from connexion.operations import AbstractOperation, make_operation
|
||||
from connexion.resolver import Resolver
|
||||
|
||||
CONNEXION_CONTEXT = 'connexion.context'
|
||||
|
||||
|
||||
_scope_receive_send: ContextVar[tuple] = ContextVar('SCOPE_RECEIVE_SEND')
|
||||
|
||||
|
||||
class MiddlewareResolver(Resolver):
|
||||
|
||||
def __init__(self, call_next: t.Callable) -> None:
|
||||
"""Resolver that resolves each operation to the provided call_next function."""
|
||||
super().__init__()
|
||||
self.call_next = call_next
|
||||
|
||||
def resolve_function_from_operation_id(self, operation_id: str) -> t.Callable:
|
||||
return self.call_next
|
||||
|
||||
|
||||
class RoutingMiddleware(AppMiddleware):
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
"""Middleware that resolves the Operation for an incoming request and attaches it to the
|
||||
scope.
|
||||
|
||||
:param app: app to wrap in middleware.
|
||||
"""
|
||||
self.app = app
|
||||
# Pass unknown routes to next app
|
||||
self.router = Router(default=self.default_fn)
|
||||
|
||||
def add_api(
|
||||
self,
|
||||
specification: t.Union[pathlib.Path, str, dict],
|
||||
base_path: t.Optional[str] = None,
|
||||
arguments: t.Optional[dict] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""Add an API to the router based on a OpenAPI spec.
|
||||
|
||||
:param specification: OpenAPI spec as dict or path to file.
|
||||
:param base_path: Base path where to add this API.
|
||||
:param arguments: Jinja arguments to replace in the spec.
|
||||
"""
|
||||
kwargs.pop("resolver", None)
|
||||
resolver = MiddlewareResolver(self.create_call_next())
|
||||
api = MiddlewareAPI(specification, base_path=base_path, arguments=arguments,
|
||||
resolver=resolver, default=self.default_fn, **kwargs)
|
||||
self.router.mount(api.base_path, app=api.router)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Route request to matching operation, and attach it to the scope before calling the
|
||||
next app."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
_scope_receive_send.set((scope.copy(), receive, send))
|
||||
|
||||
# Needs to be set so starlette router throws exceptions instead of returning error responses
|
||||
scope['app'] = self
|
||||
try:
|
||||
await self.router(scope, receive, send)
|
||||
except ValueError:
|
||||
raise NotFoundProblem
|
||||
|
||||
async def default_fn(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Callback to call next app as default when no matching route is found."""
|
||||
original_scope, *_ = _scope_receive_send.get()
|
||||
|
||||
api_base_path = scope.get('root_path', '')[len(original_scope.get('root_path', '')):]
|
||||
|
||||
extensions = original_scope.setdefault('extensions', {})
|
||||
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
|
||||
connexion_context.update({
|
||||
'api_base_path': api_base_path
|
||||
})
|
||||
await self.app(original_scope, receive, send)
|
||||
|
||||
def create_call_next(self):
|
||||
|
||||
async def call_next(
|
||||
operation: AbstractOperation,
|
||||
request: StarletteRequest = None
|
||||
) -> None:
|
||||
"""Attach operation to scope and pass it to the next app"""
|
||||
scope, receive, send = _scope_receive_send.get()
|
||||
|
||||
api_base_path = request.scope.get('root_path', '')[len(scope.get('root_path', '')):]
|
||||
|
||||
extensions = scope.setdefault('extensions', {})
|
||||
connexion_context = extensions.setdefault(CONNEXION_CONTEXT, {})
|
||||
connexion_context.update({
|
||||
'api_base_path': api_base_path,
|
||||
'operation_id': operation.operation_id
|
||||
})
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
return call_next
|
||||
|
||||
|
||||
class MiddlewareAPI(AbstractMinimalAPI):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specification: t.Union[pathlib.Path, str, dict],
|
||||
base_path: t.Optional[str] = None,
|
||||
arguments: t.Optional[dict] = None,
|
||||
resolver: t.Optional[Resolver] = None,
|
||||
default: ASGIApp = None,
|
||||
resolver_error_handler: t.Optional[t.Callable] = None,
|
||||
debug: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""API implementation on top of Starlette Router for Connexion middleware."""
|
||||
self.router = Router(default=default)
|
||||
|
||||
super().__init__(
|
||||
specification,
|
||||
base_path=base_path,
|
||||
arguments=arguments,
|
||||
resolver=resolver,
|
||||
resolver_error_handler=resolver_error_handler,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
def add_operation(self, path: str, method: str) -> None:
|
||||
operation = make_operation(
|
||||
self.specification,
|
||||
self,
|
||||
path,
|
||||
method,
|
||||
self.resolver
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def patch_operation_function():
|
||||
"""Patch the operation function so no decorators are set in the middleware. This
|
||||
should be cleaned up by separating the APIs and Operations between the App and
|
||||
middleware"""
|
||||
original_operation_function = AbstractOperation.function
|
||||
AbstractOperation.function = operation._resolution.function
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
AbstractOperation.function = original_operation_function
|
||||
|
||||
with patch_operation_function():
|
||||
self._add_operation_internal(method, path, operation)
|
||||
|
||||
def _add_operation_internal(self, method: str, path: str, operation: AbstractOperation) -> None:
|
||||
self.router.add_route(path, operation.function, methods=[method])
|
||||
|
||||
@staticmethod
|
||||
def make_security_handler_factory(pass_context_arg_name):
|
||||
""" Create default SecurityHandlerFactory to create all security check handlers """
|
||||
pass
|
||||
@@ -13,10 +13,9 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.apis import AbstractSwaggerUIAPI
|
||||
from connexion.jsonifier import JSONEncoder, Jsonifier
|
||||
from connexion.middleware import AppMiddleware
|
||||
from connexion.utils import yamldumper
|
||||
|
||||
from .base import AppMiddleware
|
||||
|
||||
logger = logging.getLogger('connexion.middleware.swagger_ui')
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ def test_errors(problem_app):
|
||||
error405 = json.loads(get_greeting.data.decode('utf-8', 'replace'))
|
||||
assert error405['type'] == 'about:blank'
|
||||
assert error405['title'] == 'Method Not Allowed'
|
||||
assert error405['detail'] == 'The method is not allowed for the requested URL.'
|
||||
assert error405['status'] == 405
|
||||
assert 'instance' not in error405
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from connexion import App
|
||||
@@ -136,7 +135,7 @@ def json_datetime_dir():
|
||||
return FIXTURES_FOLDER / 'datetime_support'
|
||||
|
||||
|
||||
def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs):
|
||||
def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', middlewares=None, **kwargs):
|
||||
debug = True
|
||||
if 'debug' in kwargs:
|
||||
debug = kwargs['debug']
|
||||
@@ -145,6 +144,7 @@ def build_app_from_fixture(api_spec_folder, spec_file='openapi.yaml', **kwargs):
|
||||
cnx_app = App(__name__,
|
||||
port=5001,
|
||||
specification_dir=FIXTURES_FOLDER / api_spec_folder,
|
||||
middlewares=middlewares,
|
||||
debug=debug)
|
||||
|
||||
cnx_app.add_api(spec_file, **kwargs)
|
||||
@@ -254,4 +254,3 @@ def unordered_definition_app(request):
|
||||
def bad_operations_app(request):
|
||||
return build_app_from_fixture('bad_operations', request.param,
|
||||
resolver_error=501)
|
||||
|
||||
|
||||
@@ -13,15 +13,3 @@ def fake_json_auth(token, required_scopes=None):
|
||||
return json.loads(token)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
async def async_basic_auth(username, password, required_scopes=None, request=None):
|
||||
return fake_basic_auth(username, password, required_scopes)
|
||||
|
||||
|
||||
async def async_json_auth(token, required_scopes=None, request=None):
|
||||
return fake_json_auth(token, required_scopes)
|
||||
|
||||
|
||||
async def async_scope_validation(required_scopes, token_scopes, request):
|
||||
return required_scopes == token_scopes
|
||||
|
||||
44
tests/test_middleware.py
Normal file
44
tests/test_middleware.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import pytest
|
||||
from connexion.middleware import ConnexionMiddleware
|
||||
from connexion.middleware.routing import CONNEXION_CONTEXT
|
||||
from starlette.datastructures import MutableHeaders
|
||||
|
||||
from conftest import SPECS, build_app_from_fixture
|
||||
|
||||
|
||||
class TestMiddleware:
|
||||
"""Middleware to check if operation is accessible on scope."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
operation_id = scope['extensions'][CONNEXION_CONTEXT]['operation_id']
|
||||
|
||||
async def patched_send(message):
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(scope=message)
|
||||
headers["operation_id"] = operation_id
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, patched_send)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", params=SPECS)
|
||||
def middleware_app(request):
|
||||
middlewares = ConnexionMiddleware.default_middlewares + [TestMiddleware]
|
||||
return build_app_from_fixture('simple', request.param, middlewares=middlewares)
|
||||
|
||||
|
||||
def test_routing_middleware(middleware_app):
|
||||
app_client = middleware_app.app.test_client()
|
||||
|
||||
response = app_client.post("/v1.0/greeting/robbe")
|
||||
|
||||
assert response.headers.get('operation_id') == 'fakeapi.hello.post_greeting', \
|
||||
response.status_code
|
||||
Reference in New Issue
Block a user