Extract Swagger UI functionality into middleware (#1496)

* Extract swagger UI functionality from AbstractAPI

* Extract Swagger UI functionality into middleware

Co-authored-by: Wojciech Paciorek <arkkors@users.noreply.github.com>

* Add additional docstrings

Co-authored-by: Wojciech Paciorek <arkkors@users.noreply.github.com>
This commit is contained in:
Robbe Sneyders
2022-04-10 17:15:27 +02:00
committed by GitHub
parent 895d3d475a
commit 41c19c1127
10 changed files with 335 additions and 177 deletions

View File

@@ -13,4 +13,4 @@ on the framework app.
""" """
from .abstract import AbstractAPI # NOQA from .abstract import AbstractAPI, AbstractSwaggerUIAPI # NOQA

View File

@@ -32,7 +32,84 @@ class AbstractAPIMeta(abc.ABCMeta):
cls._set_jsonifier() cls._set_jsonifier()
class AbstractAPI(metaclass=AbstractAPIMeta): class AbstractSpecAPI(metaclass=AbstractAPIMeta):
def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
options: t.Optional[dict] = None,
*args,
**kwargs
):
"""Base API class with only minimal behavior related to the specification."""
logger.debug('Loading specification: %s', specification,
extra={'swagger_yaml': specification,
'base_path': base_path,
'arguments': arguments})
# Avoid validator having ability to modify specification
self.specification = Specification.load(specification, arguments=arguments)
logger.debug('Read specification', extra={'spec': self.specification})
self.options = ConnexionOptions(options, oas_version=self.specification.version)
logger.debug('Options Loaded',
extra={'swagger_ui': self.options.openapi_console_ui_available,
'swagger_path': self.options.openapi_console_ui_from_dir,
'swagger_url': self.options.openapi_console_ui_path})
self._set_base_path(base_path)
def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
if base_path is not None:
# update spec to include user-provided base_path
self.specification.base_path = base_path
self.base_path = base_path
else:
self.base_path = self.specification.base_path
@classmethod
def _set_jsonifier(cls):
cls.jsonifier = Jsonifier()
class AbstractSwaggerUIAPI(AbstractSpecAPI):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.options.openapi_spec_available:
self.add_openapi_json()
self.add_openapi_yaml()
if self.options.openapi_console_ui_available:
self.add_swagger_ui()
@abc.abstractmethod
def add_openapi_json(self):
"""
Adds openapi spec to {base_path}/openapi.json
(or {base_path}/swagger.json for swagger2)
"""
@abc.abstractmethod
def add_openapi_yaml(self):
"""
Adds openapi spec to {base_path}/openapi.yaml
(or {base_path}/swagger.yaml for swagger2)
"""
@abc.abstractmethod
def add_swagger_ui(self):
"""
Adds swagger ui to {base_path}/ui/
"""
class AbstractAPI(AbstractSpecAPI):
""" """
Defines an abstract interface for a Swagger API Defines an abstract interface for a Swagger API
""" """
@@ -107,12 +184,7 @@ class AbstractAPI(metaclass=AbstractAPIMeta):
self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name) self.security_handler_factory = self.make_security_handler_factory(pass_context_arg_name)
if self.options.openapi_spec_available: super().__init__(specification, base_path=base_path, arguments=arguments, options=options)
self.add_openapi_json()
self.add_openapi_yaml()
if self.options.openapi_console_ui_available:
self.add_swagger_ui()
self.add_paths() self.add_paths()
@@ -122,27 +194,6 @@ class AbstractAPI(metaclass=AbstractAPIMeta):
self.specification.security_definitions self.specification.security_definitions
) )
def _set_base_path(self, base_path=None):
if base_path is not None:
# update spec to include user-provided base_path
self.specification.base_path = base_path
self.base_path = base_path
else:
self.base_path = self.specification.base_path
@abc.abstractmethod
def add_openapi_json(self):
"""
Adds openapi spec to {base_path}/openapi.json
(or {base_path}/swagger.json for swagger2)
"""
@abc.abstractmethod
def add_swagger_ui(self):
"""
Adds swagger ui to {base_path}/ui/
"""
@abc.abstractmethod @abc.abstractmethod
def add_auth_on_not_found(self, security, security_definitions): def add_auth_on_not_found(self, security, security_definitions):
""" """
@@ -422,7 +473,3 @@ class AbstractAPI(metaclass=AbstractAPIMeta):
def json_loads(self, data): def json_loads(self, data):
return self.jsonifier.loads(data) return self.jsonifier.loads(data)
@classmethod
def _set_jsonifier(cls):
cls.jsonifier = Jsonifier()

View File

@@ -4,7 +4,6 @@ Connexion requests / responses.
""" """
import logging import logging
import pathlib
import warnings import warnings
from typing import Any from typing import Any
@@ -18,7 +17,7 @@ from connexion.handlers import AuthErrorHandler
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionRequest, ConnexionResponse from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.security import FlaskSecurityHandlerFactory from connexion.security import FlaskSecurityHandlerFactory
from connexion.utils import is_json_mimetype, yamldumper from connexion.utils import is_json_mimetype
logger = logging.getLogger('connexion.apis.flask_api') logger = logging.getLogger('connexion.apis.flask_api')
@@ -40,72 +39,6 @@ class FlaskApi(AbstractAPI):
self.blueprint = flask.Blueprint(endpoint, __name__, url_prefix=self.base_path, self.blueprint = flask.Blueprint(endpoint, __name__, url_prefix=self.base_path,
template_folder=str(self.options.openapi_console_ui_from_dir)) template_folder=str(self.options.openapi_console_ui_from_dir))
def add_openapi_json(self):
"""
Adds spec json to {base_path}/swagger.json
or {base_path}/openapi.json (for oas3)
"""
logger.debug('Adding spec json: %s/%s', self.base_path,
self.options.openapi_spec_path)
endpoint_name = f"{self.blueprint.name}_openapi_json"
self.blueprint.add_url_rule(self.options.openapi_spec_path,
endpoint_name,
self._handlers.get_json_spec)
def add_openapi_yaml(self):
"""
Adds spec yaml to {base_path}/swagger.yaml
or {base_path}/openapi.yaml (for oas3)
"""
if not self.options.openapi_spec_path.endswith("json"):
return
openapi_spec_path_yaml = \
self.options.openapi_spec_path[:-len("json")] + "yaml"
logger.debug('Adding spec yaml: %s/%s', self.base_path,
openapi_spec_path_yaml)
endpoint_name = f"{self.blueprint.name}_openapi_yaml"
self.blueprint.add_url_rule(
openapi_spec_path_yaml,
endpoint_name,
self._handlers.get_yaml_spec
)
def add_swagger_ui(self):
"""
Adds swagger ui to {base_path}/ui/
"""
console_ui_path = self.options.openapi_console_ui_path.strip('/')
logger.debug('Adding swagger-ui: %s/%s/',
self.base_path,
console_ui_path)
if self.options.openapi_console_ui_config is not None:
config_endpoint_name = f"{self.blueprint.name}_swagger_ui_config"
config_file_url = '/{console_ui_path}/swagger-ui-config.json'.format(
console_ui_path=console_ui_path)
self.blueprint.add_url_rule(config_file_url,
config_endpoint_name,
lambda: flask.jsonify(self.options.openapi_console_ui_config))
static_endpoint_name = f"{self.blueprint.name}_swagger_ui_static"
static_files_url = '/{console_ui_path}/<path:filename>'.format(
console_ui_path=console_ui_path)
self.blueprint.add_url_rule(static_files_url,
static_endpoint_name,
self._handlers.console_ui_static_files)
index_endpoint_name = f"{self.blueprint.name}_swagger_ui_index"
console_ui_url = '/{console_ui_path}/'.format(
console_ui_path=console_ui_path)
self.blueprint.add_url_rule(console_ui_url,
index_endpoint_name,
self._handlers.console_ui_home)
def add_auth_on_not_found(self, security, security_definitions): def add_auth_on_not_found(self, security, security_definitions):
""" """
Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass. Adds a 404 error handler to authenticate and only expose the 404 status if the security validation pass.
@@ -127,13 +60,6 @@ class FlaskApi(AbstractAPI):
function = operation.function function = operation.function
self.blueprint.add_url_rule(flask_path, endpoint_name, function, methods=[method]) self.blueprint.add_url_rule(flask_path, endpoint_name, function, methods=[method])
@property
def _handlers(self):
# type: () -> InternalHandlers
if not hasattr(self, '_internal_handlers'):
self._internal_handlers = InternalHandlers(self.base_path, self.options, self.specification)
return self._internal_handlers
@classmethod @classmethod
def get_response(cls, response, mimetype=None, request=None): def get_response(cls, response, mimetype=None, request=None):
"""Gets ConnexionResponse instance for the operation handler """Gets ConnexionResponse instance for the operation handler
@@ -267,65 +193,3 @@ def _get_context():
context = LocalProxy(_get_context) context = LocalProxy(_get_context)
class InternalHandlers:
"""
Flask handlers for internally registered endpoints.
"""
def __init__(self, base_path, options, specification):
self.base_path = base_path
self.options = options
self.specification = specification
def console_ui_home(self):
"""
Home page of the OpenAPI Console UI.
:return:
"""
openapi_json_route_name = "{blueprint}.{prefix}_openapi_json"
escaped = flask_utils.flaskify_endpoint(self.base_path)
openapi_json_route_name = openapi_json_route_name.format(
blueprint=escaped,
prefix=escaped
)
template_variables = {
'openapi_spec_url': flask.url_for(openapi_json_route_name),
**self.options.openapi_console_ui_index_template_variables,
}
if self.options.openapi_console_ui_config is not None:
template_variables['configUrl'] = 'swagger-ui-config.json'
# Use `render_template_string` instead of `render_template` to circumvent the flask
# template lookup mechanism and explicitly render the template of the current blueprint.
# https://github.com/zalando/connexion/issues/1289#issuecomment-884105076
template_dir = pathlib.Path(self.options.openapi_console_ui_from_dir)
index_path = template_dir / 'index.j2'
return flask.render_template_string(index_path.read_text(), **template_variables)
def console_ui_static_files(self, filename):
"""
Servers the static files for the OpenAPI Console UI.
:param filename: Requested file contents.
:return:
"""
# convert PosixPath to str
static_dir = str(self.options.openapi_console_ui_from_dir)
return flask.send_from_directory(static_dir, filename)
def get_json_spec(self):
return flask.jsonify(self._spec_for_prefix())
def get_yaml_spec(self):
return yamldumper(self._spec_for_prefix()), 200, {"Content-Type": "text/yaml"}
def _spec_for_prefix(self):
"""
Modify base_path in the spec based on incoming url
This fixes problems with reverse proxies changing the path.
"""
base_path = flask.url_for(flask.request.endpoint).rsplit("/", 1)[0]
return self.specification.with_base_path(base_path).raw

View File

@@ -55,7 +55,7 @@ class AbstractApp(metaclass=abc.ABCMeta):
self.server = server self.server = server
self.server_args = dict() if server_args is None else server_args self.server_args = dict() if server_args is None else server_args
self.app = self.create_app() self.app = self.create_app()
self._apply_middleware() self.middleware = self._apply_middleware()
# we get our application root path to avoid duplicating logic # we get our application root path to avoid duplicating logic
self.root_path = self.get_root_path() self.root_path = self.get_root_path()
@@ -153,6 +153,22 @@ class AbstractApp(metaclass=abc.ABCMeta):
api_options = self.options.extend(options) api_options = self.options.extend(options)
self.middleware.add_api(
specification,
base_path=base_path,
arguments=arguments,
resolver=resolver,
resolver_error_handler=resolver_error_handler,
validate_responses=validate_responses,
strict_validation=strict_validation,
auth_all_paths=auth_all_paths,
debug=self.debug,
validator_map=validator_map,
pythonic_params=pythonic_params,
pass_context_arg_name=pass_context_arg_name,
options=api_options.as_dict()
)
api = self.api_cls(specification, api = self.api_cls(specification,
base_path=base_path, base_path=base_path,
arguments=arguments, arguments=arguments,

View File

@@ -31,7 +31,6 @@ class FlaskApp(AbstractApp):
See :class:`~connexion.AbstractApp` for additional parameters. See :class:`~connexion.AbstractApp` for additional parameters.
""" """
self.extra_files = extra_files or [] self.extra_files = extra_files or []
self.middleware = None
super().__init__(import_name, FlaskApi, server=server, **kwargs) super().__init__(import_name, FlaskApi, server=server, **kwargs)
@@ -45,10 +44,12 @@ class FlaskApp(AbstractApp):
def _apply_middleware(self): def _apply_middleware(self):
middlewares = [*ConnexionMiddleware.default_middlewares, middlewares = [*ConnexionMiddleware.default_middlewares,
a2wsgi.WSGIMiddleware] a2wsgi.WSGIMiddleware]
self.middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares) middleware = ConnexionMiddleware(self.app.wsgi_app, middlewares=middlewares)
# Wrap with ASGI to WSGI middleware for usage with development server and test client # Wrap with ASGI to WSGI middleware for usage with development server and test client
self.app.wsgi_app = a2wsgi.ASGIMiddleware(self.middleware) self.app.wsgi_app = a2wsgi.ASGIMiddleware(middleware)
return middleware
def get_root_path(self): def get_root_path(self):
return pathlib.Path(self.app.root_path) return pathlib.Path(self.app.root_path)

View File

@@ -1 +1,2 @@
from .main import ConnexionMiddleware # NOQA from .main import ConnexionMiddleware # NOQA
from .swagger_ui import SwaggerUIMiddleware # NOQA

View File

@@ -0,0 +1,10 @@
import abc
import pathlib
import typing as t
class AppMiddleware(abc.ABC):
@abc.abstractmethod
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> None:
pass

View File

@@ -1,12 +1,18 @@
import pathlib import pathlib
import typing as t import typing as t
from starlette.exceptions import ExceptionMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.middleware.base import AppMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
class ConnexionMiddleware: class ConnexionMiddleware:
default_middlewares = [ default_middlewares = [
ExceptionMiddleware,
SwaggerUIMiddleware,
] ]
def __init__( def __init__(
@@ -25,8 +31,6 @@ class ConnexionMiddleware:
middlewares = self.default_middlewares middlewares = self.default_middlewares
self.app, self.apps = self._apply_middlewares(app, middlewares) self.app, self.apps = self._apply_middlewares(app, middlewares)
self._routing_middleware = None
@staticmethod @staticmethod
def _apply_middlewares(app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]) \ def _apply_middlewares(app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]) \
-> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]: -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
@@ -49,6 +53,7 @@ class ConnexionMiddleware:
specification: t.Union[pathlib.Path, str, dict], specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None, base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None, arguments: t.Optional[dict] = None,
**kwargs
) -> None: ) -> None:
"""Add an API to the underlying routing middleware based on a OpenAPI spec. """Add an API to the underlying routing middleware based on a OpenAPI spec.
@@ -56,6 +61,9 @@ class ConnexionMiddleware:
:param base_path: Base path where to add this API. :param base_path: Base path where to add this API.
:param arguments: Jinja arguments to replace in the spec. :param arguments: Jinja arguments to replace in the spec.
""" """
for app in self.apps:
if isinstance(app, AppMiddleware):
app.add_api(specification, base_path=base_path, arguments=arguments, **kwargs)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send) await self.app(scope, receive, send)

View File

@@ -0,0 +1,211 @@
import logging
import pathlib
import re
import typing as t
from contextvars import ContextVar
from starlette.responses import RedirectResponse
from starlette.responses import Response as StarletteResponse
from starlette.routing import Router
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractSwaggerUIAPI
from connexion.jsonifier import JSONEncoder, Jsonifier
from connexion.utils import yamldumper
from .base import AppMiddleware
logger = logging.getLogger('connexion.middleware.swagger_ui')
_original_scope: ContextVar[Scope] = ContextVar('SCOPE')
class SwaggerUIMiddleware(AppMiddleware):
def __init__(self, app: ASGIApp) -> None:
"""Middleware that hosts a swagger UI.
:param app: app to wrap in middleware.
"""
self.app = app
# Set default to pass unknown routes to next app
self.router = Router(default=self.default_fn)
def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
**kwargs
) -> None:
"""Add an API to the router based on a OpenAPI spec.
:param specification: OpenAPI spec as dict or path to file.
:param base_path: Base path where to add this API.
:param arguments: Jinja arguments to replace in the spec.
"""
api = SwaggerUIAPI(specification, base_path=base_path, arguments=arguments,
default=self.default_fn, **kwargs)
self.router.mount(api.base_path, app=api.router)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
_original_scope.set(scope.copy())
await self.router(scope, receive, send)
async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None:
"""
Callback to call next app as default when no matching route is found.
Unfortunately we cannot just pass the next app as default, since the router manipulates
the scope when descending into mounts, losing information about the base path. Therefore,
we use the original scope instead.
This is caused by https://github.com/encode/starlette/issues/1336.
"""
original_scope = _original_scope.get()
await self.app(original_scope, receive, send)
class SwaggerUIAPI(AbstractSwaggerUIAPI):
def __init__(self, *args, default: ASGIApp, **kwargs):
self.router = Router(default=default)
super().__init__(*args, **kwargs)
self._templates = Jinja2Templates(
directory=str(self.options.openapi_console_ui_from_dir)
)
@staticmethod
def normalize_string(string):
return re.sub(r"[^a-zA-Z0-9]", "_", string.strip("/"))
def _base_path_for_prefix(self, request):
"""
returns a modified basePath which includes the incoming request's
path prefix.
"""
base_path = self.base_path
if not request.url.path.startswith(self.base_path):
prefix = request.url.path.split(self.base_path)[0]
base_path = prefix + base_path
return base_path
def _spec_for_prefix(self, request):
"""
returns a spec with a modified basePath / servers block
which corresponds to the incoming request path.
This is needed when behind a path-altering reverse proxy.
"""
base_path = self._base_path_for_prefix(request)
return self.specification.with_base_path(base_path).raw
def add_openapi_json(self):
"""
Adds openapi json to {base_path}/openapi.json
(or {base_path}/swagger.json for swagger2)
"""
logger.info(
"Adding spec json: %s/%s", self.base_path, self.options.openapi_spec_path
)
self.router.add_route(
methods=["GET"],
path=self.options.openapi_spec_path,
endpoint=self._get_openapi_json,
)
def add_openapi_yaml(self):
"""
Adds openapi json to {base_path}/openapi.json
(or {base_path}/swagger.json for swagger2)
"""
if not self.options.openapi_spec_path.endswith("json"):
return
openapi_spec_path_yaml = self.options.openapi_spec_path[: -len("json")] + "yaml"
logger.debug("Adding spec yaml: %s/%s", self.base_path, openapi_spec_path_yaml)
self.router.add_route(
methods=["GET"],
path=openapi_spec_path_yaml,
endpoint=self._get_openapi_yaml,
)
async def _get_openapi_json(self, request):
return StarletteResponse(
content=self.jsonifier.dumps(self._spec_for_prefix(request)),
status_code=200,
media_type="application/json",
)
async def _get_openapi_yaml(self, request):
return StarletteResponse(
content=yamldumper(self._spec_for_prefix(request)),
status_code=200,
media_type="text/yaml",
)
def add_swagger_ui(self):
"""
Adds swagger ui to {base_path}/ui/
"""
console_ui_path = self.options.openapi_console_ui_path.strip().rstrip("/")
logger.debug("Adding swagger-ui: %s%s/", self.base_path, console_ui_path)
for path in (
console_ui_path + "/",
console_ui_path + "/index.html",
):
self.router.add_route(
methods=["GET"], path=path, endpoint=self._get_swagger_ui_home
)
if self.options.openapi_console_ui_config is not None:
self.router.add_route(
methods=["GET"],
path=console_ui_path + "/swagger-ui-config.json",
endpoint=self._get_swagger_ui_config,
)
# we have to add an explicit redirect instead of relying on the
# normalize_path_middleware because we also serve static files
# from this dir (below)
async def redirect(_request):
return RedirectResponse(url=self.base_path + console_ui_path + "/")
self.router.add_route(methods=["GET"], path=console_ui_path, endpoint=redirect)
# this route will match and get a permission error when trying to
# serve index.html, so we add the redirect above.
self.router.mount(
path=console_ui_path,
app=StaticFiles(directory=str(self.options.openapi_console_ui_from_dir)),
name="swagger_ui_static",
)
async def _get_swagger_ui_home(self, req):
base_path = self._base_path_for_prefix(req)
template_variables = {
"request": req,
"openapi_spec_url": (base_path + self.options.openapi_spec_path),
**self.options.openapi_console_ui_index_template_variables,
}
if self.options.openapi_console_ui_config is not None:
template_variables["configUrl"] = "swagger-ui-config.json"
return self._templates.TemplateResponse("index.j2", template_variables)
async def _get_swagger_ui_config(self, request):
return StarletteResponse(
status_code=200,
media_type="application/json",
content=self.jsonifier.dumps(self.options.openapi_console_ui_config),
)
@classmethod
def _set_jsonifier(cls):
cls.jsonifier = Jsonifier(cls=JSONEncoder)

View File

@@ -59,7 +59,7 @@ def test_openapi_yaml_behind_proxy(reverse_proxied_app):
headers=headers headers=headers
) )
assert openapi_yaml.status_code == 200 assert openapi_yaml.status_code == 200
assert openapi_yaml.headers.get('Content-Type') == 'text/yaml' assert openapi_yaml.headers.get('Content-Type').startswith('text/yaml')
spec = yaml.load(openapi_yaml.data.decode('utf-8'), Loader=yaml.BaseLoader) spec = yaml.load(openapi_yaml.data.decode('utf-8'), Loader=yaml.BaseLoader)
if reverse_proxied_app._spec_file == 'swagger.yaml': if reverse_proxied_app._spec_file == 'swagger.yaml':