mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 12:27:45 +00:00
Extract Swagger UI functionality into middleware (#1496)
* Extract swagger UI functionality from AbstractAPI * Extract Swagger UI functionality into middleware Co-authored-by: Wojciech Paciorek <arkkors@users.noreply.github.com> * Add additional docstrings Co-authored-by: Wojciech Paciorek <arkkors@users.noreply.github.com>
This commit is contained in:
@@ -13,4 +13,4 @@ on the framework app.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from .abstract import AbstractAPI # NOQA
|
from .abstract import AbstractAPI, AbstractSwaggerUIAPI # NOQA
|
||||||
|
|||||||
@@ -32,7 +32,84 @@ class AbstractAPIMeta(abc.ABCMeta):
|
|||||||
cls._set_jsonifier()
|
cls._set_jsonifier()
|
||||||
|
|
||||||
|
|
||||||
class AbstractAPI(metaclass=AbstractAPIMeta):
|
class AbstractSpecAPI(metaclass=AbstractAPIMeta):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
specification: t.Union[pathlib.Path, str, dict],
|
||||||
|
base_path: t.Optional[str] = None,
|
||||||
|
arguments: t.Optional[dict] = None,
|
||||||
|
options: t.Optional[dict] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""Base API class with only minimal behavior related to the specification."""
|
||||||
|
logger.debug('Loading specification: %s', specification,
|
||||||
|
extra={'swagger_yaml': specification,
|
||||||
|
'base_path': base_path,
|
||||||
|
'arguments': arguments})
|
||||||
|
|
||||||
|
# Avoid validator having ability to modify specification
|
||||||
|
self.specification = Specification.load(specification, arguments=arguments)
|
||||||
|
|
||||||
|
logger.debug('Read specification', extra={'spec': self.specification})
|
||||||
|
|
||||||
|
self.options = ConnexionOptions(options, oas_version=self.specification.version)
|
||||||
|
|
||||||
|
logger.debug('Options Loaded',
|
||||||
|
extra={'swagger_ui': self.options.openapi_console_ui_available,
|
||||||
|
'swagger_path': self.options.openapi_console_ui_from_dir,
|
||||||
|
'swagger_url': self.options.openapi_console_ui_path})
|
||||||
|
|
||||||
|
self._set_base_path(base_path)
|
||||||
|
|
||||||
|
def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
|
||||||
|
if base_path is not None:
|
||||||
|
# update spec to include user-provided base_path
|
||||||
|
self.specification.base_path = base_path
|
||||||
|
self.base_path = base_path
|
||||||
|
else:
|
||||||
|
self.base_path = self.specification.base_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_jsonifier(cls):
|
||||||
|
cls.jsonifier = Jsonifier()
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractSwaggerUIAPI(AbstractSpecAPI):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.options.openapi_spec_available:
|
||||||
|
self.add_openapi_json()
|
||||||
|
self.add_openapi_yaml()
|
||||||
|
|
||||||
|
if self.options.openapi_console_ui_available:
|
||||||
|
self.add_swagger_ui()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def add_openapi_json(self):
|
||||||
|
"""
|
||||||
|
Adds openapi spec to {base_path}/openapi.json
|
||||||
|
(or {base_path}/swagger.json for swagger2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def add_openapi_yaml(self):
|
||||||
|
"""
|
||||||
|
Adds openapi spec to {base_path}/openapi.yaml
|
||||||
|
(or {base_path}/swagger.yaml for swagger2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def add_swagger_ui(self):
|
||||||
|
"""
|
||||||
|
Adds swagger ui to {base_path}/ui/
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractAPI(AbstractSpecAPI):
|
||||||
"""
|
"""
|
||||||
Defines an abstract interface for a Swagger API
|
Defines an abstract interface for a Swagger API
|
||||||
"""
|
"""
|
||||||
@@ -107,12 +184,7 @@ class AbstractAPI(metaclass=AbstractAPIMeta):
|
|||||||
|
|
||||||
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
|
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
|
||||||
|
|
||||||
if self.options.openapi_spec_available:
|
super().__init__(specification, base_path=base_path, arguments=arguments, options=options)
|
||||||
self.add_openapi_json()
|
|
||||||
self.add_openapi_yaml()
|
|
||||||
|
|
||||||
if self.options.openapi_console_ui_available:
|
|
||||||
self.add_swagger_ui()
|
|
||||||
|
|
||||||
self.add_paths()
|
self.add_paths()
|
||||||
|
|
||||||
@@ -122,27 +194,6 @@ class AbstractAPI(metaclass=AbstractAPIMeta):
|
|||||||
self.specification.security_definitions
|
self.specification.security_definitions
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_base_path(self, base_path=None):
|
|
||||||
if base_path is not None:
|
|
||||||
# update spec to include user-provided base_path
|
|
||||||
self.specification.base_path = base_path
|
|
||||||
self.base_path = base_path
|
|
||||||
else:
|
|
||||||
self.base_path = self.specification.base_path
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def add_openapi_json(self):
|
|
||||||
"""
|
|
||||||
Adds openapi spec to {base_path}/openapi.json
|
|
||||||
(or {base_path}/swagger.json for swagger2)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def add_swagger_ui(self):
|
|
||||||
"""
|
|
||||||
Adds swagger ui to {base_path}/ui/
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def add_auth_on_not_found(self, security, security_definitions):
|
def add_auth_on_not_found(self, security, security_definitions):
|
||||||
"""
|
"""
|
||||||
@@ -422,7 +473,3 @@ class AbstractAPI(metaclass=AbstractAPIMeta):
|
|||||||
|
|
||||||
def json_loads(self, data):
|
def json_loads(self, data):
|
||||||
return self.jsonifier.loads(data)
|
return self.jsonifier.loads(data)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _set_jsonifier(cls):
|
|
||||||
cls.jsonifier = Jsonifier()
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Connexion requests / responses.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import pathlib
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -18,7 +17,7 @@ from connexion.handlers import AuthErrorHandler
|
|||||||
from connexion.jsonifier import Jsonifier
|
from connexion.jsonifier import Jsonifier
|
||||||
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
|
||||||
from connexion.security import FlaskSecurityHandlerFactory
|
from connexion.security import FlaskSecurityHandlerFactory
|
||||||
from connexion.utils import is_json_mimetype, yamldumper
|
from connexion.utils import is_json_mimetype
|
||||||
|
|
||||||
logger = logging.getLogger('connexion.apis.flask_api')
|
logger = logging.getLogger('connexion.apis.flask_api')
|
||||||
|
|
||||||
@@ -40,72 +39,6 @@ class FlaskApi(AbstractAPI):
|
|||||||
self.blueprint = flask.Blueprint(endpoint, __name__, url_prefix=self.base_path,
|
self.blueprint = flask.Blueprint(endpoint, __name__, url_prefix=self.base_path,
|
||||||
template_folder=str(self.options.openapi_console_ui_from_dir))
|
template_folder=str(self.options.openapi_console_ui_from_dir))
|
||||||
|
|
||||||
def add_openapi_json(self):
|
|
||||||
"""
|
|
||||||
Adds spec json to {base_path}/swagger.json
|
|
||||||
or {base_path}/openapi.json (for oas3)
|
|
||||||
"""
|
|
||||||
logger.debug('Adding spec json: %s/%s', self.base_path,
|
|
||||||
self.options.openapi_spec_path)
|
|
||||||
endpoint_name = f"{self.blueprint.name}_openapi_json"
|
|
||||||
|
|
||||||
self.blueprint.add_url_rule(self.options.openapi_spec_path,
|
|
||||||
endpoint_name,
|
|
||||||
self._handlers.get_json_spec)
|
|
||||||
|
|
||||||
def add_openapi_yaml(self):
|
|
||||||
"""
|
|
||||||
Adds spec yaml to {base_path}/swagger.yaml
|
|
||||||
or {base_path}/openapi.yaml (for oas3)
|
|
||||||
"""
|
|
||||||
if not self.options.openapi_spec_path.endswith("json"):
|
|
||||||
return
|
|
||||||
|
|
||||||
openapi_spec_path_yaml = \
|
|
||||||
self.options.openapi_spec_path[:-len("json")] + "yaml"
|
|
||||||
logger.debug('Adding spec yaml: %s/%s', self.base_path,
|
|
||||||
openapi_spec_path_yaml)
|
|
||||||
endpoint_name = f"{self.blueprint.name}_openapi_yaml"
|
|
||||||
self.blueprint.add_url_rule(
|
|
||||||
openapi_spec_path_yaml,
|
|
||||||
endpoint_name,
|
|
||||||
self._handlers.get_yaml_spec
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_swagger_ui(self):
|
|
||||||
"""
|
|
||||||
Adds swagger ui to {base_path}/ui/
|
|
||||||
"""
|
|
||||||
console_ui_path = self.options.openapi_console_ui_path.strip('/')
|
|
||||||
logger.debug('Adding swagger-ui: %s/%s/',
|
|
||||||
self.base_path,
|
|
||||||
console_ui_path)
|
|
||||||
|
|
||||||
if self.options.openapi_console_ui_config is not None:
|
|
||||||
config_endpoint_name = f"{self.blueprint.name}_swagger_ui_config"
|
|
||||||
config_file_url = '/{console_ui_path}/swagger-ui-config.json'.format(
|
|
||||||
console_ui_path=console_ui_path)
|
|
||||||
|
|
||||||
self.blueprint.add_url_rule(config_file_url,
|
|
||||||
config_endpoint_name,
|
|
||||||
lambda: flask.jsonify(self.options.openapi_console_ui_config))
|
|
||||||
|
|
||||||
static_endpoint_name = f"{self.blueprint.name}_swagger_ui_static"
|
|
||||||
static_files_url = '/{console_ui_path}/<path:filename>'.format(
|
|
||||||
console_ui_path=console_ui_path)
|
|
||||||
|
|
||||||
self.blueprint.add_url_rule(static_files_url,
|
|
||||||
static_endpoint_name,
|
|
||||||
self._handlers.console_ui_static_files)
|
|
||||||
|
|
||||||
index_endpoint_name = f"{self.blueprint.name}_swagger_ui_index"
|
|
||||||
console_ui_url = '/{console_ui_path}/'.format(
|
|
||||||
console_ui_path=console_ui_path)
|
|
||||||
|
|
||||||
self.blueprint.add_url_rule(console_ui_url,
|
|
||||||
index_endpoint_name,
|
|
||||||
self._handlers.console_ui_home)
|
|
||||||
|
|
||||||
def add_auth_on_not_found(self, security, security_definitions):
|
def add_auth_on_not_found(self, security, security_definitions):
|
||||||
"""
|
"""
|
||||||
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
|
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
|
||||||
@@ -127,13 +60,6 @@ class FlaskApi(AbstractAPI):
|
|||||||
function = operation.function
|
function = operation.function
|
||||||
self.blueprint.add_url_rule(flask_path, endpoint_name, function, methods=[method])
|
self.blueprint.add_url_rule(flask_path, endpoint_name, function, methods=[method])
|
||||||
|
|
||||||
@property
|
|
||||||
def _handlers(self):
|
|
||||||
# type: () -> InternalHandlers
|
|
||||||
if not hasattr(self, '_internal_handlers'):
|
|
||||||
self._internal_handlers = InternalHandlers(self.base_path, self.options, self.specification)
|
|
||||||
return self._internal_handlers
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_response(cls, response, mimetype=None, request=None):
|
def get_response(cls, response, mimetype=None, request=None):
|
||||||
"""Gets ConnexionResponse instance for the operation handler
|
"""Gets ConnexionResponse instance for the operation handler
|
||||||
@@ -267,65 +193,3 @@ def _get_context():
|
|||||||
|
|
||||||
|
|
||||||
context = LocalProxy(_get_context)
|
context = LocalProxy(_get_context)
|
||||||
|
|
||||||
|
|
||||||
class InternalHandlers:
|
|
||||||
"""
|
|
||||||
Flask handlers for internally registered endpoints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, base_path, options, specification):
|
|
||||||
self.base_path = base_path
|
|
||||||
self.options = options
|
|
||||||
self.specification = specification
|
|
||||||
|
|
||||||
def console_ui_home(self):
|
|
||||||
"""
|
|
||||||
Home page of the OpenAPI Console UI.
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
openapi_json_route_name = "{blueprint}.{prefix}_openapi_json"
|
|
||||||
escaped = flask_utils.flaskify_endpoint(self.base_path)
|
|
||||||
openapi_json_route_name = openapi_json_route_name.format(
|
|
||||||
blueprint=escaped,
|
|
||||||
prefix=escaped
|
|
||||||
)
|
|
||||||
template_variables = {
|
|
||||||
'openapi_spec_url': flask.url_for(openapi_json_route_name),
|
|
||||||
**self.options.openapi_console_ui_index_template_variables,
|
|
||||||
}
|
|
||||||
if self.options.openapi_console_ui_config is not None:
|
|
||||||
template_variables['configUrl'] = 'swagger-ui-config.json'
|
|
||||||
|
|
||||||
# Use `render_template_string` instead of `render_template` to circumvent the flask
|
|
||||||
# template lookup mechanism and explicitly render the template of the current blueprint.
|
|
||||||
# https://github.com/zalando/connexion/issues/1289#issuecomment-884105076
|
|
||||||
template_dir = pathlib.Path(self.options.openapi_console_ui_from_dir)
|
|
||||||
index_path = template_dir / 'index.j2'
|
|
||||||
return flask.render_template_string(index_path.read_text(), **template_variables)
|
|
||||||
|
|
||||||
def console_ui_static_files(self, filename):
|
|
||||||
"""
|
|
||||||
Servers the static files for the OpenAPI Console UI.
|
|
||||||
|
|
||||||
:param filename: Requested file contents.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# convert PosixPath to str
|
|
||||||
static_dir = str(self.options.openapi_console_ui_from_dir)
|
|
||||||
return flask.send_from_directory(static_dir, filename)
|
|
||||||
|
|
||||||
def get_json_spec(self):
|
|
||||||
return flask.jsonify(self._spec_for_prefix())
|
|
||||||
|
|
||||||
def get_yaml_spec(self):
|
|
||||||
return yamldumper(self._spec_for_prefix()), 200, {"Content-Type": "text/yaml"}
|
|
||||||
|
|
||||||
def _spec_for_prefix(self):
|
|
||||||
"""
|
|
||||||
Modify base_path in the spec based on incoming url
|
|
||||||
This fixes problems with reverse proxies changing the path.
|
|
||||||
"""
|
|
||||||
base_path = flask.url_for(flask.request.endpoint).rsplit("/", 1)[0]
|
|
||||||
return self.specification.with_base_path(base_path).raw
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
|||||||
self.server = server
|
self.server = server
|
||||||
self.server_args = dict() if server_args is None else server_args
|
self.server_args = dict() if server_args is None else server_args
|
||||||
self.app = self.create_app()
|
self.app = self.create_app()
|
||||||
self._apply_middleware()
|
self.middleware = self._apply_middleware()
|
||||||
|
|
||||||
# we get our application root path to avoid duplicating logic
|
# we get our application root path to avoid duplicating logic
|
||||||
self.root_path = self.get_root_path()
|
self.root_path = self.get_root_path()
|
||||||
@@ -153,6 +153,22 @@ class AbstractApp(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
api_options = self.options.extend(options)
|
api_options = self.options.extend(options)
|
||||||
|
|
||||||
|
self.middleware.add_api(
|
||||||
|
specification,
|
||||||
|
base_path=base_path,
|
||||||
|
arguments=arguments,
|
||||||
|
resolver=resolver,
|
||||||
|
resolver_error_handler=resolver_error_handler,
|
||||||
|
validate_responses=validate_responses,
|
||||||
|
strict_validation=strict_validation,
|
||||||
|
auth_all_paths=auth_all_paths,
|
||||||
|
debug=self.debug,
|
||||||
|
validator_map=validator_map,
|
||||||
|
pythonic_params=pythonic_params,
|
||||||
|
pass_context_arg_name=pass_context_arg_name,
|
||||||
|
options=api_options.as_dict()
|
||||||
|
)
|
||||||
|
|
||||||
api = self.api_cls(specification,
|
api = self.api_cls(specification,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ class FlaskApp(AbstractApp):
|
|||||||
See :class:`~connexion.AbstractApp` for additional parameters.
|
See :class:`~connexion.AbstractApp` for additional parameters.
|
||||||
"""
|
"""
|
||||||
self.extra_files = extra_files or []
|
self.extra_files = extra_files or []
|
||||||
self.middleware = None
|
|
||||||
|
|
||||||
super().__init__(import_name, FlaskApi, server=server, **kwargs)
|
super().__init__(import_name, FlaskApi, server=server, **kwargs)
|
||||||
|
|
||||||
@@ -45,10 +44,12 @@ class FlaskApp(AbstractApp):
|
|||||||
def _apply_middleware(self):
|
def _apply_middleware(self):
|
||||||
middlewares = [*ConnexionMiddleware.default_middlewares,
|
middlewares = [*ConnexionMiddleware.default_middlewares,
|
||||||
a2wsgi.WSGIMiddleware]
|
a2wsgi.WSGIMiddleware]
|
||||||
self.middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)
|
middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)
|
||||||
|
|
||||||
# Wrap with ASGI to WSGI middleware for usage with development server and test client
|
# Wrap with ASGI to WSGI middleware for usage with development server and test client
|
||||||
self.app.wsgi_app = a2wsgi.ASGIMiddleware(self.middleware)
|
self.app.wsgi_app = a2wsgi.ASGIMiddleware(middleware)
|
||||||
|
|
||||||
|
return middleware
|
||||||
|
|
||||||
def get_root_path(self):
|
def get_root_path(self):
|
||||||
return pathlib.Path(self.app.root_path)
|
return pathlib.Path(self.app.root_path)
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
from .main import ConnexionMiddleware # NOQA
|
from .main import ConnexionMiddleware # NOQA
|
||||||
|
from .swagger_ui import SwaggerUIMiddleware # NOQA
|
||||||
|
|||||||
10
connexion/middleware/base.py
Normal file
10
connexion/middleware/base.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import abc
|
||||||
|
import pathlib
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
|
||||||
|
class AppMiddleware(abc.ABC):
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None:
|
||||||
|
pass
|
||||||
@@ -1,12 +1,18 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
from starlette.exceptions import ExceptionMiddleware
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
from connexion.middleware.base import AppMiddleware
|
||||||
|
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
||||||
|
|
||||||
|
|
||||||
class ConnexionMiddleware:
|
class ConnexionMiddleware:
|
||||||
|
|
||||||
default_middlewares = [
|
default_middlewares = [
|
||||||
|
ExceptionMiddleware,
|
||||||
|
SwaggerUIMiddleware,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -25,8 +31,6 @@ class ConnexionMiddleware:
|
|||||||
middlewares = self.default_middlewares
|
middlewares = self.default_middlewares
|
||||||
self.app, self.apps = self._apply_middlewares(app, middlewares)
|
self.app, self.apps = self._apply_middlewares(app, middlewares)
|
||||||
|
|
||||||
self._routing_middleware = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _apply_middlewares(app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]) \
|
def _apply_middlewares(app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]) \
|
||||||
-> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
|
-> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
|
||||||
@@ -49,6 +53,7 @@ class ConnexionMiddleware:
|
|||||||
specification: t.Union[pathlib.Path, str, dict],
|
specification: t.Union[pathlib.Path, str, dict],
|
||||||
base_path: t.Optional[str] = None,
|
base_path: t.Optional[str] = None,
|
||||||
arguments: t.Optional[dict] = None,
|
arguments: t.Optional[dict] = None,
|
||||||
|
**kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add an API to the underlying routing middleware based on a OpenAPI spec.
|
"""Add an API to the underlying routing middleware based on a OpenAPI spec.
|
||||||
|
|
||||||
@@ -56,6 +61,9 @@ class ConnexionMiddleware:
|
|||||||
:param base_path: Base path where to add this API.
|
:param base_path: Base path where to add this API.
|
||||||
:param arguments: Jinja arguments to replace in the spec.
|
:param arguments: Jinja arguments to replace in the spec.
|
||||||
"""
|
"""
|
||||||
|
for app in self.apps:
|
||||||
|
if isinstance(app, AppMiddleware):
|
||||||
|
app.add_api(specification, base_path=base_path, arguments=arguments, **kwargs)
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
await self.app(scope, receive, send)
|
await self.app(scope, receive, send)
|
||||||
|
|||||||
211
connexion/middleware/swagger_ui.py
Normal file
211
connexion/middleware/swagger_ui.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
import logging
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
import typing as t
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from starlette.responses import RedirectResponse
|
||||||
|
from starlette.responses import Response as StarletteResponse
|
||||||
|
from starlette.routing import Router
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
from starlette.templating import Jinja2Templates
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
from connexion.apis import AbstractSwaggerUIAPI
|
||||||
|
from connexion.jsonifier import JSONEncoder, Jsonifier
|
||||||
|
from connexion.utils import yamldumper
|
||||||
|
|
||||||
|
from .base import AppMiddleware
|
||||||
|
|
||||||
|
logger = logging.getLogger('connexion.middleware.swagger_ui')
|
||||||
|
|
||||||
|
|
||||||
|
_original_scope: ContextVar[Scope] = ContextVar('SCOPE')
|
||||||
|
|
||||||
|
|
||||||
|
class SwaggerUIMiddleware(AppMiddleware):
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
"""Middleware that hosts a swagger UI.
|
||||||
|
|
||||||
|
:param app: app to wrap in middleware.
|
||||||
|
"""
|
||||||
|
self.app = app
|
||||||
|
# Set default to pass unknown routes to next app
|
||||||
|
self.router = Router(default=self.default_fn)
|
||||||
|
|
||||||
|
def add_api(
|
||||||
|
self,
|
||||||
|
specification: t.Union[pathlib.Path, str, dict],
|
||||||
|
base_path: t.Optional[str] = None,
|
||||||
|
arguments: t.Optional[dict] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Add an API to the router based on a OpenAPI spec.
|
||||||
|
|
||||||
|
:param specification: OpenAPI spec as dict or path to file.
|
||||||
|
:param base_path: Base path where to add this API.
|
||||||
|
:param arguments: Jinja arguments to replace in the spec.
|
||||||
|
"""
|
||||||
|
api = SwaggerUIAPI(specification, base_path=base_path, arguments=arguments,
|
||||||
|
default=self.default_fn, **kwargs)
|
||||||
|
self.router.mount(api.base_path, app=api.router)
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
_original_scope.set(scope.copy())
|
||||||
|
await self.router(scope, receive, send)
|
||||||
|
|
||||||
|
async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
"""
|
||||||
|
Callback to call next app as default when no matching route is found.
|
||||||
|
|
||||||
|
Unfortunately we cannot just pass the next app as default, since the router manipulates
|
||||||
|
the scope when descending into mounts, losing information about the base path. Therefore,
|
||||||
|
we use the original scope instead.
|
||||||
|
|
||||||
|
This is caused by https://github.com/encode/starlette/issues/1336.
|
||||||
|
"""
|
||||||
|
original_scope = _original_scope.get()
|
||||||
|
await self.app(original_scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
|
class SwaggerUIAPI(AbstractSwaggerUIAPI):
|
||||||
|
|
||||||
|
def __init__(self, *args, default: ASGIApp, **kwargs):
|
||||||
|
self.router = Router(default=default)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._templates = Jinja2Templates(
|
||||||
|
directory=str(self.options.openapi_console_ui_from_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_string(string):
|
||||||
|
return re.sub(r"[^a-zA-Z0-9]", "_", string.strip("/"))
|
||||||
|
|
||||||
|
def _base_path_for_prefix(self, request):
|
||||||
|
"""
|
||||||
|
returns a modified basePath which includes the incoming request's
|
||||||
|
path prefix.
|
||||||
|
"""
|
||||||
|
base_path = self.base_path
|
||||||
|
if not request.url.path.startswith(self.base_path):
|
||||||
|
prefix = request.url.path.split(self.base_path)[0]
|
||||||
|
base_path = prefix + base_path
|
||||||
|
return base_path
|
||||||
|
|
||||||
|
def _spec_for_prefix(self, request):
|
||||||
|
"""
|
||||||
|
returns a spec with a modified basePath / servers block
|
||||||
|
which corresponds to the incoming request path.
|
||||||
|
This is needed when behind a path-altering reverse proxy.
|
||||||
|
"""
|
||||||
|
base_path = self._base_path_for_prefix(request)
|
||||||
|
return self.specification.with_base_path(base_path).raw
|
||||||
|
|
||||||
|
def add_openapi_json(self):
|
||||||
|
"""
|
||||||
|
Adds openapi json to {base_path}/openapi.json
|
||||||
|
(or {base_path}/swagger.json for swagger2)
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Adding spec json: %s/%s", self.base_path, self.options.openapi_spec_path
|
||||||
|
)
|
||||||
|
self.router.add_route(
|
||||||
|
methods=["GET"],
|
||||||
|
path=self.options.openapi_spec_path,
|
||||||
|
endpoint=self._get_openapi_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_openapi_yaml(self):
|
||||||
|
"""
|
||||||
|
Adds openapi json to {base_path}/openapi.json
|
||||||
|
(or {base_path}/swagger.json for swagger2)
|
||||||
|
"""
|
||||||
|
if not self.options.openapi_spec_path.endswith("json"):
|
||||||
|
return
|
||||||
|
|
||||||
|
openapi_spec_path_yaml = self.options.openapi_spec_path[: -len("json")] + "yaml"
|
||||||
|
logger.debug("Adding spec yaml: %s/%s", self.base_path, openapi_spec_path_yaml)
|
||||||
|
self.router.add_route(
|
||||||
|
methods=["GET"],
|
||||||
|
path=openapi_spec_path_yaml,
|
||||||
|
endpoint=self._get_openapi_yaml,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_openapi_json(self, request):
|
||||||
|
return StarletteResponse(
|
||||||
|
content=self.jsonifier.dumps(self._spec_for_prefix(request)),
|
||||||
|
status_code=200,
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_openapi_yaml(self, request):
|
||||||
|
return StarletteResponse(
|
||||||
|
content=yamldumper(self._spec_for_prefix(request)),
|
||||||
|
status_code=200,
|
||||||
|
media_type="text/yaml",
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_swagger_ui(self):
|
||||||
|
"""
|
||||||
|
Adds swagger ui to {base_path}/ui/
|
||||||
|
"""
|
||||||
|
console_ui_path = self.options.openapi_console_ui_path.strip().rstrip("/")
|
||||||
|
logger.debug("Adding swagger-ui: %s%s/", self.base_path, console_ui_path)
|
||||||
|
|
||||||
|
for path in (
|
||||||
|
console_ui_path + "/",
|
||||||
|
console_ui_path + "/index.html",
|
||||||
|
):
|
||||||
|
self.router.add_route(
|
||||||
|
methods=["GET"], path=path, endpoint=self._get_swagger_ui_home
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.options.openapi_console_ui_config is not None:
|
||||||
|
self.router.add_route(
|
||||||
|
methods=["GET"],
|
||||||
|
path=console_ui_path + "/swagger-ui-config.json",
|
||||||
|
endpoint=self._get_swagger_ui_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# we have to add an explicit redirect instead of relying on the
|
||||||
|
# normalize_path_middleware because we also serve static files
|
||||||
|
# from this dir (below)
|
||||||
|
|
||||||
|
async def redirect(_request):
|
||||||
|
return RedirectResponse(url=self.base_path + console_ui_path + "/")
|
||||||
|
|
||||||
|
self.router.add_route(methods=["GET"], path=console_ui_path, endpoint=redirect)
|
||||||
|
|
||||||
|
# this route will match and get a permission error when trying to
|
||||||
|
# serve index.html, so we add the redirect above.
|
||||||
|
self.router.mount(
|
||||||
|
path=console_ui_path,
|
||||||
|
app=StaticFiles(directory=str(self.options.openapi_console_ui_from_dir)),
|
||||||
|
name="swagger_ui_static",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_swagger_ui_home(self, req):
|
||||||
|
base_path = self._base_path_for_prefix(req)
|
||||||
|
template_variables = {
|
||||||
|
"request": req,
|
||||||
|
"openapi_spec_url": (base_path + self.options.openapi_spec_path),
|
||||||
|
**self.options.openapi_console_ui_index_template_variables,
|
||||||
|
}
|
||||||
|
if self.options.openapi_console_ui_config is not None:
|
||||||
|
template_variables["configUrl"] = "swagger-ui-config.json"
|
||||||
|
|
||||||
|
return self._templates.TemplateResponse("index.j2", template_variables)
|
||||||
|
|
||||||
|
async def _get_swagger_ui_config(self, request):
|
||||||
|
return StarletteResponse(
|
||||||
|
status_code=200,
|
||||||
|
media_type="application/json",
|
||||||
|
content=self.jsonifier.dumps(self.options.openapi_console_ui_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_jsonifier(cls):
|
||||||
|
cls.jsonifier = Jsonifier(cls=JSONEncoder)
|
||||||
@@ -59,7 +59,7 @@ def test_openapi_yaml_behind_proxy(reverse_proxied_app):
|
|||||||
headers=headers
|
headers=headers
|
||||||
)
|
)
|
||||||
assert openapi_yaml.status_code == 200
|
assert openapi_yaml.status_code == 200
|
||||||
assert openapi_yaml.headers.get('Content-Type') == 'text/yaml'
|
assert openapi_yaml.headers.get('Content-Type').startswith('text/yaml')
|
||||||
spec = yaml.load(openapi_yaml.data.decode('utf-8'), Loader=yaml.BaseLoader)
|
spec = yaml.load(openapi_yaml.data.decode('utf-8'), Loader=yaml.BaseLoader)
|
||||||
|
|
||||||
if reverse_proxied_app._spec_file == 'swagger.yaml':
|
if reverse_proxied_app._spec_file == 'swagger.yaml':
|
||||||
|
|||||||
Reference in New Issue
Block a user