diff --git a/connexion/__init__.py b/connexion/__init__.py index 9d5c1e4..2ba183d 100755 --- a/connexion/__init__.py +++ b/connexion/__init__.py @@ -20,7 +20,7 @@ from .utils import not_installed_error # NOQA try: from flask import request # NOQA - from .apis.flask_api import FlaskApi, context # NOQA + from .apis.flask_api import FlaskApi # NOQA from .apps.flask_app import FlaskApp except ImportError as e: # pragma: no cover _flask_not_installed_error = not_installed_error(e) diff --git a/connexion/apis/abstract.py b/connexion/apis/abstract.py index 6305d92..011b867 100644 --- a/connexion/apis/abstract.py +++ b/connexion/apis/abstract.py @@ -1,7 +1,6 @@ """ This module defines an AbstractAPI, which defines a standardized interface for a Connexion API. """ - import abc import logging import pathlib @@ -9,15 +8,18 @@ import sys import typing as t from enum import Enum -from ..datastructures import NoContent -from ..exceptions import ResolverError -from ..http_facts import METHODS -from ..jsonifier import Jsonifier -from ..lifecycle import ConnexionResponse -from ..operations import make_operation -from ..options import ConnexionOptions -from ..resolver import Resolver -from ..spec import Specification +from starlette.requests import Request + +from connexion.datastructures import NoContent +from connexion.decorators.parameter import parameter_to_arg +from connexion.exceptions import ResolverError +from connexion.http_facts import METHODS +from connexion.jsonifier import Jsonifier +from connexion.lifecycle import ConnexionResponse +from connexion.operations import make_operation +from connexion.options import ConnexionOptions +from connexion.resolver import Resolver +from connexion.spec import Specification MODULE_PATH = pathlib.Path(__file__).absolute().parent.parent SWAGGER_UI_URL = "ui" @@ -231,30 +233,19 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): self.resolver, pythonic_params=self.pythonic_params, uri_parser_class=self.options.uri_parser_class, + parameter_to_arg=parameter_to_arg, ) self._add_operation_internal(method, path, operation) @classmethod @abc.abstractmethod - def get_request(self, *args, **kwargs): + def get_request(cls, uri_parser) -> Request: """ This method converts the user framework request to a ConnexionRequest. """ @classmethod - @abc.abstractmethod - def get_response(self, response, mimetype=None, request=None): - """ - This method converts a handler response to a framework response. - This method should just retrieve response from handler then call `cls._get_response`. - :param response: A response to cast (tuple, framework response, etc). - :param mimetype: The response mimetype. - :type mimetype: Union[None, str] - :param request: The request associated with this response (the user framework request). - """ - - @classmethod - def _get_response(cls, response, mimetype=None, extra_context=None): + def _get_response(cls, response, mimetype=None): """ This method converts a handler response to a framework response. The response can be a ConnexionResponse, an operation handler, a framework response or a tuple. @@ -262,31 +253,24 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): :param response: A response to cast (tuple, framework response, etc). :param mimetype: The response mimetype. :type mimetype: Union[None, str] - :param extra_context: dict of extra details, like url, to include in logs - :type extra_context: Union[None, dict] """ - if extra_context is None: - extra_context = {} logger.debug( "Getting data and status code", - extra={"data": response, "data_type": type(response), **extra_context}, + extra={"data": response, "data_type": type(response)}, ) if isinstance(response, ConnexionResponse): framework_response = cls._connexion_to_framework_response( - response, mimetype, extra_context + response, mimetype ) else: - framework_response = cls._response_from_handler( - response, mimetype, extra_context - ) + framework_response = cls._response_from_handler(response, mimetype) logger.debug( "Got framework response", extra={ "response": framework_response, "response_type": type(framework_response), - **extra_context, }, ) return framework_response @@ -298,7 +282,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): t.Any, str, t.Tuple[str], t.Tuple[str, int], t.Tuple[str, int, dict] ], mimetype: str, - extra_context: t.Optional[dict] = None, ) -> t.Any: """ Create a framework response from the operation handler data. @@ -311,7 +294,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): :param response: A response from an operation handler. :param mimetype: The response mimetype. - :param extra_context: dict of extra details, like url, to include in logs """ if cls._is_framework_response(response): return response @@ -320,9 +302,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): len_response = len(response) if len_response == 1: (data,) = response - return cls._build_response( - mimetype=mimetype, data=data, extra_context=extra_context - ) + return cls._build_response(mimetype=mimetype, data=data) if len_response == 2: if isinstance(response[1], (int, Enum)): data, status_code = response @@ -330,7 +310,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): mimetype=mimetype, data=data, status_code=status_code, - extra_context=extra_context, ) else: data, headers = response @@ -338,7 +317,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): mimetype=mimetype, data=data, headers=headers, - extra_context=extra_context, ) elif len_response == 3: data, status_code, headers = response @@ -347,7 +325,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): data=data, status_code=status_code, headers=headers, - extra_context=extra_context, ) else: raise TypeError( @@ -356,9 +333,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): " (body, status), or (body, headers)." ) else: - return cls._build_response( - mimetype=mimetype, data=response, extra_context=extra_context - ) + return cls._build_response(mimetype=mimetype, data=response) @classmethod def get_connexion_response(cls, response, mimetype=None): @@ -384,7 +359,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): @classmethod @abc.abstractmethod - def _connexion_to_framework_response(cls, response, mimetype, extra_context=None): + def _connexion_to_framework_response(cls, response, mimetype): """Cast ConnexionResponse to framework response class""" @classmethod @@ -396,7 +371,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): content_type=None, status_code=None, headers=None, - extra_context=None, ): """ Create a framework response from the provided arguments. @@ -407,16 +381,12 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): :type status_code: int :param headers: The response status code. :type headers: Union[Iterable[Tuple[str, str]], Dict[str, str]] - :param extra_context: dict of extra details, like url, to include in logs - :type extra_context: Union[None, dict] :return A framework response. :rtype Response """ @classmethod - def _prepare_body_and_status_code( - cls, data, mimetype, status_code=None, extra_context=None - ): + def _prepare_body_and_status_code(cls, data, mimetype, status_code=None): if data is NoContent: data = None @@ -435,12 +405,10 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta): else: body = data - if extra_context is None: - extra_context = {} logger.debug( "Prepared body and status code (%d)", status_code, - extra={"body": body, **extra_context}, + extra={"body": body}, ) return body, status_code, mimetype diff --git a/connexion/apis/flask_api.py b/connexion/apis/flask_api.py index 8489079..4fab1d6 100644 --- a/connexion/apis/flask_api.py +++ b/connexion/apis/flask_api.py @@ -2,13 +2,9 @@ This module defines a Flask Connexion API which implements translations between Flask and Connexion requests / responses. """ - import logging -import warnings -from typing import Any import flask -from werkzeug.local import LocalProxy from connexion.apis import flask_utils from connexion.apis.abstract import AbstractAPI @@ -52,7 +48,7 @@ class FlaskApi(AbstractAPI): ) @classmethod - def get_response(cls, response, mimetype=None, request=None): + def get_response(cls, response, mimetype=None): """Gets ConnexionResponse instance for the operation handler result. Status Code and Headers for response. If only body data is returned by the endpoint function, then the status @@ -64,9 +60,7 @@ class FlaskApi(AbstractAPI): :type response: flask.Response | (flask.Response,) | (flask.Response, int) | (flask.Response, dict) | (flask.Response, int, dict) :rtype: ConnexionResponse """ - return cls._get_response( - response, mimetype=mimetype, extra_context={"url": flask.request.url} - ) + return cls._get_response(response, mimetype=mimetype) @classmethod def _is_framework_response(cls, response): @@ -86,7 +80,7 @@ class FlaskApi(AbstractAPI): ) @classmethod - def _connexion_to_framework_response(cls, response, mimetype, extra_context=None): + def _connexion_to_framework_response(cls, response, mimetype): """Cast ConnexionResponse to framework response class""" flask_response = cls._build_response( mimetype=response.mimetype or mimetype, @@ -94,7 +88,6 @@ class FlaskApi(AbstractAPI): headers=response.headers, status_code=response.status_code, data=response.body, - extra_context=extra_context, ) return flask_response @@ -107,7 +100,6 @@ class FlaskApi(AbstractAPI): headers=None, status_code=None, data=None, - extra_context=None, ): if cls._is_framework_response(data): return flask.current_app.make_response((data, status_code, headers)) @@ -116,7 +108,6 @@ class FlaskApi(AbstractAPI): data=data, mimetype=mimetype, status_code=status_code, - extra_context=extra_context, ) kwargs = { @@ -133,61 +124,14 @@ class FlaskApi(AbstractAPI): def _serialize_data(cls, data, mimetype): if isinstance(mimetype, str) and is_json_mimetype(mimetype): body = cls.jsonifier.dumps(data) - elif not (isinstance(data, bytes) or isinstance(data, str)): - warnings.warn( - "Implicit (flask) JSON serialization will change in the next major version. " - "This is triggered because a response body is being serialized as JSON " - "even though the mimetype is not a JSON type. " - "This will be replaced by something that is mimetype-specific and may " - "raise an error instead of silently converting everything to JSON. " - "Please make sure to specify media/mime types in your specs.", - FutureWarning, # a Deprecation targeted at application users. - ) - body = cls.jsonifier.dumps(data) else: body = data return body, mimetype @classmethod - def get_request(cls, *args, **params): - # type: (*Any, **Any) -> ConnexionRequest - """Gets ConnexionRequest instance for the operation handler - result. Status Code and Headers for response. If only body - data is returned by the endpoint function, then the status - code will be set to 200 and no headers will be added. - - If the returned object is a flask.Response then it will just - pass the information needed to recreate it. - - :rtype: ConnexionRequest - """ - flask_request = flask.request - scope = flask_request.environ["asgi.scope"] - context_dict = scope.get("extensions", {}).get("connexion_context", {}) - setattr(flask.globals.request_ctx, "connexion_context", context_dict) - request = ConnexionRequest( - flask_request.url, - flask_request.method, - headers=flask_request.headers, - form=flask_request.form, - query=flask_request.args, - body=flask_request.get_data(), - json_getter=lambda: flask_request.get_json(silent=True), - files=flask_request.files, - path_params=params, - context=context_dict, - cookies=flask_request.cookies, - ) - logger.debug( - "Getting data and status code", - extra={ - "data": request.body, - "data_type": type(request.body), - "url": request.url, - }, - ) - return request + def get_request(cls, uri_parser) -> ConnexionRequest: + return ConnexionRequest(flask.request, uri_parser=uri_parser) @classmethod def _set_jsonifier(cls): @@ -195,10 +139,3 @@ class FlaskApi(AbstractAPI): Use Flask specific JSON loader """ cls.jsonifier = Jsonifier(flask.json, indent=2) - - -def _get_context(): - return getattr(flask.globals.request_ctx, "connexion_context") - - -context = LocalProxy(_get_context) diff --git a/connexion/apps/flask_app.py b/connexion/apps/flask_app.py index 38769bf..6a3ad65 100644 --- a/connexion/apps/flask_app.py +++ b/connexion/apps/flask_app.py @@ -7,6 +7,7 @@ import pathlib from types import FunctionType # NOQA import a2wsgi +import asgiref.wsgi import flask import werkzeug.exceptions from flask import signals diff --git a/connexion/context.py b/connexion/context.py new file mode 100644 index 0000000..5036866 --- /dev/null +++ b/connexion/context.py @@ -0,0 +1,9 @@ +from asyncio import AbstractEventLoop +from contextvars import ContextVar + +_context: ContextVar[AbstractEventLoop] = ContextVar("CONTEXT") + + +def __getattr__(name): + if name == "context": + return _context.get() diff --git a/connexion/decorators/lifecycle.py b/connexion/decorators/lifecycle.py index 724f673..50dffd7 100644 --- a/connexion/decorators/lifecycle.py +++ b/connexion/decorators/lifecycle.py @@ -22,7 +22,7 @@ class RequestResponseDecorator: self.api = api self.mimetype = mimetype - def __call__(self, function): + def __call__(self, function, uri_parser): """ :type function: types.FunctionType :rtype: types.FunctionType @@ -31,7 +31,7 @@ class RequestResponseDecorator: @functools.wraps(function) async def wrapper(*args, **kwargs): - connexion_request = self.api.get_request(*args, **kwargs) + connexion_request = self.api.get_request(uri_parser=uri_parser) while asyncio.iscoroutine(connexion_request): connexion_request = await connexion_request @@ -40,7 +40,7 @@ class RequestResponseDecorator: connexion_response = await connexion_response framework_response = self.api.get_response( - connexion_response, self.mimetype, connexion_request + connexion_response, self.mimetype ) while asyncio.iscoroutine(framework_response): framework_response = await framework_response @@ -51,8 +51,8 @@ class RequestResponseDecorator: @functools.wraps(function) def wrapper(*args, **kwargs): - request = self.api.get_request(*args, **kwargs) + request = self.api.get_request(uri_parser) response = function(request) - return self.api.get_response(response, self.mimetype, request) + return self.api.get_response(response, self.mimetype) return wrapper diff --git a/connexion/decorators/parameter.py b/connexion/decorators/parameter.py index c089a6f..3cac6cd 100644 --- a/connexion/decorators/parameter.py +++ b/connexion/decorators/parameter.py @@ -1,112 +1,69 @@ """ -This module defines a decorator to convert request parameters to arguments for the view function. +This module defines a utility functions to convert request parameters to arguments for the view +function. """ - import builtins import functools import inspect import keyword import logging import re -from typing import Any +import typing as t +from copy import copy, deepcopy import inflection -from ..http_facts import FORM_CONTENT_TYPES -from ..lifecycle import ConnexionRequest # NOQA -from ..utils import all_json +from connexion.http_facts import FORM_CONTENT_TYPES +from connexion.lifecycle import ConnexionRequest +from connexion.operations import AbstractOperation, Swagger2Operation +from connexion.utils import ( + deep_merge, + is_json_mimetype, + is_null, + is_nullable, + make_type, +) logger = logging.getLogger(__name__) CONTEXT_NAME = "context_" -def inspect_function_arguments(function): # pragma: no cover - """ - Returns the list of variables names of a function and if it - accepts keyword arguments. - - :type function: Callable - :rtype: tuple[list[str], bool] - """ - parameters = inspect.signature(function).parameters - bound_arguments = [ - name - for name, p in parameters.items() - if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD) - ] - has_kwargs = any(p.kind == p.VAR_KEYWORD for p in parameters.values()) - return list(bound_arguments), has_kwargs - - -def snake_and_shadow(name): - """ - Converts the given name into Pythonic form. Firstly it converts CamelCase names to snake_case. Secondly it looks to - see if the name matches a known built-in and if it does it appends an underscore to the name. - :param name: The parameter name - :type name: str - :return: - """ - snake = inflection.underscore(name) - if snake in builtins.__dict__ or keyword.iskeyword(snake): - return f"{snake}_" - return snake - - -def sanitized(name): - return name and re.sub( - "^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", re.sub(r"\[(?!])", "_", name)) - ) - - -def pythonic(name): - name = name and snake_and_shadow(name) - return sanitized(name) - - -def parameter_to_arg(operation, function, pythonic_params=False): - """ - Pass query and body parameters as keyword arguments to handler function. - - See (https://github.com/zalando/connexion/issues/59) - :param operation: The operation being called - :type operation: connexion.operations.AbstractOperation - :param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to - any shadowed built-ins - :type pythonic_params: bool - """ - consumes = operation.consumes - +def parameter_to_arg( + operation: AbstractOperation, + function: t.Callable, + pythonic_params: bool = False, +) -> t.Callable[[ConnexionRequest], t.Any]: sanitize = pythonic if pythonic_params else sanitized arguments, has_kwargs = inspect_function_arguments(function) @functools.wraps(function) - def wrapper(request): - # type: (ConnexionRequest) -> Any - logger.debug("Function Arguments: %s", arguments) + def wrapper(request: ConnexionRequest) -> t.Any: kwargs = {} - if all_json(consumes): - request_body = request.json - elif consumes[0] in FORM_CONTENT_TYPES: - request_body = request.form + body_name = sanitize(operation.body_name(request.content_type)) + if body_name in arguments or has_kwargs: + request_body = get_body(request) + # Pass form contents separately for Swagger2 for backward compatibility with Connexion 2 + # Checking for body_name is not enough + elif request.mimetype in FORM_CONTENT_TYPES and isinstance( + operation, Swagger2Operation + ): + request_body = get_body(request) else: - request_body = request.body - - try: - query = request.query.to_dict(flat=False) - except AttributeError: - query = dict(request.query.items()) + request_body = None kwargs.update( - operation.get_arguments( - request.path_params, - query, - request_body, - request.files, - arguments, - has_kwargs, - sanitize, + get_arguments( + operation, + path_params=request.view_args, + query_params=request.args, + body=request_body, + files=request.files, + arguments=arguments, + has_kwargs=has_kwargs, + sanitize=sanitize, + content_type=request.content_type, ) ) @@ -127,3 +84,348 @@ def parameter_to_arg(operation, function, pythonic_params=False): return function(**kwargs) return wrapper + + +def get_body(request: ConnexionRequest) -> t.Any: + """Get body from the request based on the content type.""" + if is_json_mimetype(request.content_type): + return request.get_json(silent=True) + elif request.mimetype in FORM_CONTENT_TYPES: + return request.form + else: + # Return explicit None instead of empty bytestring so it is handled as null downstream + return request.get_data() or None + + +def inspect_function_arguments(function: t.Callable) -> t.Tuple[t.List[str], bool]: + """ + Returns the list of variables names of a function and if it + accepts keyword arguments. + """ + parameters = inspect.signature(function).parameters + bound_arguments = [ + name + for name, p in parameters.items() + if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD) + ] + has_kwargs = any(p.kind == p.VAR_KEYWORD for p in parameters.values()) + return list(bound_arguments), has_kwargs + + +def snake_and_shadow(name: str) -> str: + """ + Converts the given name into Pythonic form. Firstly it converts CamelCase names to snake_case. Secondly it looks to + see if the name matches a known built-in and if it does it appends an underscore to the name. + :param name: The parameter name + """ + snake = inflection.underscore(name) + if snake in builtins.__dict__ or keyword.iskeyword(snake): + return f"{snake}_" + return snake + + +def sanitized(name: str) -> str: + return name and re.sub( + "^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", re.sub(r"\[(?!])", "_", name)) + ) + + +def pythonic(name: str) -> str: + name = name and snake_and_shadow(name) + return sanitized(name) + + +def get_arguments( + operation: AbstractOperation, + *, + path_params: dict, + query_params: dict, + body: t.Any, + files: dict, + arguments: t.List[str], + has_kwargs: bool, + sanitize: t.Callable, + content_type: str, +) -> t.Dict[str, t.Any]: + """ + get arguments for handler function + """ + ret = {} + ret.update(_get_path_arguments(path_params, operation=operation, sanitize=sanitize)) + ret.update( + _get_query_arguments( + query_params, + operation=operation, + arguments=arguments, + has_kwargs=has_kwargs, + sanitize=sanitize, + ) + ) + + if operation.method.upper() in ["PATCH", "POST", "PUT"]: + ret.update( + _get_body_argument( + body, + operation=operation, + arguments=arguments, + has_kwargs=has_kwargs, + sanitize=sanitize, + content_type=content_type, + ) + ) + ret.update(_get_file_arguments(files, arguments, has_kwargs)) + return ret + + +def _get_path_arguments( + path_params: dict, *, operation: AbstractOperation, sanitize: t.Callable +) -> dict: + """ + Extract handler function arguments from path parameters + """ + kwargs = {} + + path_definitions = { + parameter["name"]: parameter + for parameter in operation.parameters + if parameter["in"] == "path" + } + + for name, value in path_params.items(): + sanitized_key = sanitize(name) + if name in path_definitions: + kwargs[sanitized_key] = _get_val_from_param(value, path_definitions[name]) + else: # Assume path params mechanism used for injection + kwargs[sanitized_key] = value + return kwargs + + +def _get_val_from_param(value: t.Any, param_definitions: t.Dict[str, dict]) -> t.Any: + """Cast a value according to its definition in the specification.""" + param_schema = param_definitions.get("schema", param_definitions) + + if is_nullable(param_schema) and is_null(value): + return None + + if param_schema["type"] == "array": + type_ = param_schema["items"]["type"] + format_ = param_schema["items"].get("format") + return [make_type(part, type_, format_) for part in value] + else: + type_ = param_schema["type"] + format_ = param_schema.get("format") + return make_type(value, type_, format_) + + +def _get_query_arguments( + query_params: dict, + *, + operation: AbstractOperation, + arguments: t.List[str], + has_kwargs: bool, + sanitize: t.Callable, +) -> dict: + """ + extract handler function arguments from the query parameters + """ + query_definitions = { + parameter["name"]: parameter + for parameter in operation.parameters + if parameter["in"] == "query" + } + + default_query_params = _get_query_defaults(query_definitions) + + query_arguments = deepcopy(default_query_params) + query_arguments = deep_merge(query_arguments, query_params) + return _query_args_helper( + query_definitions, query_arguments, arguments, has_kwargs, sanitize + ) + + +def _get_query_defaults(query_definitions: t.Dict[str, dict]) -> t.Dict[str, t.Any]: + """Get the default values for the query parameter from the parameter definition.""" + defaults = {} + for k, v in query_definitions.items(): + try: + if "default" in v: + defaults[k] = v["default"] + elif v["schema"]["type"] == "object": + defaults[k] = _get_default_obj(v["schema"]) + else: + defaults[k] = v["schema"]["default"] + except KeyError: + pass + return defaults + + +def _get_default_obj(schema: dict) -> dict: + try: + return deepcopy(schema["default"]) + except KeyError: + properties = schema.get("properties", {}) + return _build_default_obj_recursive(properties, {}) + + +def _build_default_obj_recursive(properties: dict, default_object: dict) -> dict: + """takes disparate and nested default keys, and builds up a default object""" + for name, property_ in properties.items(): + if "default" in property_ and name not in default_object: + default_object[name] = copy(property_["default"]) + elif property_.get("type") == "object" and "properties" in property_: + default_object.setdefault(name, {}) + default_object[name] = _build_default_obj_recursive( + property_["properties"], default_object[name] + ) + return default_object + + +def _query_args_helper( + query_definitions: dict, + query_arguments: dict, + function_arguments: t.List[str], + has_kwargs: bool, + sanitize: t.Callable, +) -> dict: + result = {} + for key, value in query_arguments.items(): + sanitized_key = sanitize(key) + if not has_kwargs and sanitized_key not in function_arguments: + logger.debug( + "Query Parameter '%s' (sanitized: '%s') not in function arguments", + key, + sanitized_key, + ) + else: + logger.debug( + "Query Parameter '%s' (sanitized: '%s') in function arguments", + key, + sanitized_key, + ) + try: + query_defn = query_definitions[key] + except KeyError: # pragma: no cover + logger.error( + "Function argument '%s' (non-sanitized: %s) not defined in specification", + sanitized_key, + key, + ) + else: + logger.debug("%s is a %s", key, query_defn) + result.update({sanitized_key: _get_val_from_param(value, query_defn)}) + return result + + +def _get_body_argument( + body: t.Any, + *, + operation: AbstractOperation, + arguments: t.List[str], + has_kwargs: bool, + sanitize: t.Callable, + content_type: str, +) -> dict: + if len(arguments) <= 0 and not has_kwargs: + return {} + + body_name = sanitize(operation.body_name(content_type)) + + if content_type in FORM_CONTENT_TYPES: + result = _get_body_argument_form( + body, operation=operation, content_type=content_type + ) + + # Unpack form values for Swagger for compatibility with Connexion 2 behavior + if content_type in FORM_CONTENT_TYPES and isinstance( + operation, Swagger2Operation + ): + if has_kwargs: + return result + else: + return { + sanitize(name): value + for name, value in result.items() + if sanitize(name) in arguments + } + else: + result = _get_body_argument_json( + body, operation=operation, content_type=content_type + ) + + if body_name in arguments or has_kwargs: + return {body_name: result} + + return {} + + +def _get_body_argument_json( + body: t.Any, *, operation: AbstractOperation, content_type: str +) -> t.Any: + # if the body came in null, and the schema says it can be null, we decide + # to include no value for the body argument, rather than the default body + if is_nullable(operation.body_schema(content_type)) and is_null(body): + return None + + if body is None: + default_body = operation.body_schema(content_type).get("default", {}) + return deepcopy(default_body) + + return body + + +def _get_body_argument_form( + body: dict, *, operation: AbstractOperation, content_type: str +) -> dict: + # now determine the actual value for the body (whether it came in or is default) + default_body = operation.body_schema(content_type).get("default", {}) + body_props = { + k: {"schema": v} + for k, v in operation.body_schema(content_type).get("properties", {}).items() + } + + # by OpenAPI specification `additionalProperties` defaults to `true` + # see: https://github.com/OAI/OpenAPI-Specification/blame/3.0.2/versions/3.0.2.md#L2305 + additional_props = operation.body_schema().get("additionalProperties", True) + + body_arg = deepcopy(default_body) + body_arg.update(body or {}) + + if body_props or additional_props: + return _get_typed_body_values(body_arg, body_props, additional_props) + + return {} + + +def _get_typed_body_values(body_arg, body_props, additional_props): + """ + Return a copy of the provided body_arg dictionary + whose values will have the appropriate types + as defined in the provided schemas. + + :type body_arg: type dict + :type body_props: dict + :type additional_props: dict|bool + :rtype: dict + """ + additional_props_defn = ( + {"schema": additional_props} if isinstance(additional_props, dict) else None + ) + res = {} + + for key, value in body_arg.items(): + try: + prop_defn = body_props[key] + res[key] = _get_val_from_param(value, prop_defn) + except KeyError: # pragma: no cover + if not additional_props: + logger.error(f"Body property '{key}' not defined in body schema") + continue + if additional_props_defn is not None: + value = _get_val_from_param(value, additional_props_defn) + res[key] = value + + return res + + +def _get_file_arguments(files, arguments, has_kwargs=False): + return {k: v for k, v in files.items() if k in arguments or has_kwargs} diff --git a/connexion/lifecycle.py b/connexion/lifecycle.py index 2df2636..8af1065 100644 --- a/connexion/lifecycle.py +++ b/connexion/lifecycle.py @@ -2,46 +2,13 @@ This module defines interfaces for requests and responses used in Connexion for authentication, validation, serialization, etc. """ +import typing as t + +from flask import Request as FlaskRequest from starlette.requests import Request as StarletteRequest from starlette.responses import StreamingResponse as StarletteStreamingResponse -class ConnexionRequest: - """Connexion interface for a request.""" - - def __init__( - self, - url, - method, - path_params=None, - query=None, - headers=None, - form=None, - body=None, - json_getter=None, - files=None, - context=None, - cookies=None, - ): - self.url = url - self.method = method - self.path_params = path_params or {} - self.query = query or {} - self.headers = headers or {} - self.form = form or {} - self.body = body - self.json_getter = json_getter - self.files = files - self.context = context if context is not None else {} - self.cookies = cookies or {} - - @property - def json(self): - if not hasattr(self, "_json"): - self._json = self.json_getter() - return self._json - - class ConnexionResponse: """Connexion interface for a response.""" @@ -62,6 +29,41 @@ class ConnexionResponse: self.is_streamed = is_streamed +class ConnexionRequest: + def __init__(self, flask_request: FlaskRequest, uri_parser=None): + self._flask_request = flask_request + self.uri_parser = uri_parser + self._context = None + + @property + def context(self): + if self._context is None: + scope = self._flask_request.environ["asgi.scope"] + extensions = scope.setdefault("extensions", {}) + self._context = extensions.setdefault("connexion_context", {}) + + return self._context + + @property + def view_args(self) -> t.Dict[str, t.Any]: + return self.uri_parser.resolve_path(self._flask_request.view_args) + + @property + def args(self): + query_params = self._flask_request.args + query_params = {k: query_params.getlist(k) for k in query_params} + return self.uri_parser.resolve_query(query_params) + + @property + def form(self): + form = self._flask_request.form.to_dict(flat=False) + form_data = self.uri_parser.resolve_form(form) + return form_data + + def __getattr__(self, item): + return getattr(self._flask_request, item) + + class MiddlewareRequest(StarletteRequest): """Wraps starlette Request so it can easily be extended.""" diff --git a/connexion/middleware/context.py b/connexion/middleware/context.py new file mode 100644 index 0000000..f8fb052 --- /dev/null +++ b/connexion/middleware/context.py @@ -0,0 +1,15 @@ +"""The ContextMiddleware creates a global context based the scope. It should be last in the +middleware stack, so it exposes the scope passed to the application""" +from starlette.types import ASGIApp, Receive, Scope, Send + +from connexion.context import _context + + +class ContextMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + context = scope.get("extensions", {}).get("connexion_context", {}) + _context.set(context) + await self.app(scope, receive, send) diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index 2e2acea..fcd4ff0 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -4,6 +4,7 @@ import typing as t from starlette.types import ASGIApp, Receive, Scope, Send from connexion.middleware.abstract import AppMiddleware +from connexion.middleware.context import ContextMiddleware from connexion.middleware.exceptions import ExceptionMiddleware from connexion.middleware.request_validation import RequestValidationMiddleware from connexion.middleware.response_validation import ResponseValidationMiddleware @@ -21,6 +22,7 @@ class ConnexionMiddleware: SecurityMiddleware, RequestValidationMiddleware, ResponseValidationMiddleware, + ContextMiddleware, ] def __init__( diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index 9419712..192f26f 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -1,7 +1,10 @@ +import functools import pathlib +import re import typing as t from contextvars import ContextVar +import starlette.convertors from starlette.routing import Router from starlette.types import ASGIApp, Receive, Scope, Send @@ -53,7 +56,7 @@ class RoutingAPI(AbstractRoutingAPI): resolver: t.Optional[Resolver] = None, resolver_error_handler: t.Optional[t.Callable] = None, debug: bool = False, - **kwargs + **kwargs, ) -> None: """API implementation on top of Starlette Router for Connexion middleware.""" self.next_app = next_app @@ -76,6 +79,8 @@ class RoutingAPI(AbstractRoutingAPI): routing_operation = RoutingOperation.from_operation( operation, next_app=self.next_app ) + types = operation.get_path_parameter_types() + path = starlettify_path(path, types) self._add_operation_internal(method, path, routing_operation) def _add_operation_internal( @@ -94,13 +99,15 @@ class RoutingMiddleware(AppMiddleware): self.app = app # Pass unknown routes to next app self.router = Router(default=RoutingOperation(None, self.app)) + starlette.convertors.register_url_convertor("float", FloatConverter()) + starlette.convertors.register_url_convertor("int", IntegerConverter()) def add_api( self, specification: t.Union[pathlib.Path, str, dict], base_path: t.Optional[str] = None, arguments: t.Optional[dict] = None, - **kwargs + **kwargs, ) -> None: """Add an API to the router based on a OpenAPI spec. @@ -113,7 +120,7 @@ class RoutingMiddleware(AppMiddleware): base_path=base_path, arguments=arguments, next_app=self.app, - **kwargs + **kwargs, ) self.router.mount(api.base_path, app=api.router) @@ -129,3 +136,46 @@ class RoutingMiddleware(AppMiddleware): # Needs to be set so starlette router throws exceptions instead of returning error responses scope["app"] = self await self.router(scope, receive, send) + + +PATH_PARAMETER = re.compile(r"\{([^}]*)\}") +PATH_PARAMETER_CONVERTERS = {"integer": "int", "number": "float", "path": "path"} + + +def convert_path_parameter(match, types): + name = match.group(1) + swagger_type = types.get(name) + converter = PATH_PARAMETER_CONVERTERS.get(swagger_type) + return f'{{{name.replace("-", "_")}{":" if converter else ""}{converter or ""}}}' + + +def starlettify_path(swagger_path, types=None): + """ + Convert swagger path templates to flask path templates + + :type swagger_path: str + :type types: dict + :rtype: str + + >>> starlettify_path('/foo-bar/{my-param}') + '/foo-bar/{my_param}' + + >>> starlettify_path('/foo/{someint}', {'someint': 'int'}) + '/foo/{someint:int}' + """ + if types is None: + types = {} + convert_match = functools.partial(convert_path_parameter, types=types) + return PATH_PARAMETER.sub(convert_match, swagger_path) + + +class FloatConverter(starlette.convertors.FloatConvertor): + """Starlette converter for OpenAPI number type""" + + regex = r"[+-]?[0-9]*(\.[0-9]*)?" + + +class IntegerConverter(starlette.convertors.IntegerConvertor): + """Starlette converter for OpenAPI integer type""" + + regex = r"[+-]?[0-9]+" diff --git a/connexion/operations/abstract.py b/connexion/operations/abstract.py index eff4ed9..cad28a8 100644 --- a/connexion/operations/abstract.py +++ b/connexion/operations/abstract.py @@ -6,9 +6,8 @@ and functionality shared between Swagger 2 and OpenAPI 3 specifications. import abc import logging -from ..decorators.lifecycle import RequestResponseDecorator -from ..decorators.parameter import parameter_to_arg -from ..utils import all_json +from connexion.decorators.lifecycle import RequestResponseDecorator +from connexion.utils import all_json logger = logging.getLogger("connexion.operations.abstract") @@ -46,6 +45,7 @@ class AbstractOperation(metaclass=abc.ABCMeta): randomize_endpoint=None, pythonic_params=False, uri_parser_class=None, + parameter_to_arg=None, ): """ :param api: api that this operation is attached to @@ -87,6 +87,8 @@ class AbstractOperation(metaclass=abc.ABCMeta): self._responses = self._operation.get("responses", {}) + self.parameter_to_arg = parameter_to_arg + @property def api(self): return self._api @@ -148,75 +150,6 @@ class AbstractOperation(metaclass=abc.ABCMeta): """ return self._pythonic_params - @staticmethod - def _get_file_arguments(files, arguments, has_kwargs=False): - return {k: v for k, v in files.items() if k in arguments or has_kwargs} - - @abc.abstractmethod - def _get_val_from_param(self, value, query_defn): - """ - Convert input parameters into the correct type - """ - - def _query_args_helper( - self, query_defns, query_arguments, function_arguments, has_kwargs, sanitize - ): - res = {} - for key, value in query_arguments.items(): - sanitized_key = sanitize(key) - if not has_kwargs and sanitized_key not in function_arguments: - logger.debug( - "Query Parameter '%s' (sanitized: '%s') not in function arguments", - key, - sanitized_key, - ) - else: - logger.debug( - "Query Parameter '%s' (sanitized: '%s') in function arguments", - key, - sanitized_key, - ) - try: - query_defn = query_defns[key] - except KeyError: # pragma: no cover - logger.error( - "Function argument '%s' (non-sanitized: %s) not defined in specification", - sanitized_key, - key, - ) - else: - logger.debug("%s is a %s", key, query_defn) - res.update( - {sanitized_key: self._get_val_from_param(value, query_defn)} - ) - return res - - @abc.abstractmethod - def _get_query_arguments(self, query, arguments, has_kwargs, sanitize): - """ - extract handler function arguments from the query parameters - """ - - @abc.abstractmethod - def _get_body_argument(self, body, arguments, has_kwargs, sanitize): - """ - extract handler function arguments from the request body - """ - - def _get_path_arguments(self, path_params, sanitize): - """ - extract handler function arguments from path parameters - """ - kwargs = {} - path_defns = {p["name"]: p for p in self.parameters if p["in"] == "path"} - for key, value in path_params.items(): - sanitized_key = sanitize(key) - if key in path_defns: - kwargs[sanitized_key] = self._get_val_from_param(value, path_defns[key]) - else: # Assume path params mechanism used for injection - kwargs[sanitized_key] = value - return kwargs - @property @abc.abstractmethod def parameters(self): @@ -238,6 +171,12 @@ class AbstractOperation(metaclass=abc.ABCMeta): Content-Types that the operation consumes """ + @abc.abstractmethod + def body_name(self, content_type: str) -> str: + """ + Name of the body in the spec. + """ + @abc.abstractmethod def body_schema(self, content_type: str = None) -> dict: """ @@ -251,23 +190,6 @@ class AbstractOperation(metaclass=abc.ABCMeta): :rtype: dict """ - def get_arguments( - self, path_params, query_params, body, files, arguments, has_kwargs, sanitize - ): - """ - get arguments for handler function - """ - ret = {} - ret.update(self._get_path_arguments(path_params, sanitize)) - ret.update( - self._get_query_arguments(query_params, arguments, has_kwargs, sanitize) - ) - - if self.method.upper() in ["PATCH", "POST", "PUT"]: - ret.update(self._get_body_argument(body, arguments, has_kwargs, sanitize)) - ret.update(self._get_file_arguments(files, arguments, has_kwargs)) - return ret - def response_definition(self, status_code=None, content_type=None): """ response definition for this endpoint @@ -335,17 +257,19 @@ class AbstractOperation(metaclass=abc.ABCMeta): :rtype: types.FunctionType """ - function = parameter_to_arg( - self, - self._resolution.function, - self.pythonic_params, + function = self._resolution.function + + if self.parameter_to_arg: + function = self.parameter_to_arg( + self, + function, + self.pythonic_params, + ) + + function = self._request_response_decorator( + function, self._uri_parsing_decorator ) - uri_parsing_decorator = self._uri_parsing_decorator - function = uri_parsing_decorator(function) - - function = self._request_response_decorator(function) - return function @property diff --git a/connexion/operations/openapi.py b/connexion/operations/openapi.py index 3d253dc..3a2a327 100644 --- a/connexion/operations/openapi.py +++ b/connexion/operations/openapi.py @@ -3,14 +3,11 @@ This module defines an OpenAPIOperation class, a Connexion operation specific fo """ import logging -from copy import copy, deepcopy from connexion.datastructures import MediaTypeDict from connexion.operations.abstract import AbstractOperation - -from ..decorators.uri_parsing import OpenAPIURIParser -from ..http_facts import FORM_CONTENT_TYPES -from ..utils import deep_get, deep_merge, is_null, is_nullable, make_type +from connexion.uri_parsing import OpenAPIURIParser +from connexion.utils import deep_get logger = logging.getLogger("connexion.operations.openapi3") @@ -35,6 +32,7 @@ class OpenAPIOperation(AbstractOperation): randomize_endpoint=None, pythonic_params=False, uri_parser_class=None, + parameter_to_arg=None, ): """ This class uses the OperationID identify the module and function that will handle the operation @@ -88,6 +86,7 @@ class OpenAPIOperation(AbstractOperation): randomize_endpoint=randomize_endpoint, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, + parameter_to_arg=parameter_to_arg, ) self._request_body = operation.get("requestBody", {}) @@ -144,8 +143,9 @@ class OpenAPIOperation(AbstractOperation): def produces(self): return self._produces - def with_definitions(self, schema): + def with_definitions(self, schema: dict): if self.components: + schema.setdefault("schema", {}) schema["schema"]["components"] = self.components return schema @@ -240,6 +240,9 @@ class OpenAPIOperation(AbstractOperation): types[path_defn["name"]] = path_type return types + def body_name(self, _content_type: str) -> str: + return self.request_body.get("x-body-name", "body") + def body_schema(self, content_type: str = None) -> dict: """ The body schema definition for this operation. @@ -267,131 +270,3 @@ class OpenAPIOperation(AbstractOperation): res = content_type_dict.get(content_type, {}) return self.with_definitions(res) return {} - - def _get_body_argument(self, body, arguments, has_kwargs, sanitize): - if len(arguments) <= 0 and not has_kwargs: - return {} - - x_body_name = sanitize(self.request_body.get("x-body-name", "body")) - - if self.consumes[0] in FORM_CONTENT_TYPES: - result = self._get_body_argument_form(body) - else: - result = self._get_body_argument_json(body) - - if x_body_name in arguments or has_kwargs: - return {x_body_name: result} - return {} - - def _get_body_argument_json(self, body): - # if the body came in null, and the schema says it can be null, we decide - # to include no value for the body argument, rather than the default body - if is_nullable(self.body_schema()) and is_null(body): - return None - - if body is None: - default_body = self.body_schema().get("default", {}) - return deepcopy(default_body) - - return body - - def _get_body_argument_form(self, body): - # now determine the actual value for the body (whether it came in or is default) - default_body = self.body_schema().get("default", {}) - body_props = { - k: {"schema": v} - for k, v in self.body_schema().get("properties", {}).items() - } - - # by OpenAPI specification `additionalProperties` defaults to `true` - # see: https://github.com/OAI/OpenAPI-Specification/blame/3.0.2/versions/3.0.2.md#L2305 - additional_props = self.body_schema().get("additionalProperties", True) - - body_arg = deepcopy(default_body) - body_arg.update(body or {}) - - if body_props or additional_props: - return self._get_typed_body_values(body_arg, body_props, additional_props) - return {} - - def _get_typed_body_values(self, body_arg, body_props, additional_props): - """ - Return a copy of the provided body_arg dictionary - whose values will have the appropriate types - as defined in the provided schemas. - - :type body_arg: type dict - :type body_props: dict - :type additional_props: dict|bool - :rtype: dict - """ - additional_props_defn = ( - {"schema": additional_props} if isinstance(additional_props, dict) else None - ) - res = {} - - for key, value in body_arg.items(): - try: - prop_defn = body_props[key] - res[key] = self._get_val_from_param(value, prop_defn) - except KeyError: # pragma: no cover - if not additional_props: - logger.error(f"Body property '{key}' not defined in body schema") - continue - if additional_props_defn is not None: - value = self._get_val_from_param(value, additional_props_defn) - res[key] = value - - return res - - def _build_default_obj_recursive(self, _properties, res): - """takes disparate and nested default keys, and builds up a default object""" - for key, prop in _properties.items(): - if "default" in prop and key not in res: - res[key] = copy(prop["default"]) - elif prop.get("type") == "object" and "properties" in prop: - res.setdefault(key, {}) - res[key] = self._build_default_obj_recursive( - prop["properties"], res[key] - ) - return res - - def _get_default_obj(self, schema): - try: - return deepcopy(schema["default"]) - except KeyError: - _properties = schema.get("properties", {}) - return self._build_default_obj_recursive(_properties, {}) - - def _get_query_defaults(self, query_defns): - defaults = {} - for k, v in query_defns.items(): - try: - if v["schema"]["type"] == "object": - defaults[k] = self._get_default_obj(v["schema"]) - else: - defaults[k] = v["schema"]["default"] - except KeyError: - pass - return defaults - - def _get_query_arguments(self, query, arguments, has_kwargs, sanitize): - query_defns = {p["name"]: p for p in self.parameters if p["in"] == "query"} - default_query_params = self._get_query_defaults(query_defns) - - query_arguments = deepcopy(default_query_params) - query_arguments = deep_merge(query_arguments, query) - return self._query_args_helper( - query_defns, query_arguments, arguments, has_kwargs, sanitize - ) - - def _get_val_from_param(self, value, query_defn): - query_schema = query_defn["schema"] - - if is_nullable(query_schema) and is_null(value): - return None - - if query_schema["type"] == "array": - return [make_type(part, query_schema["items"]["type"]) for part in value] - else: - return make_type(value, query_schema["type"]) diff --git a/connexion/operations/swagger2.py b/connexion/operations/swagger2.py index ad5c104..b7aa7fc 100644 --- a/connexion/operations/swagger2.py +++ b/connexion/operations/swagger2.py @@ -4,14 +4,12 @@ This module defines a Swagger2Operation class, a Connexion operation specific fo import logging import typing as t -from copy import deepcopy +from connexion.exceptions import InvalidSpecification +from connexion.http_facts import FORM_CONTENT_TYPES from connexion.operations.abstract import AbstractOperation - -from ..decorators.uri_parsing import Swagger2URIParser -from ..exceptions import InvalidSpecification -from ..http_facts import FORM_CONTENT_TYPES -from ..utils import deep_get, is_null, is_nullable, make_type +from connexion.uri_parsing import Swagger2URIParser +from connexion.utils import deep_get logger = logging.getLogger("connexion.operations.swagger2") @@ -52,6 +50,7 @@ class Swagger2Operation(AbstractOperation): randomize_endpoint=None, pythonic_params=False, uri_parser_class=None, + parameter_to_arg=None, ): """ :param api: api that this operation is attached to @@ -101,6 +100,7 @@ class Swagger2Operation(AbstractOperation): randomize_endpoint=randomize_endpoint, pythonic_params=pythonic_params, uri_parser_class=uri_parser_class, + parameter_to_arg=parameter_to_arg, ) self._produces = operation.get("produces", app_produces) @@ -222,6 +222,9 @@ class Swagger2Operation(AbstractOperation): except KeyError: raise + def body_name(self, content_type: str = None) -> str: + return self.body_definition(content_type).get("name", "body") + def body_schema(self, content_type: str = None) -> dict: """ The body schema definition for this operation. @@ -237,6 +240,7 @@ class Swagger2Operation(AbstractOperation): :rtype: dict """ + # TODO: cache if content_type in FORM_CONTENT_TYPES: form_parameters = [p for p in self.parameters if p["in"] == "formData"] body_definition = self._transform_form(form_parameters) @@ -248,12 +252,21 @@ class Swagger2Operation(AbstractOperation): method=self.method, path=self.path ) ) - body_definition = body_parameters[0] if body_parameters else {} + body_parameter = body_parameters[0] if body_parameters else {} + body_definition = self._transform_json(body_parameter) return body_definition + def _transform_json(self, body_parameter: dict) -> dict: + """Translate Swagger2 json parameters into OpenAPI 3 jsonschema spec.""" + nullable = body_parameter.get("x-nullable") + if nullable is not None: + body_parameter["schema"]["nullable"] = nullable + return body_parameter + def _transform_form(self, form_parameters: t.List[dict]) -> dict: """Translate Swagger2 form parameters into OpenAPI 3 jsonschema spec.""" properties = {} + defaults = {} required = [] encoding = {} @@ -276,7 +289,7 @@ class Swagger2Operation(AbstractOperation): default = param.get("default") if default is not None: - prop["default"] = default + defaults[param["name"]] = default nullable = param.get("x-nullable") if nullable is not None: @@ -305,6 +318,7 @@ class Swagger2Operation(AbstractOperation): "schema": { "type": "object", "properties": properties, + "default": defaults, "required": required, } } @@ -313,76 +327,3 @@ class Swagger2Operation(AbstractOperation): definition["encoding"] = encoding return definition - - def _get_query_arguments(self, query, arguments, has_kwargs, sanitize): - query_defns = {p["name"]: p for p in self.parameters if p["in"] == "query"} - default_query_params = { - k: v["default"] for k, v in query_defns.items() if "default" in v - } - query_arguments = deepcopy(default_query_params) - query_arguments.update(query) - return self._query_args_helper( - query_defns, query_arguments, arguments, has_kwargs, sanitize - ) - - def _get_body_argument(self, body, arguments, has_kwargs, sanitize): - kwargs = {} - body_parameters = [p for p in self.parameters if p["in"] == "body"] or [{}] - if body is None: - body = deepcopy(body_parameters[0].get("schema", {}).get("default")) - body_name = sanitize(body_parameters[0].get("name")) - - form_defns = {p["name"]: p for p in self.parameters if p["in"] == "formData"} - - default_form_params = { - k: v["default"] for k, v in form_defns.items() if "default" in v - } - - # Add body parameters - if body_name: - if not has_kwargs and body_name not in arguments: - logger.debug("Body parameter '%s' not in function arguments", body_name) - else: - logger.debug("Body parameter '%s' in function arguments", body_name) - kwargs[body_name] = body - - # Add formData parameters - form_arguments = deepcopy(default_form_params) - if form_defns and body: - form_arguments.update(body) - for key, value in form_arguments.items(): - sanitized_key = sanitize(key) - if not has_kwargs and sanitized_key not in arguments: - logger.debug( - "FormData parameter '%s' (sanitized: '%s') not in function arguments", - key, - sanitized_key, - ) - else: - logger.debug( - "FormData parameter '%s' (sanitized: '%s') in function arguments", - key, - sanitized_key, - ) - try: - form_defn = form_defns[key] - except KeyError: # pragma: no cover - logger.error( - "Function argument '%s' (non-sanitized: %s) not defined in specification", - key, - sanitized_key, - ) - else: - kwargs[sanitized_key] = self._get_val_from_param(value, form_defn) - return kwargs - - def _get_val_from_param(self, value, query_defn): - if is_nullable(query_defn) and is_null(value): - return None - - query_schema = query_defn - - if query_schema["type"] == "array": - return [make_type(part, query_defn["items"]["type"]) for part in value] - else: - return make_type(value, query_defn["type"]) diff --git a/connexion/options.py b/connexion/options.py index 28c2655..71f74b3 100644 --- a/connexion/options.py +++ b/connexion/options.py @@ -10,7 +10,7 @@ try: except ImportError: swagger_ui_2_path = swagger_ui_3_path = None -from connexion.decorators.uri_parsing import AbstractURIParser +from connexion.uri_parsing import AbstractURIParser NO_UI_MSG = """The swagger_ui directory could not be found. Please install connexion with extra install: pip install connexion[swagger-ui] diff --git a/connexion/security/security_handler_factory.py b/connexion/security.py similarity index 99% rename from connexion/security/security_handler_factory.py rename to connexion/security.py index 4e42b9c..2edd529 100644 --- a/connexion/security/security_handler_factory.py +++ b/connexion/security.py @@ -13,14 +13,14 @@ import typing as t import httpx -from ..decorators.parameter import inspect_function_arguments -from ..exceptions import ( +from connexion.decorators.parameter import inspect_function_arguments +from connexion.exceptions import ( ConnexionException, OAuthProblem, OAuthResponseProblem, OAuthScopeProblem, ) -from ..utils import get_function_from_name +from connexion.utils import get_function_from_name logger = logging.getLogger("connexion.api.security") diff --git a/connexion/security/__init__.py b/connexion/security/__init__.py deleted file mode 100644 index 136011c..0000000 --- a/connexion/security/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -This module defines SecurityHandlerFactories which support the creation of security -handlers for operations. - -isort:skip_file -""" - -from .security_handler_factory import SecurityHandlerFactory # NOQA diff --git a/connexion/uri_parsing.py b/connexion/uri_parsing.py index f78c5a4..b7133f7 100644 --- a/connexion/uri_parsing.py +++ b/connexion/uri_parsing.py @@ -1,9 +1,9 @@ """ -This module defines view function decorators to split query and path parameters. +This module defines URIParsers which parse query and path parameters according to OpenAPI +serialization rules. """ import abc -import functools import json import logging import re @@ -130,33 +130,6 @@ class AbstractURIParser(metaclass=abc.ABCMeta): return resolved_param - def __call__(self, function): - """ - :type function: types.FunctionType - :rtype: types.FunctionType - """ - - @functools.wraps(function) - def wrapper(request): - def coerce_dict(md): - """MultiDict -> dict of lists""" - try: - return md.to_dict(flat=False) - except AttributeError: - return dict(md.items()) - - query = coerce_dict(request.query) - path_params = coerce_dict(request.path_params) - form = coerce_dict(request.form) - - request.query = self.resolve_query(query) - request.path_params = self.resolve_path(path_params) - request.form = self.resolve_form(form) - response = function(request) - return response - - return wrapper - class OpenAPIURIParser(AbstractURIParser): style_defaults = { @@ -281,7 +254,7 @@ class OpenAPIURIParser(AbstractURIParser): class Swagger2URIParser(AbstractURIParser): """ Adheres to the Swagger2 spec, - Assumes the the last defined query parameter should be used. + Assumes that the last defined query parameter should be used. """ parsable_parameters = ["query", "path", "formData"] diff --git a/connexion/utils.py b/connexion/utils.py index e37dba7..c1907f3 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -35,18 +35,24 @@ def boolean(s): # https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#data-types -TYPE_MAP = { +TYPE_MAP: t.Dict[str, t.Any] = { "integer": int, "number": float, "string": str, "boolean": boolean, "array": list, "object": dict, + "file": lambda x: x, # Don't cast files } # map of swagger types to python types -def make_type(value, _type): - type_func = TYPE_MAP[_type] # convert value to right type +def make_type(value: t.Any, type_: str, format_: t.Optional[str]) -> t.Any: + """Cast a value to the type defined in the specification.""" + # In OpenAPI, files are represented with string type and binary format + if type_ == "string" and format_ == "binary": + type_ = "file" + + type_func = TYPE_MAP[type_] return type_func(value) @@ -141,6 +147,8 @@ def is_json_mimetype(mimetype): :type mimetype: str :rtype: bool """ + if mimetype is None: + return False maintype, subtype = mimetype.split("/") # type: str, str if ";" in subtype: @@ -209,9 +217,7 @@ def has_coroutine(function, api=None): return iscorofunc(function) else: - return any( - iscorofunc(func) for func in (function, api.get_request, api.get_response) - ) + return any(iscorofunc(func) for func in (function, api.get_response)) def yamldumper(openapi): diff --git a/connexion/validators/form_data.py b/connexion/validators/form_data.py index 23a3d11..28c628e 100644 --- a/connexion/validators/form_data.py +++ b/connexion/validators/form_data.py @@ -6,13 +6,13 @@ from starlette.datastructures import FormData, Headers, UploadFile from starlette.formparsers import FormParser, MultiPartParser from starlette.types import Receive, Scope -from connexion.decorators.uri_parsing import AbstractURIParser from connexion.exceptions import ( BadRequestProblem, ExtraParameterProblem, TypeValidationError, ) from connexion.json_schema import Draft4RequestValidator +from connexion.uri_parsing import AbstractURIParser from connexion.utils import coerce_type, is_null logger = logging.getLogger("connexion.validators.form_data") diff --git a/connexion/validators/parameter.py b/connexion/validators/parameter.py index 4c02f62..dc8f7a8 100644 --- a/connexion/validators/parameter.py +++ b/connexion/validators/parameter.py @@ -102,6 +102,9 @@ class ParameterValidator: return self.validate_parameter("query", val, param) def validate_path_parameter(self, param, request): + # TODO: activate + # path_params = self.uri_parser.resolve_path(request.path_params) + # val = path_params.get(param["name"].replace("-", "_")) val = request.path_params.get(param["name"].replace("-", "_")) return self.validate_parameter("path", val, param) diff --git a/setup.cfg b/setup.cfg index 3c6e79c..3480374 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [bdist_wheel] -universal=1 +universal=1 \ No newline at end of file diff --git a/tests/api/test_bootstrap.py b/tests/api/test_bootstrap.py index cd95f8b..d71755f 100644 --- a/tests/api/test_bootstrap.py +++ b/tests/api/test_bootstrap.py @@ -65,7 +65,7 @@ def test_app_with_different_server_option(simple_api_spec_dir, spec): def test_app_with_different_uri_parser(simple_api_spec_dir): - from connexion.decorators.uri_parsing import FirstValueURIParser + from connexion.uri_parsing import FirstValueURIParser app = App( __name__, diff --git a/tests/api/test_parameters.py b/tests/api/test_parameters.py index 2219ded..1ae37ed 100644 --- a/tests/api/test_parameters.py +++ b/tests/api/test_parameters.py @@ -172,7 +172,7 @@ def test_path_parameter_someint__bad(simple_app): # non-integer values will not match Flask route app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-int-path/foo") # type: flask.Response - assert resp.status_code == 400, resp.text + assert resp.status_code == 404, resp.text @pytest.mark.parametrize( @@ -205,7 +205,7 @@ def test_path_parameter_somefloat__bad(simple_app): # non-float values will not match Flask route app_client = simple_app.app.test_client() resp = app_client.get("/v1.0/test-float-path/123,45") # type: flask.Response - assert resp.status_code == 400, resp.text + assert resp.status_code == 404, resp.text def test_default_param(strict_app): @@ -280,7 +280,7 @@ def test_formdata_file_upload(simple_app): app_client = simple_app.app.test_client() resp = app_client.post( "/v1.0/test-formData-file-upload", - data={"formData": (BytesIO(b"file contents"), "filename.txt")}, + data={"fileData": (BytesIO(b"file contents"), "filename.txt")}, ) assert resp.status_code == 200 response = json.loads(resp.data.decode("utf-8", "replace")) @@ -293,8 +293,8 @@ def test_formdata_file_upload_bad_request(simple_app): assert resp.status_code == 400 response = json.loads(resp.data.decode("utf-8", "replace")) assert response["detail"] in [ - "Missing formdata parameter 'formData'", - "'formData' is a required property", # OAS3 + "Missing formdata parameter 'fileData'", + "'fileData' is a required property", # OAS3 ] @@ -302,7 +302,7 @@ def test_formdata_file_upload_missing_param(simple_app): app_client = simple_app.app.test_client() resp = app_client.post( "/v1.0/test-formData-file-upload-missing-param", - data={"missing_formData": (BytesIO(b"file contents"), "example.txt")}, + data={"missing_fileData": (BytesIO(b"file contents"), "example.txt")}, ) assert resp.status_code == 200 diff --git a/tests/decorators/test_parameter.py b/tests/decorators/test_parameter.py index dac2638..685c9f8 100644 --- a/tests/decorators/test_parameter.py +++ b/tests/decorators/test_parameter.py @@ -3,11 +3,19 @@ from unittest.mock import MagicMock from connexion.decorators.parameter import parameter_to_arg, pythonic -def test_injection(): - request = MagicMock(name="request", path_params={"p1": "123"}) - request.args = {} +async def test_injection(): + request = MagicMock(name="request") + request.query_params = {} + request.path_params = {"p1": "123"} request.headers = {} - request.params = {} + request.content_type = "application/json" + + async def coro(): + return + + request.json = coro + request.loop = None + request.context = {} func = MagicMock() @@ -16,17 +24,29 @@ def test_injection(): class Op: consumes = ["application/json"] + parameters = [] + method = "GET" - def get_arguments(self, *args, **kwargs): - return {"p1": "123"} + def body_name(self, *args, **kwargs): + return "body" - parameter_to_arg(Op(), handler)(request) + parameter_decorator = parameter_to_arg(Op(), handler) + await parameter_decorator(request) func.assert_called_with(p1="123") -def test_injection_with_context(): +async def test_injection_with_context(): request = MagicMock(name="request") + async def coro(): + return + + request.json = coro + request.loop = None + request.context = {} + request.content_type = "application/json" + request.path_params = {"p1": "123"} + func = MagicMock() def handler(context_, **kwargs): @@ -34,11 +54,14 @@ def test_injection_with_context(): class Op2: consumes = ["application/json"] + parameters = [] + method = "GET" - def get_arguments(self, *args, **kwargs): - return {"p1": "123"} + def body_name(self, *args, **kwargs): + return "body" - parameter_to_arg(Op2(), handler)(request) + parameter_decorator = parameter_to_arg(Op2(), handler) + await parameter_decorator(request) func.assert_called_with(request.context, p1="123") diff --git a/tests/decorators/test_uri_parsing.py b/tests/decorators/test_uri_parsing.py index b5dc2a4..58e6c45 100644 --- a/tests/decorators/test_uri_parsing.py +++ b/tests/decorators/test_uri_parsing.py @@ -1,5 +1,5 @@ import pytest -from connexion.decorators.uri_parsing import ( +from connexion.uri_parsing import ( AlwaysMultiURIParser, FirstValueURIParser, OpenAPIURIParser, @@ -46,7 +46,9 @@ MULTI = "multi" (AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY2, PIPES), ], ) -def test_uri_parser_query_params(parser_class, expected, query_in, collection_format): +async def test_uri_parser_query_params( + parser_class, expected, query_in, collection_format +): class Request: query = query_in path_params = {} @@ -63,9 +65,9 @@ def test_uri_parser_query_params(parser_class, expected, query_in, collection_fo } ] body_defn = {} - p = parser_class(parameters, body_defn) - res = p(lambda x: x)(request) - assert res.query["letters"] == expected + parser = parser_class(parameters, body_defn) + res = parser.resolve_query(request.query.to_dict(flat=False)) + assert res["letters"] == expected @pytest.mark.parametrize( @@ -82,7 +84,9 @@ def test_uri_parser_query_params(parser_class, expected, query_in, collection_fo (AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY2, PIPES), ], ) -def test_uri_parser_form_params(parser_class, expected, query_in, collection_format): +async def test_uri_parser_form_params( + parser_class, expected, query_in, collection_format +): class Request: query = {} form = query_in @@ -99,9 +103,9 @@ def test_uri_parser_form_params(parser_class, expected, query_in, collection_for } ] body_defn = {} - p = parser_class(parameters, body_defn) - res = p(lambda x: x)(request) - assert res.form["letters"] == expected + parser = parser_class(parameters, body_defn) + res = parser.resolve_form(request.form.to_dict(flat=False)) + assert res["letters"] == expected @pytest.mark.parametrize( @@ -115,7 +119,9 @@ def test_uri_parser_form_params(parser_class, expected, query_in, collection_for (AlwaysMultiURIParser, ["d", "e", "f"], PATH2, PIPES), ], ) -def test_uri_parser_path_params(parser_class, expected, query_in, collection_format): +async def test_uri_parser_path_params( + parser_class, expected, query_in, collection_format +): class Request: query = {} form = {} @@ -132,9 +138,9 @@ def test_uri_parser_path_params(parser_class, expected, query_in, collection_for } ] body_defn = {} - p = parser_class(parameters, body_defn) - res = p(lambda x: x)(request) - assert res.path_params["letters"] == expected + parser = parser_class(parameters, body_defn) + res = parser.resolve_path(request.path_params) + assert res["letters"] == expected @pytest.mark.parametrize( @@ -149,7 +155,7 @@ def test_uri_parser_path_params(parser_class, expected, query_in, collection_for (AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY4, PIPES), ], ) -def test_uri_parser_query_params_with_square_brackets( +async def test_uri_parser_query_params_with_square_brackets( parser_class, expected, query_in, collection_format ): class Request: @@ -168,9 +174,9 @@ def test_uri_parser_query_params_with_square_brackets( } ] body_defn = {} - p = parser_class(parameters, body_defn) - res = p(lambda x: x)(request) - assert res.query["letters[eq]"] == expected + parser = parser_class(parameters, body_defn) + res = parser.resolve_query(request.query.to_dict(flat=False)) + assert res["letters[eq]"] == expected @pytest.mark.parametrize( @@ -188,7 +194,7 @@ def test_uri_parser_query_params_with_square_brackets( (AlwaysMultiURIParser, ["a"], QUERY6, PIPES), ], ) -def test_uri_parser_query_params_with_underscores( +async def test_uri_parser_query_params_with_underscores( parser_class, expected, query_in, collection_format ): class Request: @@ -207,9 +213,9 @@ def test_uri_parser_query_params_with_underscores( } ] body_defn = {} - p = parser_class(parameters, body_defn) - res = p(lambda x: x)(request) - assert res.query["letters_eq"] == expected + parser = parser_class(parameters, body_defn) + res = parser.resolve_query(request.query.to_dict(flat=False)) + assert res["letters_eq"] == expected @pytest.mark.parametrize( @@ -231,7 +237,7 @@ def test_uri_parser_query_params_with_underscores( ), ], ) -def test_uri_parser_query_params_with_malformed_names( +async def test_uri_parser_query_params_with_malformed_names( parser_class, query_in, collection_format, explode, expected ): class Request: @@ -253,6 +259,6 @@ def test_uri_parser_query_params_with_malformed_names( } ] body_defn = {} - p = parser_class(parameters, body_defn) - res = p(lambda x: x)(request) - assert res.query == expected + parser = parser_class(parameters, body_defn) + res = parser.resolve_query(request.query.to_dict(flat=False)) + assert res == expected diff --git a/tests/fakeapi/hello/__init__.py b/tests/fakeapi/hello/__init__.py index 5fb61ad..b1c36cf 100644 --- a/tests/fakeapi/hello/__init__.py +++ b/tests/fakeapi/hello/__init__.py @@ -84,7 +84,7 @@ def get_bye_secure(name, user, token_info): def get_bye_secure_from_flask(): - return "Goodbye {user} (Secure!)".format(user=context["user"]) + return "Goodbye {user} (Secure!)".format(user=context.context["user"]) def get_bye_secure_from_connexion(context_): @@ -314,9 +314,11 @@ def test_formdata_missing_param(): return "" -def test_formdata_file_upload(formData, **kwargs): - filename = formData.filename - contents = formData.read().decode("utf-8", "replace") +def test_formdata_file_upload(fileData, **kwargs): + """In Swagger, form paramaeters and files are passed separately""" + filename = fileData.filename + contents = fileData.read() + contents = contents.decode("utf-8", "replace") return {filename: contents} diff --git a/tests/fixtures/simple/openapi.yaml b/tests/fixtures/simple/openapi.yaml index fdcfea8..7beb631 100644 --- a/tests/fixtures/simple/openapi.yaml +++ b/tests/fixtures/simple/openapi.yaml @@ -3,25 +3,6 @@ info: title: '{{title}}' version: '1.0' paths: - '/greeting/{name}': - post: - summary: Generate greeting - description: Generates a greeting message. - operationId: fakeapi.hello.post_greeting - responses: - '200': - description: greeting response - content: - 'application/json': - schema: - type: object - parameters: - - name: name - in: path - description: Name of the person to greet. - required: true - schema: - type: string '/greeting/{name}/{remainder}': post: summary: Generate greeting and collect the remainder of the url @@ -48,6 +29,25 @@ paths: schema: type: string format: path + '/greeting/{name}': + post: + summary: Generate greeting + description: Generates a greeting message. + operationId: fakeapi.hello.post_greeting + responses: + '200': + description: greeting response + content: + 'application/json': + schema: + type: object + parameters: + - name: name + in: path + description: Name of the person to greet. + required: true + schema: + type: string '/greetings/{name}': get: summary: Generate greeting @@ -600,11 +600,11 @@ paths: schema: type: object properties: - formData: + fileData: type: string format: binary required: - - formData + - fileData /test-formData-file-upload-missing-param: post: summary: 'Test formData with file type, missing parameter in handler' @@ -618,11 +618,11 @@ paths: schema: type: object properties: - missing_formData: + missing_fileData: type: string format: binary required: - - missing_formData + - missing_fileData /test-bool-param: get: summary: Test usage of boolean default value @@ -690,6 +690,24 @@ paths: responses: '200': description: OK + /goodday/noheader: + post: + summary: Generate good day greeting + description: Generates a good day message. + operationId: fakeapi.hello.post_goodday_no_header + responses: + '201': + description: goodday response + headers: + Location: + description: The URI of the created resource + schema: + type: string + required: true + content: + 'application/json': + schema: + type: object '/goodday/{name}': post: summary: Generate good day greeting @@ -715,24 +733,6 @@ paths: required: true schema: type: string - /goodday/noheader: - post: - summary: Generate good day greeting - description: Generates a good day message. - operationId: fakeapi.hello.post_goodday_no_header - responses: - '201': - description: goodday response - headers: - Location: - description: The URI of the created resource - schema: - type: string - required: true - content: - 'application/json': - schema: - type: object '/goodevening/{name}': post: summary: Generate good evening diff --git a/tests/fixtures/simple/swagger.yaml b/tests/fixtures/simple/swagger.yaml index f6543da..d271d72 100644 --- a/tests/fixtures/simple/swagger.yaml +++ b/tests/fixtures/simple/swagger.yaml @@ -7,22 +7,6 @@ info: basePath: /v1.0 paths: - /greeting/{name}: - post: - summary: Generate greeting - description: Generates a greeting message. - operationId: fakeapi.hello.post_greeting - responses: - '200': - description: greeting response - schema: - type: object - parameters: - - name: name - in: path - description: Name of the person to greet. - required: true - type: string /greeting/{name}/{remainder}: post: summary: Generate greeting and collect the remainder of the url @@ -45,6 +29,22 @@ paths: required: true type: string format: path + /greeting/{name}: + post: + summary: Generate greeting + description: Generates a greeting message. + operationId: fakeapi.hello.post_greeting + responses: + '200': + description: greeting response + schema: + type: object + parameters: + - name: name + in: path + description: Name of the person to greet. + required: true + type: string /greetings/{name}: get: summary: Generate greeting @@ -457,7 +457,7 @@ paths: consumes: - multipart/form-data parameters: - - name: formData + - name: fileData type: file in: formData required: true @@ -472,7 +472,7 @@ paths: consumes: - multipart/form-data parameters: - - name: missing_formData + - name: missing_fileData type: file in: formData required: true diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py new file mode 100644 index 0000000..c196f72 --- /dev/null +++ b/tests/test_datastructures.py @@ -0,0 +1,22 @@ +from connexion.datastructures import MediaTypeDict + + +def test_media_type_dict(): + d = MediaTypeDict( + { + "*/*": "*/*", + "*/json": "*/json", + "*/*json": "*/*json", + "multipart/*": "multipart/*", + "multipart/form-data": "multipart/form-data", + } + ) + + assert d["application/json"] == "*/json" + assert d["application/problem+json"] == "*/*json" + assert d["application/x-www-form-urlencoded"] == "*/*" + assert d["multipart/form-data"] == "multipart/form-data" + assert d["multipart/byteranges"] == "multipart/*" + + # Test __contains__ + assert "application/json" in d diff --git a/tests/test_operation2.py b/tests/test_operation2.py index 2ac2083..e22d5b8 100644 --- a/tests/test_operation2.py +++ b/tests/test_operation2.py @@ -735,7 +735,6 @@ def test_form_transformation(api): "properties": { "param": { "type": "string", - "default": "foo@bar.com", "format": "email", }, "array_param": { @@ -746,6 +745,7 @@ def test_form_transformation(api): "nullable": True, }, }, + "default": {"param": "foo@bar.com"}, "required": ["param"], }, "encoding": { diff --git a/tests/test_validation.py b/tests/test_validation.py index eb0a606..8a8bd45 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock from urllib.parse import quote_plus import pytest -from connexion.decorators.uri_parsing import Swagger2URIParser from connexion.exceptions import BadRequestProblem +from connexion.uri_parsing import Swagger2URIParser from connexion.validators.parameter import ParameterValidator from starlette.datastructures import QueryParams