mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 04:19:26 +00:00
Add lifespan middleware (#1676)
This PR adds a new middleware to handle lifespan events. I added this as a separate middleware so it is encapsulated and aligned for both the `FlaskApp` and `AsyncApp`. It leverages a Starlette `Router` to register and call the lifespan handler.
This commit is contained in:
@@ -11,6 +11,7 @@ from starlette.types import Receive, Scope, Send
|
||||
|
||||
from connexion.jsonifier import Jsonifier
|
||||
from connexion.middleware import ConnexionMiddleware, SpecMiddleware
|
||||
from connexion.middleware.lifespan import Lifespan
|
||||
from connexion.resolver import Resolver
|
||||
from connexion.uri_parsing import AbstractURIParser
|
||||
|
||||
@@ -32,8 +33,9 @@ class AbstractApp:
|
||||
self,
|
||||
import_name: str,
|
||||
*,
|
||||
specification_dir: t.Union[pathlib.Path, str] = "",
|
||||
lifespan: t.Optional[Lifespan] = None,
|
||||
middlewares: t.Optional[list] = None,
|
||||
specification_dir: t.Union[pathlib.Path, str] = "",
|
||||
arguments: t.Optional[dict] = None,
|
||||
auth_all_paths: t.Optional[bool] = None,
|
||||
jsonifier: t.Optional[Jsonifier] = None,
|
||||
@@ -50,11 +52,11 @@ class AbstractApp:
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
are using a single module, __name__ is always the correct value. If you however are
|
||||
using a package, it’s usually recommended to hardcode the name of your package there.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param specification_dir: The directory holding the specification(s). The provided path
|
||||
should either be absolute or relative to the root path of the application. Defaults to
|
||||
the root path.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param arguments: Arguments to substitute the specification using Jinja.
|
||||
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
|
||||
Defaults to False.
|
||||
@@ -79,8 +81,9 @@ class AbstractApp:
|
||||
self.middleware = ConnexionMiddleware(
|
||||
self.middleware_app,
|
||||
import_name=import_name,
|
||||
specification_dir=specification_dir,
|
||||
lifespan=lifespan,
|
||||
middlewares=middlewares,
|
||||
specification_dir=specification_dir,
|
||||
arguments=arguments,
|
||||
auth_all_paths=auth_all_paths,
|
||||
jsonifier=jsonifier,
|
||||
|
||||
@@ -15,6 +15,7 @@ from connexion.apps.abstract import AbstractApp
|
||||
from connexion.decorators import StarletteDecorator
|
||||
from connexion.jsonifier import Jsonifier
|
||||
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
|
||||
from connexion.middleware.lifespan import Lifespan
|
||||
from connexion.operations import AbstractOperation
|
||||
from connexion.resolver import Resolver
|
||||
from connexion.uri_parsing import AbstractURIParser
|
||||
@@ -120,8 +121,9 @@ class AsyncApp(AbstractApp):
|
||||
self,
|
||||
import_name: str,
|
||||
*,
|
||||
specification_dir: t.Union[pathlib.Path, str] = "",
|
||||
lifespan: t.Optional[Lifespan] = None,
|
||||
middlewares: t.Optional[list] = None,
|
||||
specification_dir: t.Union[pathlib.Path, str] = "",
|
||||
arguments: t.Optional[dict] = None,
|
||||
auth_all_paths: t.Optional[bool] = None,
|
||||
jsonifier: t.Optional[Jsonifier] = None,
|
||||
@@ -138,11 +140,11 @@ class AsyncApp(AbstractApp):
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
are using a single module, __name__ is always the correct value. If you however are
|
||||
using a package, it’s usually recommended to hardcode the name of your package there.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param specification_dir: The directory holding the specification(s). The provided path
|
||||
should either be absolute or relative to the root path of the application. Defaults to
|
||||
the root path.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param arguments: Arguments to substitute the specification using Jinja.
|
||||
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
|
||||
Defaults to False.
|
||||
@@ -168,8 +170,9 @@ class AsyncApp(AbstractApp):
|
||||
|
||||
super().__init__(
|
||||
import_name,
|
||||
specification_dir=specification_dir,
|
||||
lifespan=lifespan,
|
||||
middlewares=middlewares,
|
||||
specification_dir=specification_dir,
|
||||
arguments=arguments,
|
||||
auth_all_paths=auth_all_paths,
|
||||
jsonifier=jsonifier,
|
||||
|
||||
@@ -18,6 +18,7 @@ from connexion.exceptions import InternalServerError, ProblemException, Resolver
|
||||
from connexion.frameworks import flask as flask_utils
|
||||
from connexion.jsonifier import Jsonifier
|
||||
from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware
|
||||
from connexion.middleware.lifespan import Lifespan
|
||||
from connexion.operations import AbstractOperation
|
||||
from connexion.problem import problem
|
||||
from connexion.resolver import Resolver
|
||||
@@ -176,9 +177,10 @@ class FlaskApp(AbstractApp):
|
||||
self,
|
||||
import_name: str,
|
||||
*,
|
||||
lifespan: t.Optional[Lifespan] = None,
|
||||
middlewares: t.Optional[list] = None,
|
||||
server_args: t.Optional[dict] = None,
|
||||
specification_dir: t.Union[pathlib.Path, str] = "",
|
||||
middlewares: t.Optional[list] = None,
|
||||
arguments: t.Optional[dict] = None,
|
||||
auth_all_paths: t.Optional[bool] = None,
|
||||
jsonifier: t.Optional[Jsonifier] = None,
|
||||
@@ -195,12 +197,14 @@ class FlaskApp(AbstractApp):
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
are using a single module, __name__ is always the correct value. If you however are
|
||||
using a package, it’s usually recommended to hardcode the name of your package there.
|
||||
:param lifespan: A lifespan context function, which can be used to perform startup and
|
||||
shutdown tasks.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param server_args: Arguments to pass to the Flask application.
|
||||
:param specification_dir: The directory holding the specification(s). The provided path
|
||||
should either be absolute or relative to the root path of the application. Defaults to
|
||||
the root path.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param arguments: Arguments to substitute the specification using Jinja.
|
||||
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
|
||||
Defaults to False.
|
||||
@@ -226,8 +230,9 @@ class FlaskApp(AbstractApp):
|
||||
self.app = self.middleware_app.app
|
||||
super().__init__(
|
||||
import_name,
|
||||
specification_dir=specification_dir,
|
||||
lifespan=lifespan,
|
||||
middlewares=middlewares,
|
||||
specification_dir=specification_dir,
|
||||
arguments=arguments,
|
||||
auth_all_paths=auth_all_paths,
|
||||
jsonifier=jsonifier,
|
||||
|
||||
@@ -230,7 +230,7 @@ class RoutedMiddleware(SpecMiddleware, t.Generic[API]):
|
||||
api_cls: t.Type[API]
|
||||
"""The subclass of RoutedAPI this middleware uses."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
def __init__(self, app: ASGIApp, **kwargs) -> None:
|
||||
self.app = app
|
||||
self.apis: t.Dict[str, API] = {}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from starlette.exceptions import ExceptionMiddleware as StarletteExceptionMiddle
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.exceptions import InternalServerError, ProblemException, problem
|
||||
|
||||
@@ -15,8 +15,8 @@ class ExceptionMiddleware(StarletteExceptionMiddleware):
|
||||
"""Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
|
||||
existing connexion behavior."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, next_app: ASGIApp, *args, **kwargs):
|
||||
super().__init__(next_app)
|
||||
self.add_exception_handler(ProblemException, self.problem_handler)
|
||||
self.add_exception_handler(Exception, self.common_error_handler)
|
||||
|
||||
|
||||
28
connexion/middleware/lifespan.py
Normal file
28
connexion/middleware/lifespan.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import typing as t
|
||||
|
||||
from starlette.routing import Router
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
Lifespan = t.Callable[[t.Any], t.AsyncContextManager]
|
||||
|
||||
|
||||
class LifespanMiddleware:
|
||||
"""
|
||||
Middleware that adds support for Starlette lifespan handlers
|
||||
(https://www.starlette.io/lifespan/).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, next_app: ASGIApp, *, lifespan: t.Optional[Lifespan], **kwargs
|
||||
) -> None:
|
||||
self.next_app = next_app
|
||||
self._lifespan = lifespan
|
||||
# Leverage a Starlette Router for lifespan handling only
|
||||
self.router = Router(lifespan=lifespan)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# If no lifespan is registered, pass to next app so it can be handled downstream.
|
||||
if scope["type"] == "lifespan" and self._lifespan:
|
||||
await self.router(scope, receive, send)
|
||||
else:
|
||||
await self.next_app(scope, receive, send)
|
||||
@@ -12,6 +12,7 @@ from connexion.jsonifier import Jsonifier
|
||||
from connexion.middleware.abstract import SpecMiddleware
|
||||
from connexion.middleware.context import ContextMiddleware
|
||||
from connexion.middleware.exceptions import ExceptionMiddleware
|
||||
from connexion.middleware.lifespan import Lifespan, LifespanMiddleware
|
||||
from connexion.middleware.request_validation import RequestValidationMiddleware
|
||||
from connexion.middleware.response_validation import ResponseValidationMiddleware
|
||||
from connexion.middleware.routing import RoutingMiddleware
|
||||
@@ -92,6 +93,7 @@ class ConnexionMiddleware:
|
||||
RequestValidationMiddleware,
|
||||
ResponseValidationMiddleware,
|
||||
ContextMiddleware,
|
||||
LifespanMiddleware,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -99,8 +101,9 @@ class ConnexionMiddleware:
|
||||
app: ASGIApp,
|
||||
*,
|
||||
import_name: t.Optional[str] = None,
|
||||
specification_dir: t.Union[pathlib.Path, str] = "",
|
||||
lifespan: t.Optional[Lifespan] = None,
|
||||
middlewares: t.Optional[list] = None,
|
||||
specification_dir: t.Union[pathlib.Path, str] = "",
|
||||
arguments: t.Optional[dict] = None,
|
||||
auth_all_paths: t.Optional[bool] = None,
|
||||
jsonifier: t.Optional[Jsonifier] = None,
|
||||
@@ -117,11 +120,11 @@ class ConnexionMiddleware:
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
are using a single module, __name__ is always the correct value. If you however are
|
||||
using a package, it’s usually recommended to hardcode the name of your package there.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param specification_dir: The directory holding the specification(s). The provided path
|
||||
should either be absolute or relative to the root path of the application. Defaults to
|
||||
the root path.
|
||||
:param middlewares: The list of middlewares to wrap around the application. Defaults to
|
||||
:obj:`middleware.main.ConnexionmMiddleware.default_middlewares`
|
||||
:param arguments: Arguments to substitute the specification using Jinja.
|
||||
:param auth_all_paths: whether to authenticate not paths not defined in the specification.
|
||||
Defaults to False.
|
||||
@@ -150,7 +153,9 @@ class ConnexionMiddleware:
|
||||
|
||||
if middlewares is None:
|
||||
middlewares = self.default_middlewares
|
||||
self.app, self.apps = self._apply_middlewares(app, middlewares)
|
||||
self.app, self.apps = self._apply_middlewares(
|
||||
app, middlewares, lifespan=lifespan
|
||||
)
|
||||
|
||||
self.options = _Options(
|
||||
arguments=arguments,
|
||||
@@ -177,9 +182,8 @@ class ConnexionMiddleware:
|
||||
else:
|
||||
return self.root_path / path
|
||||
|
||||
@staticmethod
|
||||
def _apply_middlewares(
|
||||
app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]]
|
||||
self, app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]], **kwargs
|
||||
) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]:
|
||||
"""Apply all middlewares to the provided app.
|
||||
|
||||
@@ -193,13 +197,14 @@ class ConnexionMiddleware:
|
||||
# Include the wrapped application in the returned list.
|
||||
apps = [app]
|
||||
for middleware in reversed(middlewares):
|
||||
app = middleware(app) # type: ignore
|
||||
app = middleware(app, **kwargs) # type: ignore
|
||||
apps.append(app)
|
||||
return app, list(reversed(apps))
|
||||
|
||||
def add_api(
|
||||
self,
|
||||
specification: t.Union[pathlib.Path, str, dict],
|
||||
*,
|
||||
base_path: t.Optional[str] = None,
|
||||
arguments: t.Optional[dict] = None,
|
||||
auth_all_paths: t.Optional[bool] = None,
|
||||
|
||||
@@ -92,7 +92,7 @@ class RoutingAPI(AbstractRoutingAPI):
|
||||
|
||||
|
||||
class RoutingMiddleware(SpecMiddleware):
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
def __init__(self, app: ASGIApp, **kwargs) -> None:
|
||||
"""Middleware that resolves the Operation for an incoming request and attaches it to the
|
||||
scope.
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ class SwaggerUIAPI(AbstractSpecAPI):
|
||||
|
||||
|
||||
class SwaggerUIMiddleware(SpecMiddleware):
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
def __init__(self, app: ASGIApp, **kwargs) -> None:
|
||||
"""Middleware that hosts a swagger UI.
|
||||
|
||||
:param app: app to wrap in middleware.
|
||||
|
||||
38
tests/test_lifespan.py
Normal file
38
tests/test_lifespan.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import contextlib
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from connexion import AsyncApp, ConnexionMiddleware
|
||||
|
||||
|
||||
def test_lifespan_handler(app_class):
|
||||
m = mock.MagicMock()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app):
|
||||
m.startup()
|
||||
yield
|
||||
m.shutdown()
|
||||
|
||||
app = AsyncApp(__name__, lifespan=lifespan)
|
||||
with app.test_client():
|
||||
m.startup.assert_called()
|
||||
m.shutdown.assert_not_called()
|
||||
m.shutdown.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
|
||||
)
|
||||
async def test_lifespan():
|
||||
"""Test that lifespan events are passed through if no handler is registered."""
|
||||
lifecycle_handler = mock.Mock()
|
||||
|
||||
async def check_lifecycle(scope, receive, send):
|
||||
if scope["type"] == "lifespan":
|
||||
lifecycle_handler.handle()
|
||||
|
||||
test_app = ConnexionMiddleware(check_lifecycle)
|
||||
await test_app({"type": "lifespan"}, mock.AsyncMock(), mock.AsyncMock())
|
||||
lifecycle_handler.handle.assert_called()
|
||||
@@ -13,7 +13,7 @@ class TestMiddleware:
|
||||
|
||||
__test__ = False
|
||||
|
||||
def __init__(self, app):
|
||||
def __init__(self, app, **kwargs):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
@@ -49,19 +49,3 @@ def test_routing_middleware(middleware_app):
|
||||
assert (
|
||||
response.headers.get("operation_id") == "fakeapi.hello.post_greeting"
|
||||
), response.status_code
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8), reason="AsyncMock only available from 3.8."
|
||||
)
|
||||
async def test_lifecycle():
|
||||
"""Test that lifecycle events are passed correctly."""
|
||||
lifecycle_handler = mock.Mock()
|
||||
|
||||
async def check_lifecycle(scope, receive, send):
|
||||
if scope["type"] == "lifecycle":
|
||||
lifecycle_handler.handle()
|
||||
|
||||
test_app = ConnexionMiddleware(check_lifecycle)
|
||||
await test_app({"type": "lifecycle"}, mock.AsyncMock, mock.AsyncMock)
|
||||
lifecycle_handler.handle.assert_called()
|
||||
|
||||
Reference in New Issue
Block a user