diff --git a/connexion/apps/abstract.py b/connexion/apps/abstract.py index 69fc261..641e3b9 100644 --- a/connexion/apps/abstract.py +++ b/connexion/apps/abstract.py @@ -11,6 +11,7 @@ from starlette.types import Receive, Scope, Send from connexion.jsonifier import Jsonifier from connexion.middleware import ConnexionMiddleware, SpecMiddleware +from connexion.middleware.lifespan import Lifespan from connexion.resolver import Resolver from connexion.uri_parsing import AbstractURIParser @@ -32,8 +33,9 @@ class AbstractApp: self, import_name: str, *, - specification_dir: t.Union[pathlib.Path, str] = "", + lifespan: t.Optional[Lifespan] = None, middlewares: t.Optional[list] = None, + specification_dir: t.Union[pathlib.Path, str] = "", arguments: t.Optional[dict] = None, auth_all_paths: t.Optional[bool] = None, jsonifier: t.Optional[Jsonifier] = None, @@ -50,11 +52,11 @@ class AbstractApp: :param import_name: The name of the package or module that this object belongs to. If you are using a single module, __name__ is always the correct value. If you however are using a package, it’s usually recommended to hardcode the name of your package there. + :param middlewares: The list of middlewares to wrap around the application. Defaults to + :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param specification_dir: The directory holding the specification(s). The provided path should either be absolute or relative to the root path of the application. Defaults to the root path. - :param middlewares: The list of middlewares to wrap around the application. Defaults to - :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param arguments: Arguments to substitute the specification using Jinja. :param auth_all_paths: whether to authenticate not paths not defined in the specification. Defaults to False. @@ -79,8 +81,9 @@ class AbstractApp: self.middleware = ConnexionMiddleware( self.middleware_app, import_name=import_name, - specification_dir=specification_dir, + lifespan=lifespan, middlewares=middlewares, + specification_dir=specification_dir, arguments=arguments, auth_all_paths=auth_all_paths, jsonifier=jsonifier, diff --git a/connexion/apps/asynchronous.py b/connexion/apps/asynchronous.py index f825ad3..da28df5 100644 --- a/connexion/apps/asynchronous.py +++ b/connexion/apps/asynchronous.py @@ -15,6 +15,7 @@ from connexion.apps.abstract import AbstractApp from connexion.decorators import StarletteDecorator from connexion.jsonifier import Jsonifier from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware +from connexion.middleware.lifespan import Lifespan from connexion.operations import AbstractOperation from connexion.resolver import Resolver from connexion.uri_parsing import AbstractURIParser @@ -120,8 +121,9 @@ class AsyncApp(AbstractApp): self, import_name: str, *, - specification_dir: t.Union[pathlib.Path, str] = "", + lifespan: t.Optional[Lifespan] = None, middlewares: t.Optional[list] = None, + specification_dir: t.Union[pathlib.Path, str] = "", arguments: t.Optional[dict] = None, auth_all_paths: t.Optional[bool] = None, jsonifier: t.Optional[Jsonifier] = None, @@ -138,11 +140,11 @@ class AsyncApp(AbstractApp): :param import_name: The name of the package or module that this object belongs to. If you are using a single module, __name__ is always the correct value. If you however are using a package, it’s usually recommended to hardcode the name of your package there. + :param middlewares: The list of middlewares to wrap around the application. Defaults to + :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param specification_dir: The directory holding the specification(s). The provided path should either be absolute or relative to the root path of the application. Defaults to the root path. - :param middlewares: The list of middlewares to wrap around the application. Defaults to - :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param arguments: Arguments to substitute the specification using Jinja. :param auth_all_paths: whether to authenticate not paths not defined in the specification. Defaults to False. @@ -168,8 +170,9 @@ class AsyncApp(AbstractApp): super().__init__( import_name, - specification_dir=specification_dir, + lifespan=lifespan, middlewares=middlewares, + specification_dir=specification_dir, arguments=arguments, auth_all_paths=auth_all_paths, jsonifier=jsonifier, diff --git a/connexion/apps/flask.py b/connexion/apps/flask.py index 6e7aedb..af03a41 100644 --- a/connexion/apps/flask.py +++ b/connexion/apps/flask.py @@ -18,6 +18,7 @@ from connexion.exceptions import InternalServerError, ProblemException, Resolver from connexion.frameworks import flask as flask_utils from connexion.jsonifier import Jsonifier from connexion.middleware.abstract import AbstractRoutingAPI, SpecMiddleware +from connexion.middleware.lifespan import Lifespan from connexion.operations import AbstractOperation from connexion.problem import problem from connexion.resolver import Resolver @@ -176,9 +177,10 @@ class FlaskApp(AbstractApp): self, import_name: str, *, + lifespan: t.Optional[Lifespan] = None, + middlewares: t.Optional[list] = None, server_args: t.Optional[dict] = None, specification_dir: t.Union[pathlib.Path, str] = "", - middlewares: t.Optional[list] = None, arguments: t.Optional[dict] = None, auth_all_paths: t.Optional[bool] = None, jsonifier: t.Optional[Jsonifier] = None, @@ -195,12 +197,14 @@ class FlaskApp(AbstractApp): :param import_name: The name of the package or module that this object belongs to. If you are using a single module, __name__ is always the correct value. If you however are using a package, it’s usually recommended to hardcode the name of your package there. + :param lifespan: A lifespan context function, which can be used to perform startup and + shutdown tasks. + :param middlewares: The list of middlewares to wrap around the application. Defaults to + :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param server_args: Arguments to pass to the Flask application. :param specification_dir: The directory holding the specification(s). The provided path should either be absolute or relative to the root path of the application. Defaults to the root path. - :param middlewares: The list of middlewares to wrap around the application. Defaults to - :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param arguments: Arguments to substitute the specification using Jinja. :param auth_all_paths: whether to authenticate not paths not defined in the specification. Defaults to False. @@ -226,8 +230,9 @@ class FlaskApp(AbstractApp): self.app = self.middleware_app.app super().__init__( import_name, - specification_dir=specification_dir, + lifespan=lifespan, middlewares=middlewares, + specification_dir=specification_dir, arguments=arguments, auth_all_paths=auth_all_paths, jsonifier=jsonifier, diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index 781d532..272b6e1 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -230,7 +230,7 @@ class RoutedMiddleware(SpecMiddleware, t.Generic[API]): api_cls: t.Type[API] """The subclass of RoutedAPI this middleware uses.""" - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp, **kwargs) -> None: self.app = app self.apis: t.Dict[str, API] = {} diff --git a/connexion/middleware/exceptions.py b/connexion/middleware/exceptions.py index c8f0c51..7ca4add 100644 --- a/connexion/middleware/exceptions.py +++ b/connexion/middleware/exceptions.py @@ -4,7 +4,7 @@ from starlette.exceptions import ExceptionMiddleware as StarletteExceptionMiddle from starlette.exceptions import HTTPException from starlette.requests import Request as StarletteRequest from starlette.responses import Response -from starlette.types import Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send from connexion.exceptions import InternalServerError, ProblemException, problem @@ -15,8 +15,8 @@ class ExceptionMiddleware(StarletteExceptionMiddleware): """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to existing connexion behavior.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, next_app: ASGIApp, *args, **kwargs): + super().__init__(next_app) self.add_exception_handler(ProblemException, self.problem_handler) self.add_exception_handler(Exception, self.common_error_handler) diff --git a/connexion/middleware/lifespan.py b/connexion/middleware/lifespan.py new file mode 100644 index 0000000..fccefd8 --- /dev/null +++ b/connexion/middleware/lifespan.py @@ -0,0 +1,28 @@ +import typing as t + +from starlette.routing import Router +from starlette.types import ASGIApp, Receive, Scope, Send + +Lifespan = t.Callable[[t.Any], t.AsyncContextManager] + + +class LifespanMiddleware: + """ + Middleware that adds support for Starlette lifespan handlers + (https://www.starlette.io/lifespan/). + """ + + def __init__( + self, next_app: ASGIApp, *, lifespan: t.Optional[Lifespan], **kwargs + ) -> None: + self.next_app = next_app + self._lifespan = lifespan + # Leverage a Starlette Router for lifespan handling only + self.router = Router(lifespan=lifespan) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + # If no lifespan is registered, pass to next app so it can be handled downstream. + if scope["type"] == "lifespan" and self._lifespan: + await self.router(scope, receive, send) + else: + await self.next_app(scope, receive, send) diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index e5f4006..5a0509d 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -12,6 +12,7 @@ from connexion.jsonifier import Jsonifier from connexion.middleware.abstract import SpecMiddleware from connexion.middleware.context import ContextMiddleware from connexion.middleware.exceptions import ExceptionMiddleware +from connexion.middleware.lifespan import Lifespan, LifespanMiddleware from connexion.middleware.request_validation import RequestValidationMiddleware from connexion.middleware.response_validation import ResponseValidationMiddleware from connexion.middleware.routing import RoutingMiddleware @@ -92,6 +93,7 @@ class ConnexionMiddleware: RequestValidationMiddleware, ResponseValidationMiddleware, ContextMiddleware, + LifespanMiddleware, ] def __init__( @@ -99,8 +101,9 @@ class ConnexionMiddleware: app: ASGIApp, *, import_name: t.Optional[str] = None, - specification_dir: t.Union[pathlib.Path, str] = "", + lifespan: t.Optional[Lifespan] = None, middlewares: t.Optional[list] = None, + specification_dir: t.Union[pathlib.Path, str] = "", arguments: t.Optional[dict] = None, auth_all_paths: t.Optional[bool] = None, jsonifier: t.Optional[Jsonifier] = None, @@ -117,11 +120,11 @@ class ConnexionMiddleware: :param import_name: The name of the package or module that this object belongs to. If you are using a single module, __name__ is always the correct value. If you however are using a package, it’s usually recommended to hardcode the name of your package there. + :param middlewares: The list of middlewares to wrap around the application. Defaults to + :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param specification_dir: The directory holding the specification(s). The provided path should either be absolute or relative to the root path of the application. Defaults to the root path. - :param middlewares: The list of middlewares to wrap around the application. Defaults to - :obj:`middleware.main.ConnexionmMiddleware.default_middlewares` :param arguments: Arguments to substitute the specification using Jinja. :param auth_all_paths: whether to authenticate not paths not defined in the specification. Defaults to False. @@ -150,7 +153,9 @@ class ConnexionMiddleware: if middlewares is None: middlewares = self.default_middlewares - self.app, self.apps = self._apply_middlewares(app, middlewares) + self.app, self.apps = self._apply_middlewares( + app, middlewares, lifespan=lifespan + ) self.options = _Options( arguments=arguments, @@ -177,9 +182,8 @@ class ConnexionMiddleware: else: return self.root_path / path - @staticmethod def _apply_middlewares( - app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]] + self, app: ASGIApp, middlewares: t.List[t.Type[ASGIApp]], **kwargs ) -> t.Tuple[ASGIApp, t.Iterable[ASGIApp]]: """Apply all middlewares to the provided app. @@ -193,13 +197,14 @@ class ConnexionMiddleware: # Include the wrapped application in the returned list. apps = [app] for middleware in reversed(middlewares): - app = middleware(app) # type: ignore + app = middleware(app, **kwargs) # type: ignore apps.append(app) return app, list(reversed(apps)) def add_api( self, specification: t.Union[pathlib.Path, str, dict], + *, base_path: t.Optional[str] = None, arguments: t.Optional[dict] = None, auth_all_paths: t.Optional[bool] = None, diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index 065fb7c..4351e7a 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -92,7 +92,7 @@ class RoutingAPI(AbstractRoutingAPI): class RoutingMiddleware(SpecMiddleware): - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp, **kwargs) -> None: """Middleware that resolves the Operation for an incoming request and attaches it to the scope. diff --git a/connexion/middleware/swagger_ui.py b/connexion/middleware/swagger_ui.py index b157e9b..eea534d 100644 --- a/connexion/middleware/swagger_ui.py +++ b/connexion/middleware/swagger_ui.py @@ -180,7 +180,7 @@ class SwaggerUIAPI(AbstractSpecAPI): class SwaggerUIMiddleware(SpecMiddleware): - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp, **kwargs) -> None: """Middleware that hosts a swagger UI. :param app: app to wrap in middleware. diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py new file mode 100644 index 0000000..150ace8 --- /dev/null +++ b/tests/test_lifespan.py @@ -0,0 +1,38 @@ +import contextlib +import sys +from unittest import mock + +import pytest +from connexion import AsyncApp, ConnexionMiddleware + + +def test_lifespan_handler(app_class): + m = mock.MagicMock() + + @contextlib.asynccontextmanager + async def lifespan(app): + m.startup() + yield + m.shutdown() + + app = AsyncApp(__name__, lifespan=lifespan) + with app.test_client(): + m.startup.assert_called() + m.shutdown.assert_not_called() + m.shutdown.assert_called() + + +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="AsyncMock only available from 3.8." +) +async def test_lifespan(): + """Test that lifespan events are passed through if no handler is registered.""" + lifecycle_handler = mock.Mock() + + async def check_lifecycle(scope, receive, send): + if scope["type"] == "lifespan": + lifecycle_handler.handle() + + test_app = ConnexionMiddleware(check_lifecycle) + await test_app({"type": "lifespan"}, mock.AsyncMock(), mock.AsyncMock()) + lifecycle_handler.handle.assert_called() diff --git a/tests/test_middleware.py b/tests/test_middleware.py index d598371..f7dac3c 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -13,7 +13,7 @@ class TestMiddleware: __test__ = False - def __init__(self, app): + def __init__(self, app, **kwargs): self.app = app async def __call__(self, scope, receive, send): @@ -49,19 +49,3 @@ def test_routing_middleware(middleware_app): assert ( response.headers.get("operation_id") == "fakeapi.hello.post_greeting" ), response.status_code - - -@pytest.mark.skipif( - sys.version_info < (3, 8), reason="AsyncMock only available from 3.8." -) -async def test_lifecycle(): - """Test that lifecycle events are passed correctly.""" - lifecycle_handler = mock.Mock() - - async def check_lifecycle(scope, receive, send): - if scope["type"] == "lifecycle": - lifecycle_handler.handle() - - test_app = ConnexionMiddleware(check_lifecycle) - await test_app({"type": "lifecycle"}, mock.AsyncMock, mock.AsyncMock) - lifecycle_handler.handle.assert_called()