Add lifespan middleware (#1676)

This PR adds a new middleware to handle lifespan events.

I added this as a separate middleware so it is encapsulated and aligned
for both the `FlaskApp` and `AsyncApp`. It leverages a Starlette
`Router` to register and call the lifespan handler.
This commit is contained in:
Robbe Sneyders
2023-03-23 19:11:43 +01:00
committed by GitHub
parent 8a85a4fe01
commit 79c0852c93
11 changed files with 108 additions and 42 deletions

View File

@@ -11,6 +11,7 @@ from starlette.types import Receive, Scope, Send
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.middleware import ConnexionMiddleware, SpecMiddleware from connexion.middleware import ConnexionMiddleware, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.resolver import Resolver from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
@@ -32,8 +33,9 @@ class AbstractApp:
self, self,
import_name: str, import_name: str,
*, *,
specification_dir: t.Union[pathlib.Path, str] = "", lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None, middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None, arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None, auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None, jsonifier: t.Optional[Jsonifier] = None,
@@ -50,11 +52,11 @@ class AbstractApp:
:param import_name: The name of the package or module that this object belongs to. If you :param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there. using a package, its usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path :param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to should either be absolute or relative to the root path of the application. Defaults to
the root path. the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja. :param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification. :param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False. Defaults to False.
@@ -79,8 +81,9 @@ class AbstractApp:
self.middleware = ConnexionMiddleware( self.middleware = ConnexionMiddleware(
self.middleware_app, self.middleware_app,
import_name=import_name, import_name=import_name,
specification_dir=specification_dir, lifespan=lifespan,
middlewares=middlewares, middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments, arguments=arguments,
auth_all_paths=auth_all_paths, auth_all_paths=auth_all_paths,
jsonifier=jsonifier, jsonifier=jsonifier,

View File

@@ -15,6 +15,7 @@ from connexion.apps.abstract import AbstractApp
from connexion.decorators import StarletteDecorator from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation from connexion.operations import AbstractOperation
from connexion.resolver import Resolver from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
@@ -120,8 +121,9 @@ class AsyncApp(AbstractApp):
self, self,
import_name: str, import_name: str,
*, *,
specification_dir: t.Union[pathlib.Path, str] = "", lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None, middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None, arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None, auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None, jsonifier: t.Optional[Jsonifier] = None,
@@ -138,11 +140,11 @@ class AsyncApp(AbstractApp):
:param import_name: The name of the package or module that this object belongs to. If you :param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there. using a package, its usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path :param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to should either be absolute or relative to the root path of the application. Defaults to
the root path. the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja. :param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification. :param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False. Defaults to False.
@@ -168,8 +170,9 @@ class AsyncApp(AbstractApp):
super().__init__( super().__init__(
import_name, import_name,
specification_dir=specification_dir, lifespan=lifespan,
middlewares=middlewares, middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments, arguments=arguments,
auth_all_paths=auth_all_paths, auth_all_paths=auth_all_paths,
jsonifier=jsonifier, jsonifier=jsonifier,

View File

@@ -18,6 +18,7 @@ from connexion.exceptions import InternalServerError, ProblemException, Resolver
from connexion.frameworks import flask as flask_utils from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation from connexion.operations import AbstractOperation
from connexion.problem import problem from connexion.problem import problem
from connexion.resolver import Resolver from connexion.resolver import Resolver
@@ -176,9 +177,10 @@ class FlaskApp(AbstractApp):
self, self,
import_name: str, import_name: str,
*, *,
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
server_args: t.Optional[dict] = None, server_args: t.Optional[dict] = None,
specification_dir: t.Union[pathlib.Path, str] = "", specification_dir: t.Union[pathlib.Path, str] = "",
middlewares: t.Optional[list] = None,
arguments: t.Optional[dict] = None, arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None, auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None, jsonifier: t.Optional[Jsonifier] = None,
@@ -195,12 +197,14 @@ class FlaskApp(AbstractApp):
:param import_name: The name of the package or module that this object belongs to. If you :param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there. using a package, its usually recommended to hardcode the name of your package there.
:param lifespan: A lifespan context function, which can be used to perform startup and
shutdown tasks.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param server_args: Arguments to pass to the Flask application. :param server_args: Arguments to pass to the Flask application.
:param specification_dir: The directory holding the specification(s). The provided path :param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to should either be absolute or relative to the root path of the application. Defaults to
the root path. the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja. :param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification. :param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False. Defaults to False.
@@ -226,8 +230,9 @@ class FlaskApp(AbstractApp):
self.app = self.middleware_app.app self.app = self.middleware_app.app
super().__init__( super().__init__(
import_name, import_name,
specification_dir=specification_dir, lifespan=lifespan,
middlewares=middlewares, middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments, arguments=arguments,
auth_all_paths=auth_all_paths, auth_all_paths=auth_all_paths,
jsonifier=jsonifier, jsonifier=jsonifier,

View File

@@ -230,7 +230,7 @@ class RoutedMiddleware(SpecMiddleware, t.Generic[API]):
api_cls: t.Type[API] api_cls: t.Type[API]
"""The subclass of RoutedAPI this middleware uses.""" """The subclass of RoutedAPI this middleware uses."""
def __init__(self, app: ASGIApp) -> None: def __init__(self, app: ASGIApp, **kwargs) -> None:
self.app = app self.app = app
self.apis: t.Dict[str, API] = {} self.apis: t.Dict[str, API] = {}

View File

@@ -4,7 +4,7 @@ from starlette.exceptions import ExceptionMiddleware as StarletteExceptionMiddle
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request as StarletteRequest from starlette.requests import Request as StarletteRequest
from starlette.responses import Response from starlette.responses import Response
from starlette.types import Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import InternalServerError, ProblemException, problem from connexion.exceptions import InternalServerError, ProblemException, problem
@@ -15,8 +15,8 @@ class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior.""" existing connexion behavior."""
def __init__(self, *args, **kwargs): def __init__(self, next_app: ASGIApp, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(next_app)
self.add_exception_handler(ProblemException, self.problem_handler) self.add_exception_handler(ProblemException, self.problem_handler)
self.add_exception_handler(Exception, self.common_error_handler) self.add_exception_handler(Exception, self.common_error_handler)

View File

@@ -0,0 +1,28 @@
import typing as t
from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send
Lifespan = t.Callable[[t.Any], t.AsyncContextManager]
class LifespanMiddleware:
"""
Middleware that adds support for Starlette lifespan handlers
(https://www.starlette.io/lifespan/).
"""
def __init__(
self, next_app: ASGIApp, *, lifespan: t.Optional[Lifespan], **kwargs
) -> None:
self.next_app = next_app
self._lifespan = lifespan
# Leverage a Starlette Router for lifespan handling only
self.router = Router(lifespan=lifespan)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# If no lifespan is registered, pass to next app so it can be handled downstream.
if scope["type"] == "lifespan" and self._lifespan:
await self.router(scope, receive, send)
else:
await self.next_app(scope, receive, send)

View File

@@ -12,6 +12,7 @@ from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import SpecMiddleware from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.lifespan import Lifespan, LifespanMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware from connexion.middleware.routing import RoutingMiddleware
@@ -92,6 +93,7 @@ class ConnexionMiddleware:
RequestValidationMiddleware, RequestValidationMiddleware,
ResponseValidationMiddleware, ResponseValidationMiddleware,
ContextMiddleware, ContextMiddleware,
LifespanMiddleware,
] ]
def __init__( def __init__(
@@ -99,8 +101,9 @@ class ConnexionMiddleware:
app: ASGIApp, app: ASGIApp,
*, *,
import_name: t.Optional[str] = None, import_name: t.Optional[str] = None,
specification_dir: t.Union[pathlib.Path, str] = "", lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None, middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None, arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None, auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None, jsonifier: t.Optional[Jsonifier] = None,
@@ -117,11 +120,11 @@ class ConnexionMiddleware:
:param import_name: The name of the package or module that this object belongs to. If you :param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there. using a package, its usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path :param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to should either be absolute or relative to the root path of the application. Defaults to
the root path. the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja. :param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification. :param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False. Defaults to False.
@@ -150,7 +153,9 @@ class ConnexionMiddleware:
if middlewares is None: if middlewares is None:
middlewares = self.default_middlewares middlewares = self.default_middlewares
self.app, self.apps = self._apply_middlewares(app, middlewares) self.app, self.apps = self._apply_middlewares(
app, middlewares, lifespan=lifespan
)
self.options = _Options( self.options = _Options(
arguments=arguments, arguments=arguments,
@@ -177,9 +182,8 @@ class ConnexionMiddleware:
else: else:
return self.root_path / path return self.root_path / path
@staticmethod
def _apply_middlewares( def _apply_middlewares(
app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]] self, app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]], **kwargs
) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]: ) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
"""Apply all middlewares to the provided app. """Apply all middlewares to the provided app.
@@ -193,13 +197,14 @@ class ConnexionMiddleware:
# Include the wrapped application in the returned list. # Include the wrapped application in the returned list.
apps = [app] apps = [app]
for middleware in reversed(middlewares): for middleware in reversed(middlewares):
app = middleware(app) # type: ignore app = middleware(app, **kwargs) # type: ignore
apps.append(app) apps.append(app)
return app, list(reversed(apps)) return app, list(reversed(apps))
def add_api( def add_api(
self, self,
specification: t.Union[pathlib.Path, str, dict], specification: t.Union[pathlib.Path, str, dict],
*,
base_path: t.Optional[str] = None, base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None, arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None, auth_all_paths: t.Optional[bool] = None,

View File

@@ -92,7 +92,7 @@ class RoutingAPI(AbstractRoutingAPI):
class RoutingMiddleware(SpecMiddleware): class RoutingMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None: def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that resolves the Operation for an incoming request and attaches it to the """Middleware that resolves the Operation for an incoming request and attaches it to the
scope. scope.

View File

@@ -180,7 +180,7 @@ class SwaggerUIAPI(AbstractSpecAPI):
class SwaggerUIMiddleware(SpecMiddleware): class SwaggerUIMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None: def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that hosts a swagger UI. """Middleware that hosts a swagger UI.
:param app: app to wrap in middleware. :param app: app to wrap in middleware.

38
tests/test_lifespan.py Normal file
View File

@@ -0,0 +1,38 @@
import contextlib
import sys
from unittest import mock
import pytest
from connexion import AsyncApp, ConnexionMiddleware
def test_lifespan_handler(app_class):
m = mock.MagicMock()
@contextlib.asynccontextmanager
async def lifespan(app):
m.startup()
yield
m.shutdown()
app = AsyncApp(__name__, lifespan=lifespan)
with app.test_client():
m.startup.assert_called()
m.shutdown.assert_not_called()
m.shutdown.assert_called()
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifespan():
"""Test that lifespan events are passed through if no handler is registered."""
lifecycle_handler = mock.Mock()
async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifespan":
lifecycle_handler.handle()
test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifespan"}, mock.AsyncMock(), mock.AsyncMock())
lifecycle_handler.handle.assert_called()

View File

@@ -13,7 +13,7 @@ class TestMiddleware:
__test__ = False __test__ = False
def __init__(self, app): def __init__(self, app, **kwargs):
self.app = app self.app = app
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
@@ -49,19 +49,3 @@ def test_routing_middleware(middleware_app):
assert ( assert (
response.headers.get("operation_id") == "fakeapi.hello.post_greeting" response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
), response.status_code ), response.status_code
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifecycle():
"""Test that lifecycle events are passed correctly."""
lifecycle_handler = mock.Mock()
async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifecycle":
lifecycle_handler.handle()
test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifecycle"}, mock.AsyncMock, mock.AsyncMock)
lifecycle_handler.handle.assert_called()