Add lifespan middleware (#1676)

This PR adds a new middleware to handle lifespan events.

I added this as a separate middleware so it is encapsulated and aligned
for both the `FlaskApp` and `AsyncApp`. It leverages a Starlette
`Router` to register and call the lifespan handler.
This commit is contained in:
Robbe Sneyders
2023-03-23 19:11:43 +01:00
committed by GitHub
parent 8a85a4fe01
commit 79c0852c93
11 changed files with 108 additions and 42 deletions

View File

@@ -11,6 +11,7 @@ from starlette.types import Receive, Scope, Send
from connexion.jsonifier import Jsonifier
from connexion.middleware import ConnexionMiddleware, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser
@@ -32,8 +33,9 @@ class AbstractApp:
self,
import_name: str,
*,
specification_dir: t.Union[pathlib.Path, str] = "",
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
@@ -50,11 +52,11 @@ class AbstractApp:
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
@@ -79,8 +81,9 @@ class AbstractApp:
self.middleware = ConnexionMiddleware(
self.middleware_app,
import_name=import_name,
specification_dir=specification_dir,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,

View File

@@ -15,6 +15,7 @@ from connexion.apps.abstract import AbstractApp
from connexion.decorators import StarletteDecorator
from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
from connexion.uri_parsing import AbstractURIParser
@@ -120,8 +121,9 @@ class AsyncApp(AbstractApp):
self,
import_name: str,
*,
specification_dir: t.Union[pathlib.Path, str] = "",
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
@@ -138,11 +140,11 @@ class AsyncApp(AbstractApp):
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
@@ -168,8 +170,9 @@ class AsyncApp(AbstractApp):
super().__init__(
import_name,
specification_dir=specification_dir,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,

View File

@@ -18,6 +18,7 @@ from connexion.exceptions import InternalServerError, ProblemException, Resolver
from connexion.frameworks import flask as flask_utils
from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
from connexion.middleware.lifespan import Lifespan
from connexion.operations import AbstractOperation
from connexion.problem import problem
from connexion.resolver import Resolver
@@ -176,9 +177,10 @@ class FlaskApp(AbstractApp):
self,
import_name: str,
*,
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
server_args: t.Optional[dict] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
middlewares: t.Optional[list] = None,
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
@@ -195,12 +197,14 @@ class FlaskApp(AbstractApp):
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there.
:param lifespan: A lifespan context function, which can be used to perform startup and
shutdown tasks.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param server_args: Arguments to pass to the Flask application.
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
@@ -226,8 +230,9 @@ class FlaskApp(AbstractApp):
self.app = self.middleware_app.app
super().__init__(
import_name,
specification_dir=specification_dir,
lifespan=lifespan,
middlewares=middlewares,
specification_dir=specification_dir,
arguments=arguments,
auth_all_paths=auth_all_paths,
jsonifier=jsonifier,

View File

@@ -230,7 +230,7 @@ class RoutedMiddleware(SpecMiddleware, t.Generic[API]):
api_cls: t.Type[API]
"""The subclass of RoutedAPI this middleware uses."""
def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
self.app = app
self.apis: t.Dict[str, API] = {}

View File

@@ -4,7 +4,7 @@ from starlette.exceptions import ExceptionMiddleware as StarletteExceptionMiddle
from starlette.exceptions import HTTPException
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import InternalServerError, ProblemException, problem
@@ -15,8 +15,8 @@ class ExceptionMiddleware(StarletteExceptionMiddleware):
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
existing connexion behavior."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, next_app: ASGIApp, *args, **kwargs):
super().__init__(next_app)
self.add_exception_handler(ProblemException, self.problem_handler)
self.add_exception_handler(Exception, self.common_error_handler)

View File

@@ -0,0 +1,28 @@
import typing as t
from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send
Lifespan = t.Callable[[t.Any], t.AsyncContextManager]
class LifespanMiddleware:
"""
Middleware that adds support for Starlette lifespan handlers
(https://www.starlette.io/lifespan/).
"""
def __init__(
self, next_app: ASGIApp, *, lifespan: t.Optional[Lifespan], **kwargs
) -> None:
self.next_app = next_app
self._lifespan = lifespan
# Leverage a Starlette Router for lifespan handling only
self.router = Router(lifespan=lifespan)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# If no lifespan is registered, pass to next app so it can be handled downstream.
if scope["type"] == "lifespan" and self._lifespan:
await self.router(scope, receive, send)
else:
await self.next_app(scope, receive, send)

View File

@@ -12,6 +12,7 @@ from connexion.jsonifier import Jsonifier
from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.lifespan import Lifespan, LifespanMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware
@@ -92,6 +93,7 @@ class ConnexionMiddleware:
RequestValidationMiddleware,
ResponseValidationMiddleware,
ContextMiddleware,
LifespanMiddleware,
]
def __init__(
@@ -99,8 +101,9 @@ class ConnexionMiddleware:
app: ASGIApp,
*,
import_name: t.Optional[str] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
lifespan: t.Optional[Lifespan] = None,
middlewares: t.Optional[list] = None,
specification_dir: t.Union[pathlib.Path, str] = "",
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,
jsonifier: t.Optional[Jsonifier] = None,
@@ -117,11 +120,11 @@ class ConnexionMiddleware:
:param import_name: The name of the package or module that this object belongs to. If you
are using a single module, __name__ is always the correct value. If you however are
using a package, its usually recommended to hardcode the name of your package there.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param specification_dir: The directory holding the specification(s). The provided path
should either be absolute or relative to the root path of the application. Defaults to
the root path.
:param middlewares: The list of middlewares to wrap around the application. Defaults to
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
:param arguments: Arguments to substitute the specification using Jinja.
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
Defaults to False.
@@ -150,7 +153,9 @@ class ConnexionMiddleware:
if middlewares is None:
middlewares = self.default_middlewares
self.app, self.apps = self._apply_middlewares(app, middlewares)
self.app, self.apps = self._apply_middlewares(
app, middlewares, lifespan=lifespan
)
self.options = _Options(
arguments=arguments,
@@ -177,9 +182,8 @@ class ConnexionMiddleware:
else:
return self.root_path / path
@staticmethod
def _apply_middlewares(
app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]
self, app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]], **kwargs
) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
"""Apply all middlewares to the provided app.
@@ -193,13 +197,14 @@ class ConnexionMiddleware:
# Include the wrapped application in the returned list.
apps = [app]
for middleware in reversed(middlewares):
app = middleware(app) # type: ignore
app = middleware(app, **kwargs) # type: ignore
apps.append(app)
return app, list(reversed(apps))
def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
*,
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
auth_all_paths: t.Optional[bool] = None,

View File

@@ -92,7 +92,7 @@ class RoutingAPI(AbstractRoutingAPI):
class RoutingMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that resolves the Operation for an incoming request and attaches it to the
scope.

View File

@@ -180,7 +180,7 @@ class SwaggerUIAPI(AbstractSpecAPI):
class SwaggerUIMiddleware(SpecMiddleware):
def __init__(self, app: ASGIApp) -> None:
def __init__(self, app: ASGIApp, **kwargs) -> None:
"""Middleware that hosts a swagger UI.
:param app: app to wrap in middleware.

38
tests/test_lifespan.py Normal file
View File

@@ -0,0 +1,38 @@
import contextlib
import sys
from unittest import mock
import pytest
from connexion import AsyncApp, ConnexionMiddleware
def test_lifespan_handler(app_class):
m = mock.MagicMock()
@contextlib.asynccontextmanager
async def lifespan(app):
m.startup()
yield
m.shutdown()
app = AsyncApp(__name__, lifespan=lifespan)
with app.test_client():
m.startup.assert_called()
m.shutdown.assert_not_called()
m.shutdown.assert_called()
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifespan():
"""Test that lifespan events are passed through if no handler is registered."""
lifecycle_handler = mock.Mock()
async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifespan":
lifecycle_handler.handle()
test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifespan"}, mock.AsyncMock(), mock.AsyncMock())
lifecycle_handler.handle.assert_called()

View File

@@ -13,7 +13,7 @@ class TestMiddleware:
__test__ = False
def __init__(self, app):
def __init__(self, app, **kwargs):
self.app = app
async def __call__(self, scope, receive, send):
@@ -49,19 +49,3 @@ def test_routing_middleware(middleware_app):
assert (
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
), response.status_code
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
)
async def test_lifecycle():
"""Test that lifecycle events are passed correctly."""
lifecycle_handler = mock.Mock()
async def check_lifecycle(scope, receive, send):
if scope["type"] == "lifecycle":
lifecycle_handler.handle()
test_app = ConnexionMiddleware(check_lifecycle)
await test_app({"type": "lifecycle"}, mock.AsyncMock, mock.AsyncMock)
lifecycle_handler.handle.assert_called()