Move parameter decorator related methods out of operation classes

This commit is contained in:
Robbe Sneyders
2022-12-16 09:49:21 +01:00
parent 2438114b71
commit 7acbad0691
32 changed files with 775 additions and 722 deletions

View File

@@ -20,7 +20,7 @@ from .utils import not_installed_error # NOQA
try:
from flask import request # NOQA
from .apis.flask_api import FlaskApi, context # NOQA
from .apis.flask_api import FlaskApi # NOQA
from .apps.flask_app import FlaskApp
except ImportError as e: # pragma: no cover
_flask_not_installed_error = not_installed_error(e)

View File

@@ -1,7 +1,6 @@
"""
This module defines an AbstractAPI, which defines a standardized interface for a Connexion API.
"""
import abc
import logging
import pathlib
@@ -9,15 +8,18 @@ import sys
import typing as t
from enum import Enum
from ..datastructures import NoContent
from ..exceptions import ResolverError
from ..http_facts import METHODS
from ..jsonifier import Jsonifier
from ..lifecycle import ConnexionResponse
from ..operations import make_operation
from ..options import ConnexionOptions
from ..resolver import Resolver
from ..spec import Specification
from starlette.requests import Request
from connexion.datastructures import NoContent
from connexion.decorators.parameter import parameter_to_arg
from connexion.exceptions import ResolverError
from connexion.http_facts import METHODS
from connexion.jsonifier import Jsonifier
from connexion.lifecycle import ConnexionResponse
from connexion.operations import make_operation
from connexion.options import ConnexionOptions
from connexion.resolver import Resolver
from connexion.spec import Specification
MODULE_PATH = pathlib.Path(__file__).absolute().parent.parent
SWAGGER_UI_URL = "ui"
@@ -231,30 +233,19 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
self.resolver,
pythonic_params=self.pythonic_params,
uri_parser_class=self.options.uri_parser_class,
parameter_to_arg=parameter_to_arg,
)
self._add_operation_internal(method, path, operation)
@classmethod
@abc.abstractmethod
def get_request(self, *args, **kwargs):
def get_request(cls, uri_parser) -> Request:
"""
This method converts the user framework request to a ConnexionRequest.
"""
@classmethod
@abc.abstractmethod
def get_response(self, response, mimetype=None, request=None):
"""
This method converts a handler response to a framework response.
This method should just retrieve response from handler then call `cls._get_response`.
:param response: A response to cast (tuple, framework response, etc).
:param mimetype: The response mimetype.
:type mimetype: Union[None, str]
:param request: The request associated with this response (the user framework request).
"""
@classmethod
def _get_response(cls, response, mimetype=None, extra_context=None):
def _get_response(cls, response, mimetype=None):
"""
This method converts a handler response to a framework response.
The response can be a ConnexionResponse, an operation handler, a framework response or a tuple.
@@ -262,31 +253,24 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
:param response: A response to cast (tuple, framework response, etc).
:param mimetype: The response mimetype.
:type mimetype: Union[None, str]
:param extra_context: dict of extra details, like url, to include in logs
:type extra_context: Union[None, dict]
"""
if extra_context is None:
extra_context = {}
logger.debug(
"Getting data and status code",
extra={"data": response, "data_type": type(response), **extra_context},
extra={"data": response, "data_type": type(response)},
)
if isinstance(response, ConnexionResponse):
framework_response = cls._connexion_to_framework_response(
response, mimetype, extra_context
response, mimetype
)
else:
framework_response = cls._response_from_handler(
response, mimetype, extra_context
)
framework_response = cls._response_from_handler(response, mimetype)
logger.debug(
"Got framework response",
extra={
"response": framework_response,
"response_type": type(framework_response),
**extra_context,
},
)
return framework_response
@@ -298,7 +282,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
t.Any, str, t.Tuple[str], t.Tuple[str, int], t.Tuple[str, int, dict]
],
mimetype: str,
extra_context: t.Optional[dict] = None,
) -> t.Any:
"""
Create a framework response from the operation handler data.
@@ -311,7 +294,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
:param response: A response from an operation handler.
:param mimetype: The response mimetype.
:param extra_context: dict of extra details, like url, to include in logs
"""
if cls._is_framework_response(response):
return response
@@ -320,9 +302,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
len_response = len(response)
if len_response == 1:
(data,) = response
return cls._build_response(
mimetype=mimetype, data=data, extra_context=extra_context
)
return cls._build_response(mimetype=mimetype, data=data)
if len_response == 2:
if isinstance(response[1], (int, Enum)):
data, status_code = response
@@ -330,7 +310,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
mimetype=mimetype,
data=data,
status_code=status_code,
extra_context=extra_context,
)
else:
data, headers = response
@@ -338,7 +317,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
mimetype=mimetype,
data=data,
headers=headers,
extra_context=extra_context,
)
elif len_response == 3:
data, status_code, headers = response
@@ -347,7 +325,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
data=data,
status_code=status_code,
headers=headers,
extra_context=extra_context,
)
else:
raise TypeError(
@@ -356,9 +333,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
" (body, status), or (body, headers)."
)
else:
return cls._build_response(
mimetype=mimetype, data=response, extra_context=extra_context
)
return cls._build_response(mimetype=mimetype, data=response)
@classmethod
def get_connexion_response(cls, response, mimetype=None):
@@ -384,7 +359,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
@classmethod
@abc.abstractmethod
def _connexion_to_framework_response(cls, response, mimetype, extra_context=None):
def _connexion_to_framework_response(cls, response, mimetype):
"""Cast ConnexionResponse to framework response class"""
@classmethod
@@ -396,7 +371,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
content_type=None,
status_code=None,
headers=None,
extra_context=None,
):
"""
Create a framework response from the provided arguments.
@@ -407,16 +381,12 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
:type status_code: int
:param headers: The response status code.
:type headers: Union[Iterable[Tuple[str, str]], Dict[str, str]]
:param extra_context: dict of extra details, like url, to include in logs
:type extra_context: Union[None, dict]
:return A framework response.
:rtype Response
"""
@classmethod
def _prepare_body_and_status_code(
cls, data, mimetype, status_code=None, extra_context=None
):
def _prepare_body_and_status_code(cls, data, mimetype, status_code=None):
if data is NoContent:
data = None
@@ -435,12 +405,10 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
else:
body = data
if extra_context is None:
extra_context = {}
logger.debug(
"Prepared body and status code (%d)",
status_code,
extra={"body": body, **extra_context},
extra={"body": body},
)
return body, status_code, mimetype

View File

@@ -2,13 +2,9 @@
This module defines a Flask Connexion API which implements translations between Flask and
Connexion requests / responses.
"""
import logging
import warnings
from typing import Any
import flask
from werkzeug.local import LocalProxy
from connexion.apis import flask_utils
from connexion.apis.abstract import AbstractAPI
@@ -52,7 +48,7 @@ class FlaskApi(AbstractAPI):
)
@classmethod
def get_response(cls, response, mimetype=None, request=None):
def get_response(cls, response, mimetype=None):
"""Gets ConnexionResponse instance for the operation handler
result. Status Code and Headers for response. If only body
data is returned by the endpoint function, then the status
@@ -64,9 +60,7 @@ class FlaskApi(AbstractAPI):
:type response: flask.Response | (flask.Response,) | (flask.Response, int) | (flask.Response, dict) | (flask.Response, int, dict)
:rtype: ConnexionResponse
"""
return cls._get_response(
response, mimetype=mimetype, extra_context={"url": flask.request.url}
)
return cls._get_response(response, mimetype=mimetype)
@classmethod
def _is_framework_response(cls, response):
@@ -86,7 +80,7 @@ class FlaskApi(AbstractAPI):
)
@classmethod
def _connexion_to_framework_response(cls, response, mimetype, extra_context=None):
def _connexion_to_framework_response(cls, response, mimetype):
"""Cast ConnexionResponse to framework response class"""
flask_response = cls._build_response(
mimetype=response.mimetype or mimetype,
@@ -94,7 +88,6 @@ class FlaskApi(AbstractAPI):
headers=response.headers,
status_code=response.status_code,
data=response.body,
extra_context=extra_context,
)
return flask_response
@@ -107,7 +100,6 @@ class FlaskApi(AbstractAPI):
headers=None,
status_code=None,
data=None,
extra_context=None,
):
if cls._is_framework_response(data):
return flask.current_app.make_response((data, status_code, headers))
@@ -116,7 +108,6 @@ class FlaskApi(AbstractAPI):
data=data,
mimetype=mimetype,
status_code=status_code,
extra_context=extra_context,
)
kwargs = {
@@ -133,61 +124,14 @@ class FlaskApi(AbstractAPI):
def _serialize_data(cls, data, mimetype):
if isinstance(mimetype, str) and is_json_mimetype(mimetype):
body = cls.jsonifier.dumps(data)
elif not (isinstance(data, bytes) or isinstance(data, str)):
warnings.warn(
"Implicit (flask) JSON serialization will change in the next major version. "
"This is triggered because a response body is being serialized as JSON "
"even though the mimetype is not a JSON type. "
"This will be replaced by something that is mimetype-specific and may "
"raise an error instead of silently converting everything to JSON. "
"Please make sure to specify media/mime types in your specs.",
FutureWarning, # a Deprecation targeted at application users.
)
body = cls.jsonifier.dumps(data)
else:
body = data
return body, mimetype
@classmethod
def get_request(cls, *args, **params):
# type: (*Any, **Any) -> ConnexionRequest
"""Gets ConnexionRequest instance for the operation handler
result. Status Code and Headers for response. If only body
data is returned by the endpoint function, then the status
code will be set to 200 and no headers will be added.
If the returned object is a flask.Response then it will just
pass the information needed to recreate it.
:rtype: ConnexionRequest
"""
flask_request = flask.request
scope = flask_request.environ["asgi.scope"]
context_dict = scope.get("extensions", {}).get("connexion_context", {})
setattr(flask.globals.request_ctx, "connexion_context", context_dict)
request = ConnexionRequest(
flask_request.url,
flask_request.method,
headers=flask_request.headers,
form=flask_request.form,
query=flask_request.args,
body=flask_request.get_data(),
json_getter=lambda: flask_request.get_json(silent=True),
files=flask_request.files,
path_params=params,
context=context_dict,
cookies=flask_request.cookies,
)
logger.debug(
"Getting data and status code",
extra={
"data": request.body,
"data_type": type(request.body),
"url": request.url,
},
)
return request
def get_request(cls, uri_parser) -> ConnexionRequest:
return ConnexionRequest(flask.request, uri_parser=uri_parser)
@classmethod
def _set_jsonifier(cls):
@@ -195,10 +139,3 @@ class FlaskApi(AbstractAPI):
Use Flask specific JSON loader
"""
cls.jsonifier = Jsonifier(flask.json, indent=2)
def _get_context():
return getattr(flask.globals.request_ctx, "connexion_context")
context = LocalProxy(_get_context)

View File

@@ -7,6 +7,7 @@ import pathlib
from types import FunctionType # NOQA
import a2wsgi
import asgiref.wsgi
import flask
import werkzeug.exceptions
from flask import signals

9
connexion/context.py Normal file
View File

@@ -0,0 +1,9 @@
from asyncio import AbstractEventLoop
from contextvars import ContextVar
_context: ContextVar[AbstractEventLoop] = ContextVar("CONTEXT")
def __getattr__(name):
if name == "context":
return _context.get()

View File

@@ -22,7 +22,7 @@ class RequestResponseDecorator:
self.api = api
self.mimetype = mimetype
def __call__(self, function):
def __call__(self, function, uri_parser):
"""
:type function: types.FunctionType
:rtype: types.FunctionType
@@ -31,7 +31,7 @@ class RequestResponseDecorator:
@functools.wraps(function)
async def wrapper(*args, **kwargs):
connexion_request = self.api.get_request(*args, **kwargs)
connexion_request = self.api.get_request(uri_parser=uri_parser)
while asyncio.iscoroutine(connexion_request):
connexion_request = await connexion_request
@@ -40,7 +40,7 @@ class RequestResponseDecorator:
connexion_response = await connexion_response
framework_response = self.api.get_response(
connexion_response, self.mimetype, connexion_request
connexion_response, self.mimetype
)
while asyncio.iscoroutine(framework_response):
framework_response = await framework_response
@@ -51,8 +51,8 @@ class RequestResponseDecorator:
@functools.wraps(function)
def wrapper(*args, **kwargs):
request = self.api.get_request(*args, **kwargs)
request = self.api.get_request(uri_parser)
response = function(request)
return self.api.get_response(response, self.mimetype, request)
return self.api.get_response(response, self.mimetype)
return wrapper

View File

@@ -1,112 +1,69 @@
"""
This module defines a decorator to convert request parameters to arguments for the view function.
This module defines a utility functions to convert request parameters to arguments for the view
function.
"""
import builtins
import functools
import inspect
import keyword
import logging
import re
from typing import Any
import typing as t
from copy import copy, deepcopy
import inflection
from ..http_facts import FORM_CONTENT_TYPES
from ..lifecycle import ConnexionRequest # NOQA
from ..utils import all_json
from connexion.http_facts import FORM_CONTENT_TYPES
from connexion.lifecycle import ConnexionRequest
from connexion.operations import AbstractOperation, Swagger2Operation
from connexion.utils import (
deep_merge,
is_json_mimetype,
is_null,
is_nullable,
make_type,
)
logger = logging.getLogger(__name__)
CONTEXT_NAME = "context_"
def inspect_function_arguments(function): # pragma: no cover
"""
Returns the list of variables names of a function and if it
accepts keyword arguments.
:type function: Callable
:rtype: tuple[list[str], bool]
"""
parameters = inspect.signature(function).parameters
bound_arguments = [
name
for name, p in parameters.items()
if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
]
has_kwargs = any(p.kind == p.VAR_KEYWORD for p in parameters.values())
return list(bound_arguments), has_kwargs
def snake_and_shadow(name):
"""
Converts the given name into Pythonic form. Firstly it converts CamelCase names to snake_case. Secondly it looks to
see if the name matches a known built-in and if it does it appends an underscore to the name.
:param name: The parameter name
:type name: str
:return:
"""
snake = inflection.underscore(name)
if snake in builtins.__dict__ or keyword.iskeyword(snake):
return f"{snake}_"
return snake
def sanitized(name):
return name and re.sub(
"^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", re.sub(r"\[(?!])", "_", name))
)
def pythonic(name):
name = name and snake_and_shadow(name)
return sanitized(name)
def parameter_to_arg(operation, function, pythonic_params=False):
"""
Pass query and body parameters as keyword arguments to handler function.
See (https://github.com/zalando/connexion/issues/59)
:param operation: The operation being called
:type operation: connexion.operations.AbstractOperation
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to
any shadowed built-ins
:type pythonic_params: bool
"""
consumes = operation.consumes
def parameter_to_arg(
operation: AbstractOperation,
function: t.Callable,
pythonic_params: bool = False,
) -> t.Callable[[ConnexionRequest], t.Any]:
sanitize = pythonic if pythonic_params else sanitized
arguments, has_kwargs = inspect_function_arguments(function)
@functools.wraps(function)
def wrapper(request):
# type: (ConnexionRequest) -> Any
logger.debug("Function Arguments: %s", arguments)
def wrapper(request: ConnexionRequest) -> t.Any:
kwargs = {}
if all_json(consumes):
request_body = request.json
elif consumes[0] in FORM_CONTENT_TYPES:
request_body = request.form
body_name = sanitize(operation.body_name(request.content_type))
if body_name in arguments or has_kwargs:
request_body = get_body(request)
# Pass form contents separately for Swagger2 for backward compatibility with Connexion 2
# Checking for body_name is not enough
elif request.mimetype in FORM_CONTENT_TYPES and isinstance(
operation, Swagger2Operation
):
request_body = get_body(request)
else:
request_body = request.body
try:
query = request.query.to_dict(flat=False)
except AttributeError:
query = dict(request.query.items())
request_body = None
kwargs.update(
operation.get_arguments(
request.path_params,
query,
request_body,
request.files,
arguments,
has_kwargs,
sanitize,
get_arguments(
operation,
path_params=request.view_args,
query_params=request.args,
body=request_body,
files=request.files,
arguments=arguments,
has_kwargs=has_kwargs,
sanitize=sanitize,
content_type=request.content_type,
)
)
@@ -127,3 +84,348 @@ def parameter_to_arg(operation, function, pythonic_params=False):
return function(**kwargs)
return wrapper
def get_body(request: ConnexionRequest) -> t.Any:
"""Get body from the request based on the content type."""
if is_json_mimetype(request.content_type):
return request.get_json(silent=True)
elif request.mimetype in FORM_CONTENT_TYPES:
return request.form
else:
# Return explicit None instead of empty bytestring so it is handled as null downstream
return request.get_data() or None
def inspect_function_arguments(function: t.Callable) -> t.Tuple[t.List[str], bool]:
"""
Returns the list of variables names of a function and if it
accepts keyword arguments.
"""
parameters = inspect.signature(function).parameters
bound_arguments = [
name
for name, p in parameters.items()
if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
]
has_kwargs = any(p.kind == p.VAR_KEYWORD for p in parameters.values())
return list(bound_arguments), has_kwargs
def snake_and_shadow(name: str) -> str:
"""
Converts the given name into Pythonic form. Firstly it converts CamelCase names to snake_case. Secondly it looks to
see if the name matches a known built-in and if it does it appends an underscore to the name.
:param name: The parameter name
"""
snake = inflection.underscore(name)
if snake in builtins.__dict__ or keyword.iskeyword(snake):
return f"{snake}_"
return snake
def sanitized(name: str) -> str:
return name and re.sub(
"^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", re.sub(r"\[(?!])", "_", name))
)
def pythonic(name: str) -> str:
name = name and snake_and_shadow(name)
return sanitized(name)
def get_arguments(
operation: AbstractOperation,
*,
path_params: dict,
query_params: dict,
body: t.Any,
files: dict,
arguments: t.List[str],
has_kwargs: bool,
sanitize: t.Callable,
content_type: str,
) -> t.Dict[str, t.Any]:
"""
get arguments for handler function
"""
ret = {}
ret.update(_get_path_arguments(path_params, operation=operation, sanitize=sanitize))
ret.update(
_get_query_arguments(
query_params,
operation=operation,
arguments=arguments,
has_kwargs=has_kwargs,
sanitize=sanitize,
)
)
if operation.method.upper() in ["PATCH", "POST", "PUT"]:
ret.update(
_get_body_argument(
body,
operation=operation,
arguments=arguments,
has_kwargs=has_kwargs,
sanitize=sanitize,
content_type=content_type,
)
)
ret.update(_get_file_arguments(files, arguments, has_kwargs))
return ret
def _get_path_arguments(
path_params: dict, *, operation: AbstractOperation, sanitize: t.Callable
) -> dict:
"""
Extract handler function arguments from path parameters
"""
kwargs = {}
path_definitions = {
parameter["name"]: parameter
for parameter in operation.parameters
if parameter["in"] == "path"
}
for name, value in path_params.items():
sanitized_key = sanitize(name)
if name in path_definitions:
kwargs[sanitized_key] = _get_val_from_param(value, path_definitions[name])
else: # Assume path params mechanism used for injection
kwargs[sanitized_key] = value
return kwargs
def _get_val_from_param(value: t.Any, param_definitions: t.Dict[str, dict]) -> t.Any:
"""Cast a value according to its definition in the specification."""
param_schema = param_definitions.get("schema", param_definitions)
if is_nullable(param_schema) and is_null(value):
return None
if param_schema["type"] == "array":
type_ = param_schema["items"]["type"]
format_ = param_schema["items"].get("format")
return [make_type(part, type_, format_) for part in value]
else:
type_ = param_schema["type"]
format_ = param_schema.get("format")
return make_type(value, type_, format_)
def _get_query_arguments(
query_params: dict,
*,
operation: AbstractOperation,
arguments: t.List[str],
has_kwargs: bool,
sanitize: t.Callable,
) -> dict:
"""
extract handler function arguments from the query parameters
"""
query_definitions = {
parameter["name"]: parameter
for parameter in operation.parameters
if parameter["in"] == "query"
}
default_query_params = _get_query_defaults(query_definitions)
query_arguments = deepcopy(default_query_params)
query_arguments = deep_merge(query_arguments, query_params)
return _query_args_helper(
query_definitions, query_arguments, arguments, has_kwargs, sanitize
)
def _get_query_defaults(query_definitions: t.Dict[str, dict]) -> t.Dict[str, t.Any]:
"""Get the default values for the query parameter from the parameter definition."""
defaults = {}
for k, v in query_definitions.items():
try:
if "default" in v:
defaults[k] = v["default"]
elif v["schema"]["type"] == "object":
defaults[k] = _get_default_obj(v["schema"])
else:
defaults[k] = v["schema"]["default"]
except KeyError:
pass
return defaults
def _get_default_obj(schema: dict) -> dict:
try:
return deepcopy(schema["default"])
except KeyError:
properties = schema.get("properties", {})
return _build_default_obj_recursive(properties, {})
def _build_default_obj_recursive(properties: dict, default_object: dict) -> dict:
"""takes disparate and nested default keys, and builds up a default object"""
for name, property_ in properties.items():
if "default" in property_ and name not in default_object:
default_object[name] = copy(property_["default"])
elif property_.get("type") == "object" and "properties" in property_:
default_object.setdefault(name, {})
default_object[name] = _build_default_obj_recursive(
property_["properties"], default_object[name]
)
return default_object
def _query_args_helper(
query_definitions: dict,
query_arguments: dict,
function_arguments: t.List[str],
has_kwargs: bool,
sanitize: t.Callable,
) -> dict:
result = {}
for key, value in query_arguments.items():
sanitized_key = sanitize(key)
if not has_kwargs and sanitized_key not in function_arguments:
logger.debug(
"Query Parameter '%s' (sanitized: '%s') not in function arguments",
key,
sanitized_key,
)
else:
logger.debug(
"Query Parameter '%s' (sanitized: '%s') in function arguments",
key,
sanitized_key,
)
try:
query_defn = query_definitions[key]
except KeyError: # pragma: no cover
logger.error(
"Function argument '%s' (non-sanitized: %s) not defined in specification",
sanitized_key,
key,
)
else:
logger.debug("%s is a %s", key, query_defn)
result.update({sanitized_key: _get_val_from_param(value, query_defn)})
return result
def _get_body_argument(
body: t.Any,
*,
operation: AbstractOperation,
arguments: t.List[str],
has_kwargs: bool,
sanitize: t.Callable,
content_type: str,
) -> dict:
if len(arguments) <= 0 and not has_kwargs:
return {}
body_name = sanitize(operation.body_name(content_type))
if content_type in FORM_CONTENT_TYPES:
result = _get_body_argument_form(
body, operation=operation, content_type=content_type
)
# Unpack form values for Swagger for compatibility with Connexion 2 behavior
if content_type in FORM_CONTENT_TYPES and isinstance(
operation, Swagger2Operation
):
if has_kwargs:
return result
else:
return {
sanitize(name): value
for name, value in result.items()
if sanitize(name) in arguments
}
else:
result = _get_body_argument_json(
body, operation=operation, content_type=content_type
)
if body_name in arguments or has_kwargs:
return {body_name: result}
return {}
def _get_body_argument_json(
body: t.Any, *, operation: AbstractOperation, content_type: str
) -> t.Any:
# if the body came in null, and the schema says it can be null, we decide
# to include no value for the body argument, rather than the default body
if is_nullable(operation.body_schema(content_type)) and is_null(body):
return None
if body is None:
default_body = operation.body_schema(content_type).get("default", {})
return deepcopy(default_body)
return body
def _get_body_argument_form(
body: dict, *, operation: AbstractOperation, content_type: str
) -> dict:
# now determine the actual value for the body (whether it came in or is default)
default_body = operation.body_schema(content_type).get("default", {})
body_props = {
k: {"schema": v}
for k, v in operation.body_schema(content_type).get("properties", {}).items()
}
# by OpenAPI specification `additionalProperties` defaults to `true`
# see: https://github.com/OAI/OpenAPI-Specification/blame/3.0.2/versions/3.0.2.md#L2305
additional_props = operation.body_schema().get("additionalProperties", True)
body_arg = deepcopy(default_body)
body_arg.update(body or {})
if body_props or additional_props:
return _get_typed_body_values(body_arg, body_props, additional_props)
return {}
def _get_typed_body_values(body_arg, body_props, additional_props):
"""
Return a copy of the provided body_arg dictionary
whose values will have the appropriate types
as defined in the provided schemas.
:type body_arg: type dict
:type body_props: dict
:type additional_props: dict|bool
:rtype: dict
"""
additional_props_defn = (
{"schema": additional_props} if isinstance(additional_props, dict) else None
)
res = {}
for key, value in body_arg.items():
try:
prop_defn = body_props[key]
res[key] = _get_val_from_param(value, prop_defn)
except KeyError: # pragma: no cover
if not additional_props:
logger.error(f"Body property '{key}' not defined in body schema")
continue
if additional_props_defn is not None:
value = _get_val_from_param(value, additional_props_defn)
res[key] = value
return res
def _get_file_arguments(files, arguments, has_kwargs=False):
return {k: v for k, v in files.items() if k in arguments or has_kwargs}

View File

@@ -2,46 +2,13 @@
This module defines interfaces for requests and responses used in Connexion for authentication,
validation, serialization, etc.
"""
import typing as t
from flask import Request as FlaskRequest
from starlette.requests import Request as StarletteRequest
from starlette.responses import StreamingResponse as StarletteStreamingResponse
class ConnexionRequest:
"""Connexion interface for a request."""
def __init__(
self,
url,
method,
path_params=None,
query=None,
headers=None,
form=None,
body=None,
json_getter=None,
files=None,
context=None,
cookies=None,
):
self.url = url
self.method = method
self.path_params = path_params or {}
self.query = query or {}
self.headers = headers or {}
self.form = form or {}
self.body = body
self.json_getter = json_getter
self.files = files
self.context = context if context is not None else {}
self.cookies = cookies or {}
@property
def json(self):
if not hasattr(self, "_json"):
self._json = self.json_getter()
return self._json
class ConnexionResponse:
"""Connexion interface for a response."""
@@ -62,6 +29,41 @@ class ConnexionResponse:
self.is_streamed = is_streamed
class ConnexionRequest:
def __init__(self, flask_request: FlaskRequest, uri_parser=None):
self._flask_request = flask_request
self.uri_parser = uri_parser
self._context = None
@property
def context(self):
if self._context is None:
scope = self._flask_request.environ["asgi.scope"]
extensions = scope.setdefault("extensions", {})
self._context = extensions.setdefault("connexion_context", {})
return self._context
@property
def view_args(self) -> t.Dict[str, t.Any]:
return self.uri_parser.resolve_path(self._flask_request.view_args)
@property
def args(self):
query_params = self._flask_request.args
query_params = {k: query_params.getlist(k) for k in query_params}
return self.uri_parser.resolve_query(query_params)
@property
def form(self):
form = self._flask_request.form.to_dict(flat=False)
form_data = self.uri_parser.resolve_form(form)
return form_data
def __getattr__(self, item):
return getattr(self._flask_request, item)
class MiddlewareRequest(StarletteRequest):
"""Wraps starlette Request so it can easily be extended."""

View File

@@ -0,0 +1,15 @@
"""The ContextMiddleware creates a global context based the scope. It should be last in the
middleware stack, so it exposes the scope passed to the application"""
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.context import _context
class ContextMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
context = scope.get("extensions", {}).get("connexion_context", {})
_context.set(context)
await self.app(scope, receive, send)

View File

@@ -4,6 +4,7 @@ import typing as t
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.middleware.abstract import AppMiddleware
from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
@@ -21,6 +22,7 @@ class ConnexionMiddleware:
SecurityMiddleware,
RequestValidationMiddleware,
ResponseValidationMiddleware,
ContextMiddleware,
]
def __init__(

View File

@@ -1,7 +1,10 @@
import functools
import pathlib
import re
import typing as t
from contextvars import ContextVar
import starlette.convertors
from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send
@@ -53,7 +56,7 @@ class RoutingAPI(AbstractRoutingAPI):
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
**kwargs
**kwargs,
) -> None:
"""API implementation on top of Starlette Router for Connexion middleware."""
self.next_app = next_app
@@ -76,6 +79,8 @@ class RoutingAPI(AbstractRoutingAPI):
routing_operation = RoutingOperation.from_operation(
operation, next_app=self.next_app
)
types = operation.get_path_parameter_types()
path = starlettify_path(path, types)
self._add_operation_internal(method, path, routing_operation)
def _add_operation_internal(
@@ -94,13 +99,15 @@ class RoutingMiddleware(AppMiddleware):
self.app = app
# Pass unknown routes to next app
self.router = Router(default=RoutingOperation(None, self.app))
starlette.convertors.register_url_convertor("float", FloatConverter())
starlette.convertors.register_url_convertor("int", IntegerConverter())
def add_api(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
arguments: t.Optional[dict] = None,
**kwargs
**kwargs,
) -> None:
"""Add an API to the router based on a OpenAPI spec.
@@ -113,7 +120,7 @@ class RoutingMiddleware(AppMiddleware):
base_path=base_path,
arguments=arguments,
next_app=self.app,
**kwargs
**kwargs,
)
self.router.mount(api.base_path, app=api.router)
@@ -129,3 +136,46 @@ class RoutingMiddleware(AppMiddleware):
# Needs to be set so starlette router throws exceptions instead of returning error responses
scope["app"] = self
await self.router(scope, receive, send)
PATH_PARAMETER = re.compile(r"\{([^}]*)\}")
PATH_PARAMETER_CONVERTERS = {"integer": "int", "number": "float", "path": "path"}
def convert_path_parameter(match, types):
name = match.group(1)
swagger_type = types.get(name)
converter = PATH_PARAMETER_CONVERTERS.get(swagger_type)
return f'{{{name.replace("-", "_")}{":" if converter else ""}{converter or ""}}}'
def starlettify_path(swagger_path, types=None):
"""
Convert swagger path templates to flask path templates
:type swagger_path: str
:type types: dict
:rtype: str
>>> starlettify_path('/foo-bar/{my-param}')
'/foo-bar/{my_param}'
>>> starlettify_path('/foo/{someint}', {'someint': 'int'})
'/foo/{someint:int}'
"""
if types is None:
types = {}
convert_match = functools.partial(convert_path_parameter, types=types)
return PATH_PARAMETER.sub(convert_match, swagger_path)
class FloatConverter(starlette.convertors.FloatConvertor):
"""Starlette converter for OpenAPI number type"""
regex = r"[+-]?[0-9]*(\.[0-9]*)?"
class IntegerConverter(starlette.convertors.IntegerConvertor):
"""Starlette converter for OpenAPI integer type"""
regex = r"[+-]?[0-9]+"

View File

@@ -6,9 +6,8 @@ and functionality shared between Swagger 2 and OpenAPI 3 specifications.
import abc
import logging
from ..decorators.lifecycle import RequestResponseDecorator
from ..decorators.parameter import parameter_to_arg
from ..utils import all_json
from connexion.decorators.lifecycle import RequestResponseDecorator
from connexion.utils import all_json
logger = logging.getLogger("connexion.operations.abstract")
@@ -46,6 +45,7 @@ class AbstractOperation(metaclass=abc.ABCMeta):
randomize_endpoint=None,
pythonic_params=False,
uri_parser_class=None,
parameter_to_arg=None,
):
"""
:param api: api that this operation is attached to
@@ -87,6 +87,8 @@ class AbstractOperation(metaclass=abc.ABCMeta):
self._responses = self._operation.get("responses", {})
self.parameter_to_arg = parameter_to_arg
@property
def api(self):
return self._api
@@ -148,75 +150,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
"""
return self._pythonic_params
@staticmethod
def _get_file_arguments(files, arguments, has_kwargs=False):
return {k: v for k, v in files.items() if k in arguments or has_kwargs}
@abc.abstractmethod
def _get_val_from_param(self, value, query_defn):
"""
Convert input parameters into the correct type
"""
def _query_args_helper(
self, query_defns, query_arguments, function_arguments, has_kwargs, sanitize
):
res = {}
for key, value in query_arguments.items():
sanitized_key = sanitize(key)
if not has_kwargs and sanitized_key not in function_arguments:
logger.debug(
"Query Parameter '%s' (sanitized: '%s') not in function arguments",
key,
sanitized_key,
)
else:
logger.debug(
"Query Parameter '%s' (sanitized: '%s') in function arguments",
key,
sanitized_key,
)
try:
query_defn = query_defns[key]
except KeyError: # pragma: no cover
logger.error(
"Function argument '%s' (non-sanitized: %s) not defined in specification",
sanitized_key,
key,
)
else:
logger.debug("%s is a %s", key, query_defn)
res.update(
{sanitized_key: self._get_val_from_param(value, query_defn)}
)
return res
@abc.abstractmethod
def _get_query_arguments(self, query, arguments, has_kwargs, sanitize):
"""
extract handler function arguments from the query parameters
"""
@abc.abstractmethod
def _get_body_argument(self, body, arguments, has_kwargs, sanitize):
"""
extract handler function arguments from the request body
"""
def _get_path_arguments(self, path_params, sanitize):
"""
extract handler function arguments from path parameters
"""
kwargs = {}
path_defns = {p["name"]: p for p in self.parameters if p["in"] == "path"}
for key, value in path_params.items():
sanitized_key = sanitize(key)
if key in path_defns:
kwargs[sanitized_key] = self._get_val_from_param(value, path_defns[key])
else: # Assume path params mechanism used for injection
kwargs[sanitized_key] = value
return kwargs
@property
@abc.abstractmethod
def parameters(self):
@@ -238,6 +171,12 @@ class AbstractOperation(metaclass=abc.ABCMeta):
Content-Types that the operation consumes
"""
@abc.abstractmethod
def body_name(self, content_type: str) -> str:
"""
Name of the body in the spec.
"""
@abc.abstractmethod
def body_schema(self, content_type: str = None) -> dict:
"""
@@ -251,23 +190,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
:rtype: dict
"""
def get_arguments(
self, path_params, query_params, body, files, arguments, has_kwargs, sanitize
):
"""
get arguments for handler function
"""
ret = {}
ret.update(self._get_path_arguments(path_params, sanitize))
ret.update(
self._get_query_arguments(query_params, arguments, has_kwargs, sanitize)
)
if self.method.upper() in ["PATCH", "POST", "PUT"]:
ret.update(self._get_body_argument(body, arguments, has_kwargs, sanitize))
ret.update(self._get_file_arguments(files, arguments, has_kwargs))
return ret
def response_definition(self, status_code=None, content_type=None):
"""
response definition for this endpoint
@@ -335,17 +257,19 @@ class AbstractOperation(metaclass=abc.ABCMeta):
:rtype: types.FunctionType
"""
function = parameter_to_arg(
self,
self._resolution.function,
self.pythonic_params,
function = self._resolution.function
if self.parameter_to_arg:
function = self.parameter_to_arg(
self,
function,
self.pythonic_params,
)
function = self._request_response_decorator(
function, self._uri_parsing_decorator
)
uri_parsing_decorator = self._uri_parsing_decorator
function = uri_parsing_decorator(function)
function = self._request_response_decorator(function)
return function
@property

View File

@@ -3,14 +3,11 @@ This module defines an OpenAPIOperation class, a Connexion operation specific fo
"""
import logging
from copy import copy, deepcopy
from connexion.datastructures import MediaTypeDict
from connexion.operations.abstract import AbstractOperation
from ..decorators.uri_parsing import OpenAPIURIParser
from ..http_facts import FORM_CONTENT_TYPES
from ..utils import deep_get, deep_merge, is_null, is_nullable, make_type
from connexion.uri_parsing import OpenAPIURIParser
from connexion.utils import deep_get
logger = logging.getLogger("connexion.operations.openapi3")
@@ -35,6 +32,7 @@ class OpenAPIOperation(AbstractOperation):
randomize_endpoint=None,
pythonic_params=False,
uri_parser_class=None,
parameter_to_arg=None,
):
"""
This class uses the OperationID identify the module and function that will handle the operation
@@ -88,6 +86,7 @@ class OpenAPIOperation(AbstractOperation):
randomize_endpoint=randomize_endpoint,
pythonic_params=pythonic_params,
uri_parser_class=uri_parser_class,
parameter_to_arg=parameter_to_arg,
)
self._request_body = operation.get("requestBody", {})
@@ -144,8 +143,9 @@ class OpenAPIOperation(AbstractOperation):
def produces(self):
return self._produces
def with_definitions(self, schema):
def with_definitions(self, schema: dict):
if self.components:
schema.setdefault("schema", {})
schema["schema"]["components"] = self.components
return schema
@@ -240,6 +240,9 @@ class OpenAPIOperation(AbstractOperation):
types[path_defn["name"]] = path_type
return types
def body_name(self, _content_type: str) -> str:
return self.request_body.get("x-body-name", "body")
def body_schema(self, content_type: str = None) -> dict:
"""
The body schema definition for this operation.
@@ -267,131 +270,3 @@ class OpenAPIOperation(AbstractOperation):
res = content_type_dict.get(content_type, {})
return self.with_definitions(res)
return {}
def _get_body_argument(self, body, arguments, has_kwargs, sanitize):
if len(arguments) <= 0 and not has_kwargs:
return {}
x_body_name = sanitize(self.request_body.get("x-body-name", "body"))
if self.consumes[0] in FORM_CONTENT_TYPES:
result = self._get_body_argument_form(body)
else:
result = self._get_body_argument_json(body)
if x_body_name in arguments or has_kwargs:
return {x_body_name: result}
return {}
def _get_body_argument_json(self, body):
# if the body came in null, and the schema says it can be null, we decide
# to include no value for the body argument, rather than the default body
if is_nullable(self.body_schema()) and is_null(body):
return None
if body is None:
default_body = self.body_schema().get("default", {})
return deepcopy(default_body)
return body
def _get_body_argument_form(self, body):
# now determine the actual value for the body (whether it came in or is default)
default_body = self.body_schema().get("default", {})
body_props = {
k: {"schema": v}
for k, v in self.body_schema().get("properties", {}).items()
}
# by OpenAPI specification `additionalProperties` defaults to `true`
# see: https://github.com/OAI/OpenAPI-Specification/blame/3.0.2/versions/3.0.2.md#L2305
additional_props = self.body_schema().get("additionalProperties", True)
body_arg = deepcopy(default_body)
body_arg.update(body or {})
if body_props or additional_props:
return self._get_typed_body_values(body_arg, body_props, additional_props)
return {}
def _get_typed_body_values(self, body_arg, body_props, additional_props):
"""
Return a copy of the provided body_arg dictionary
whose values will have the appropriate types
as defined in the provided schemas.
:type body_arg: type dict
:type body_props: dict
:type additional_props: dict|bool
:rtype: dict
"""
additional_props_defn = (
{"schema": additional_props} if isinstance(additional_props, dict) else None
)
res = {}
for key, value in body_arg.items():
try:
prop_defn = body_props[key]
res[key] = self._get_val_from_param(value, prop_defn)
except KeyError: # pragma: no cover
if not additional_props:
logger.error(f"Body property '{key}' not defined in body schema")
continue
if additional_props_defn is not None:
value = self._get_val_from_param(value, additional_props_defn)
res[key] = value
return res
def _build_default_obj_recursive(self, _properties, res):
"""takes disparate and nested default keys, and builds up a default object"""
for key, prop in _properties.items():
if "default" in prop and key not in res:
res[key] = copy(prop["default"])
elif prop.get("type") == "object" and "properties" in prop:
res.setdefault(key, {})
res[key] = self._build_default_obj_recursive(
prop["properties"], res[key]
)
return res
def _get_default_obj(self, schema):
try:
return deepcopy(schema["default"])
except KeyError:
_properties = schema.get("properties", {})
return self._build_default_obj_recursive(_properties, {})
def _get_query_defaults(self, query_defns):
defaults = {}
for k, v in query_defns.items():
try:
if v["schema"]["type"] == "object":
defaults[k] = self._get_default_obj(v["schema"])
else:
defaults[k] = v["schema"]["default"]
except KeyError:
pass
return defaults
def _get_query_arguments(self, query, arguments, has_kwargs, sanitize):
query_defns = {p["name"]: p for p in self.parameters if p["in"] == "query"}
default_query_params = self._get_query_defaults(query_defns)
query_arguments = deepcopy(default_query_params)
query_arguments = deep_merge(query_arguments, query)
return self._query_args_helper(
query_defns, query_arguments, arguments, has_kwargs, sanitize
)
def _get_val_from_param(self, value, query_defn):
query_schema = query_defn["schema"]
if is_nullable(query_schema) and is_null(value):
return None
if query_schema["type"] == "array":
return [make_type(part, query_schema["items"]["type"]) for part in value]
else:
return make_type(value, query_schema["type"])

View File

@@ -4,14 +4,12 @@ This module defines a Swagger2Operation class, a Connexion operation specific fo
import logging
import typing as t
from copy import deepcopy
from connexion.exceptions import InvalidSpecification
from connexion.http_facts import FORM_CONTENT_TYPES
from connexion.operations.abstract import AbstractOperation
from ..decorators.uri_parsing import Swagger2URIParser
from ..exceptions import InvalidSpecification
from ..http_facts import FORM_CONTENT_TYPES
from ..utils import deep_get, is_null, is_nullable, make_type
from connexion.uri_parsing import Swagger2URIParser
from connexion.utils import deep_get
logger = logging.getLogger("connexion.operations.swagger2")
@@ -52,6 +50,7 @@ class Swagger2Operation(AbstractOperation):
randomize_endpoint=None,
pythonic_params=False,
uri_parser_class=None,
parameter_to_arg=None,
):
"""
:param api: api that this operation is attached to
@@ -101,6 +100,7 @@ class Swagger2Operation(AbstractOperation):
randomize_endpoint=randomize_endpoint,
pythonic_params=pythonic_params,
uri_parser_class=uri_parser_class,
parameter_to_arg=parameter_to_arg,
)
self._produces = operation.get("produces", app_produces)
@@ -222,6 +222,9 @@ class Swagger2Operation(AbstractOperation):
except KeyError:
raise
def body_name(self, content_type: str = None) -> str:
return self.body_definition(content_type).get("name", "body")
def body_schema(self, content_type: str = None) -> dict:
"""
The body schema definition for this operation.
@@ -237,6 +240,7 @@ class Swagger2Operation(AbstractOperation):
:rtype: dict
"""
# TODO: cache
if content_type in FORM_CONTENT_TYPES:
form_parameters = [p for p in self.parameters if p["in"] == "formData"]
body_definition = self._transform_form(form_parameters)
@@ -248,12 +252,21 @@ class Swagger2Operation(AbstractOperation):
method=self.method, path=self.path
)
)
body_definition = body_parameters[0] if body_parameters else {}
body_parameter = body_parameters[0] if body_parameters else {}
body_definition = self._transform_json(body_parameter)
return body_definition
def _transform_json(self, body_parameter: dict) -> dict:
"""Translate Swagger2 json parameters into OpenAPI 3 jsonschema spec."""
nullable = body_parameter.get("x-nullable")
if nullable is not None:
body_parameter["schema"]["nullable"] = nullable
return body_parameter
def _transform_form(self, form_parameters: t.List[dict]) -> dict:
"""Translate Swagger2 form parameters into OpenAPI 3 jsonschema spec."""
properties = {}
defaults = {}
required = []
encoding = {}
@@ -276,7 +289,7 @@ class Swagger2Operation(AbstractOperation):
default = param.get("default")
if default is not None:
prop["default"] = default
defaults[param["name"]] = default
nullable = param.get("x-nullable")
if nullable is not None:
@@ -305,6 +318,7 @@ class Swagger2Operation(AbstractOperation):
"schema": {
"type": "object",
"properties": properties,
"default": defaults,
"required": required,
}
}
@@ -313,76 +327,3 @@ class Swagger2Operation(AbstractOperation):
definition["encoding"] = encoding
return definition
def _get_query_arguments(self, query, arguments, has_kwargs, sanitize):
query_defns = {p["name"]: p for p in self.parameters if p["in"] == "query"}
default_query_params = {
k: v["default"] for k, v in query_defns.items() if "default" in v
}
query_arguments = deepcopy(default_query_params)
query_arguments.update(query)
return self._query_args_helper(
query_defns, query_arguments, arguments, has_kwargs, sanitize
)
def _get_body_argument(self, body, arguments, has_kwargs, sanitize):
kwargs = {}
body_parameters = [p for p in self.parameters if p["in"] == "body"] or [{}]
if body is None:
body = deepcopy(body_parameters[0].get("schema", {}).get("default"))
body_name = sanitize(body_parameters[0].get("name"))
form_defns = {p["name"]: p for p in self.parameters if p["in"] == "formData"}
default_form_params = {
k: v["default"] for k, v in form_defns.items() if "default" in v
}
# Add body parameters
if body_name:
if not has_kwargs and body_name not in arguments:
logger.debug("Body parameter '%s' not in function arguments", body_name)
else:
logger.debug("Body parameter '%s' in function arguments", body_name)
kwargs[body_name] = body
# Add formData parameters
form_arguments = deepcopy(default_form_params)
if form_defns and body:
form_arguments.update(body)
for key, value in form_arguments.items():
sanitized_key = sanitize(key)
if not has_kwargs and sanitized_key not in arguments:
logger.debug(
"FormData parameter '%s' (sanitized: '%s') not in function arguments",
key,
sanitized_key,
)
else:
logger.debug(
"FormData parameter '%s' (sanitized: '%s') in function arguments",
key,
sanitized_key,
)
try:
form_defn = form_defns[key]
except KeyError: # pragma: no cover
logger.error(
"Function argument '%s' (non-sanitized: %s) not defined in specification",
key,
sanitized_key,
)
else:
kwargs[sanitized_key] = self._get_val_from_param(value, form_defn)
return kwargs
def _get_val_from_param(self, value, query_defn):
if is_nullable(query_defn) and is_null(value):
return None
query_schema = query_defn
if query_schema["type"] == "array":
return [make_type(part, query_defn["items"]["type"]) for part in value]
else:
return make_type(value, query_defn["type"])

View File

@@ -10,7 +10,7 @@ try:
except ImportError:
swagger_ui_2_path = swagger_ui_3_path = None
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.uri_parsing import AbstractURIParser
NO_UI_MSG = """The swagger_ui directory could not be found.
Please install connexion with extra install: pip install connexion[swagger-ui]

View File

@@ -13,14 +13,14 @@ import typing as t
import httpx
from ..decorators.parameter import inspect_function_arguments
from ..exceptions import (
from connexion.decorators.parameter import inspect_function_arguments
from connexion.exceptions import (
ConnexionException,
OAuthProblem,
OAuthResponseProblem,
OAuthScopeProblem,
)
from ..utils import get_function_from_name
from connexion.utils import get_function_from_name
logger = logging.getLogger("connexion.api.security")

View File

@@ -1,8 +0,0 @@
"""
This module defines SecurityHandlerFactories which support the creation of security
handlers for operations.
isort:skip_file
"""
from .security_handler_factory import SecurityHandlerFactory # NOQA

View File

@@ -1,9 +1,9 @@
"""
This module defines view function decorators to split query and path parameters.
This module defines URIParsers which parse query and path parameters according to OpenAPI
serialization rules.
"""
import abc
import functools
import json
import logging
import re
@@ -130,33 +130,6 @@ class AbstractURIParser(metaclass=abc.ABCMeta):
return resolved_param
def __call__(self, function):
"""
:type function: types.FunctionType
:rtype: types.FunctionType
"""
@functools.wraps(function)
def wrapper(request):
def coerce_dict(md):
"""MultiDict -> dict of lists"""
try:
return md.to_dict(flat=False)
except AttributeError:
return dict(md.items())
query = coerce_dict(request.query)
path_params = coerce_dict(request.path_params)
form = coerce_dict(request.form)
request.query = self.resolve_query(query)
request.path_params = self.resolve_path(path_params)
request.form = self.resolve_form(form)
response = function(request)
return response
return wrapper
class OpenAPIURIParser(AbstractURIParser):
style_defaults = {
@@ -281,7 +254,7 @@ class OpenAPIURIParser(AbstractURIParser):
class Swagger2URIParser(AbstractURIParser):
"""
Adheres to the Swagger2 spec,
Assumes the the last defined query parameter should be used.
Assumes that the last defined query parameter should be used.
"""
parsable_parameters = ["query", "path", "formData"]

View File

@@ -35,18 +35,24 @@ def boolean(s):
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#data-types
TYPE_MAP = {
TYPE_MAP: t.Dict[str, t.Any] = {
"integer": int,
"number": float,
"string": str,
"boolean": boolean,
"array": list,
"object": dict,
"file": lambda x: x, # Don't cast files
} # map of swagger types to python types
def make_type(value, _type):
type_func = TYPE_MAP[_type] # convert value to right type
def make_type(value: t.Any, type_: str, format_: t.Optional[str]) -> t.Any:
"""Cast a value to the type defined in the specification."""
# In OpenAPI, files are represented with string type and binary format
if type_ == "string" and format_ == "binary":
type_ = "file"
type_func = TYPE_MAP[type_]
return type_func(value)
@@ -141,6 +147,8 @@ def is_json_mimetype(mimetype):
:type mimetype: str
:rtype: bool
"""
if mimetype is None:
return False
maintype, subtype = mimetype.split("/") # type: str, str
if ";" in subtype:
@@ -209,9 +217,7 @@ def has_coroutine(function, api=None):
return iscorofunc(function)
else:
return any(
iscorofunc(func) for func in (function, api.get_request, api.get_response)
)
return any(iscorofunc(func) for func in (function, api.get_response))
def yamldumper(openapi):

View File

@@ -6,13 +6,13 @@ from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import FormParser, MultiPartParser
from starlette.types import Receive, Scope
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import (
BadRequestProblem,
ExtraParameterProblem,
TypeValidationError,
)
from connexion.json_schema import Draft4RequestValidator
from connexion.uri_parsing import AbstractURIParser
from connexion.utils import coerce_type, is_null
logger = logging.getLogger("connexion.validators.form_data")

View File

@@ -102,6 +102,9 @@ class ParameterValidator:
return self.validate_parameter("query", val, param)
def validate_path_parameter(self, param, request):
# TODO: activate
# path_params = self.uri_parser.resolve_path(request.path_params)
# val = path_params.get(param["name"].replace("-", "_"))
val = request.path_params.get(param["name"].replace("-", "_"))
return self.validate_parameter("path", val, param)

View File

@@ -1,2 +1,2 @@
[bdist_wheel]
universal=1
universal=1

View File

@@ -65,7 +65,7 @@ def test_app_with_different_server_option(simple_api_spec_dir, spec):
def test_app_with_different_uri_parser(simple_api_spec_dir):
from connexion.decorators.uri_parsing import FirstValueURIParser
from connexion.uri_parsing import FirstValueURIParser
app = App(
__name__,

View File

@@ -172,7 +172,7 @@ def test_path_parameter_someint__bad(simple_app):
# non-integer values will not match Flask route
app_client = simple_app.app.test_client()
resp = app_client.get("/v1.0/test-int-path/foo") # type: flask.Response
assert resp.status_code == 400, resp.text
assert resp.status_code == 404, resp.text
@pytest.mark.parametrize(
@@ -205,7 +205,7 @@ def test_path_parameter_somefloat__bad(simple_app):
# non-float values will not match Flask route
app_client = simple_app.app.test_client()
resp = app_client.get("/v1.0/test-float-path/123,45") # type: flask.Response
assert resp.status_code == 400, resp.text
assert resp.status_code == 404, resp.text
def test_default_param(strict_app):
@@ -280,7 +280,7 @@ def test_formdata_file_upload(simple_app):
app_client = simple_app.app.test_client()
resp = app_client.post(
"/v1.0/test-formData-file-upload",
data={"formData": (BytesIO(b"file contents"), "filename.txt")},
data={"fileData": (BytesIO(b"file contents"), "filename.txt")},
)
assert resp.status_code == 200
response = json.loads(resp.data.decode("utf-8", "replace"))
@@ -293,8 +293,8 @@ def test_formdata_file_upload_bad_request(simple_app):
assert resp.status_code == 400
response = json.loads(resp.data.decode("utf-8", "replace"))
assert response["detail"] in [
"Missing formdata parameter 'formData'",
"'formData' is a required property", # OAS3
"Missing formdata parameter 'fileData'",
"'fileData' is a required property", # OAS3
]
@@ -302,7 +302,7 @@ def test_formdata_file_upload_missing_param(simple_app):
app_client = simple_app.app.test_client()
resp = app_client.post(
"/v1.0/test-formData-file-upload-missing-param",
data={"missing_formData": (BytesIO(b"file contents"), "example.txt")},
data={"missing_fileData": (BytesIO(b"file contents"), "example.txt")},
)
assert resp.status_code == 200

View File

@@ -3,11 +3,19 @@ from unittest.mock import MagicMock
from connexion.decorators.parameter import parameter_to_arg, pythonic
def test_injection():
request = MagicMock(name="request", path_params={"p1": "123"})
request.args = {}
async def test_injection():
request = MagicMock(name="request")
request.query_params = {}
request.path_params = {"p1": "123"}
request.headers = {}
request.params = {}
request.content_type = "application/json"
async def coro():
return
request.json = coro
request.loop = None
request.context = {}
func = MagicMock()
@@ -16,17 +24,29 @@ def test_injection():
class Op:
consumes = ["application/json"]
parameters = []
method = "GET"
def get_arguments(self, *args, **kwargs):
return {"p1": "123"}
def body_name(self, *args, **kwargs):
return "body"
parameter_to_arg(Op(), handler)(request)
parameter_decorator = parameter_to_arg(Op(), handler)
await parameter_decorator(request)
func.assert_called_with(p1="123")
def test_injection_with_context():
async def test_injection_with_context():
request = MagicMock(name="request")
async def coro():
return
request.json = coro
request.loop = None
request.context = {}
request.content_type = "application/json"
request.path_params = {"p1": "123"}
func = MagicMock()
def handler(context_, **kwargs):
@@ -34,11 +54,14 @@ def test_injection_with_context():
class Op2:
consumes = ["application/json"]
parameters = []
method = "GET"
def get_arguments(self, *args, **kwargs):
return {"p1": "123"}
def body_name(self, *args, **kwargs):
return "body"
parameter_to_arg(Op2(), handler)(request)
parameter_decorator = parameter_to_arg(Op2(), handler)
await parameter_decorator(request)
func.assert_called_with(request.context, p1="123")

View File

@@ -1,5 +1,5 @@
import pytest
from connexion.decorators.uri_parsing import (
from connexion.uri_parsing import (
AlwaysMultiURIParser,
FirstValueURIParser,
OpenAPIURIParser,
@@ -46,7 +46,9 @@ MULTI = "multi"
(AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY2, PIPES),
],
)
def test_uri_parser_query_params(parser_class, expected, query_in, collection_format):
async def test_uri_parser_query_params(
parser_class, expected, query_in, collection_format
):
class Request:
query = query_in
path_params = {}
@@ -63,9 +65,9 @@ def test_uri_parser_query_params(parser_class, expected, query_in, collection_fo
}
]
body_defn = {}
p = parser_class(parameters, body_defn)
res = p(lambda x: x)(request)
assert res.query["letters"] == expected
parser = parser_class(parameters, body_defn)
res = parser.resolve_query(request.query.to_dict(flat=False))
assert res["letters"] == expected
@pytest.mark.parametrize(
@@ -82,7 +84,9 @@ def test_uri_parser_query_params(parser_class, expected, query_in, collection_fo
(AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY2, PIPES),
],
)
def test_uri_parser_form_params(parser_class, expected, query_in, collection_format):
async def test_uri_parser_form_params(
parser_class, expected, query_in, collection_format
):
class Request:
query = {}
form = query_in
@@ -99,9 +103,9 @@ def test_uri_parser_form_params(parser_class, expected, query_in, collection_for
}
]
body_defn = {}
p = parser_class(parameters, body_defn)
res = p(lambda x: x)(request)
assert res.form["letters"] == expected
parser = parser_class(parameters, body_defn)
res = parser.resolve_form(request.form.to_dict(flat=False))
assert res["letters"] == expected
@pytest.mark.parametrize(
@@ -115,7 +119,9 @@ def test_uri_parser_form_params(parser_class, expected, query_in, collection_for
(AlwaysMultiURIParser, ["d", "e", "f"], PATH2, PIPES),
],
)
def test_uri_parser_path_params(parser_class, expected, query_in, collection_format):
async def test_uri_parser_path_params(
parser_class, expected, query_in, collection_format
):
class Request:
query = {}
form = {}
@@ -132,9 +138,9 @@ def test_uri_parser_path_params(parser_class, expected, query_in, collection_for
}
]
body_defn = {}
p = parser_class(parameters, body_defn)
res = p(lambda x: x)(request)
assert res.path_params["letters"] == expected
parser = parser_class(parameters, body_defn)
res = parser.resolve_path(request.path_params)
assert res["letters"] == expected
@pytest.mark.parametrize(
@@ -149,7 +155,7 @@ def test_uri_parser_path_params(parser_class, expected, query_in, collection_for
(AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY4, PIPES),
],
)
def test_uri_parser_query_params_with_square_brackets(
async def test_uri_parser_query_params_with_square_brackets(
parser_class, expected, query_in, collection_format
):
class Request:
@@ -168,9 +174,9 @@ def test_uri_parser_query_params_with_square_brackets(
}
]
body_defn = {}
p = parser_class(parameters, body_defn)
res = p(lambda x: x)(request)
assert res.query["letters[eq]"] == expected
parser = parser_class(parameters, body_defn)
res = parser.resolve_query(request.query.to_dict(flat=False))
assert res["letters[eq]"] == expected
@pytest.mark.parametrize(
@@ -188,7 +194,7 @@ def test_uri_parser_query_params_with_square_brackets(
(AlwaysMultiURIParser, ["a"], QUERY6, PIPES),
],
)
def test_uri_parser_query_params_with_underscores(
async def test_uri_parser_query_params_with_underscores(
parser_class, expected, query_in, collection_format
):
class Request:
@@ -207,9 +213,9 @@ def test_uri_parser_query_params_with_underscores(
}
]
body_defn = {}
p = parser_class(parameters, body_defn)
res = p(lambda x: x)(request)
assert res.query["letters_eq"] == expected
parser = parser_class(parameters, body_defn)
res = parser.resolve_query(request.query.to_dict(flat=False))
assert res["letters_eq"] == expected
@pytest.mark.parametrize(
@@ -231,7 +237,7 @@ def test_uri_parser_query_params_with_underscores(
),
],
)
def test_uri_parser_query_params_with_malformed_names(
async def test_uri_parser_query_params_with_malformed_names(
parser_class, query_in, collection_format, explode, expected
):
class Request:
@@ -253,6 +259,6 @@ def test_uri_parser_query_params_with_malformed_names(
}
]
body_defn = {}
p = parser_class(parameters, body_defn)
res = p(lambda x: x)(request)
assert res.query == expected
parser = parser_class(parameters, body_defn)
res = parser.resolve_query(request.query.to_dict(flat=False))
assert res == expected

View File

@@ -84,7 +84,7 @@ def get_bye_secure(name, user, token_info):
def get_bye_secure_from_flask():
return "Goodbye {user} (Secure!)".format(user=context["user"])
return "Goodbye {user} (Secure!)".format(user=context.context["user"])
def get_bye_secure_from_connexion(context_):
@@ -314,9 +314,11 @@ def test_formdata_missing_param():
return ""
def test_formdata_file_upload(formData, **kwargs):
filename = formData.filename
contents = formData.read().decode("utf-8", "replace")
def test_formdata_file_upload(fileData, **kwargs):
"""In Swagger, form paramaeters and files are passed separately"""
filename = fileData.filename
contents = fileData.read()
contents = contents.decode("utf-8", "replace")
return {filename: contents}

View File

@@ -3,25 +3,6 @@ info:
title: '{{title}}'
version: '1.0'
paths:
'/greeting/{name}':
post:
summary: Generate greeting
description: Generates a greeting message.
operationId: fakeapi.hello.post_greeting
responses:
'200':
description: greeting response
content:
'application/json':
schema:
type: object
parameters:
- name: name
in: path
description: Name of the person to greet.
required: true
schema:
type: string
'/greeting/{name}/{remainder}':
post:
summary: Generate greeting and collect the remainder of the url
@@ -48,6 +29,25 @@ paths:
schema:
type: string
format: path
'/greeting/{name}':
post:
summary: Generate greeting
description: Generates a greeting message.
operationId: fakeapi.hello.post_greeting
responses:
'200':
description: greeting response
content:
'application/json':
schema:
type: object
parameters:
- name: name
in: path
description: Name of the person to greet.
required: true
schema:
type: string
'/greetings/{name}':
get:
summary: Generate greeting
@@ -600,11 +600,11 @@ paths:
schema:
type: object
properties:
formData:
fileData:
type: string
format: binary
required:
- formData
- fileData
/test-formData-file-upload-missing-param:
post:
summary: 'Test formData with file type, missing parameter in handler'
@@ -618,11 +618,11 @@ paths:
schema:
type: object
properties:
missing_formData:
missing_fileData:
type: string
format: binary
required:
- missing_formData
- missing_fileData
/test-bool-param:
get:
summary: Test usage of boolean default value
@@ -690,6 +690,24 @@ paths:
responses:
'200':
description: OK
/goodday/noheader:
post:
summary: Generate good day greeting
description: Generates a good day message.
operationId: fakeapi.hello.post_goodday_no_header
responses:
'201':
description: goodday response
headers:
Location:
description: The URI of the created resource
schema:
type: string
required: true
content:
'application/json':
schema:
type: object
'/goodday/{name}':
post:
summary: Generate good day greeting
@@ -715,24 +733,6 @@ paths:
required: true
schema:
type: string
/goodday/noheader:
post:
summary: Generate good day greeting
description: Generates a good day message.
operationId: fakeapi.hello.post_goodday_no_header
responses:
'201':
description: goodday response
headers:
Location:
description: The URI of the created resource
schema:
type: string
required: true
content:
'application/json':
schema:
type: object
'/goodevening/{name}':
post:
summary: Generate good evening

View File

@@ -7,22 +7,6 @@ info:
basePath: /v1.0
paths:
/greeting/{name}:
post:
summary: Generate greeting
description: Generates a greeting message.
operationId: fakeapi.hello.post_greeting
responses:
'200':
description: greeting response
schema:
type: object
parameters:
- name: name
in: path
description: Name of the person to greet.
required: true
type: string
/greeting/{name}/{remainder}:
post:
summary: Generate greeting and collect the remainder of the url
@@ -45,6 +29,22 @@ paths:
required: true
type: string
format: path
/greeting/{name}:
post:
summary: Generate greeting
description: Generates a greeting message.
operationId: fakeapi.hello.post_greeting
responses:
'200':
description: greeting response
schema:
type: object
parameters:
- name: name
in: path
description: Name of the person to greet.
required: true
type: string
/greetings/{name}:
get:
summary: Generate greeting
@@ -457,7 +457,7 @@ paths:
consumes:
- multipart/form-data
parameters:
- name: formData
- name: fileData
type: file
in: formData
required: true
@@ -472,7 +472,7 @@ paths:
consumes:
- multipart/form-data
parameters:
- name: missing_formData
- name: missing_fileData
type: file
in: formData
required: true

View File

@@ -0,0 +1,22 @@
from connexion.datastructures import MediaTypeDict
def test_media_type_dict():
d = MediaTypeDict(
{
"*/*": "*/*",
"*/json": "*/json",
"*/*json": "*/*json",
"multipart/*": "multipart/*",
"multipart/form-data": "multipart/form-data",
}
)
assert d["application/json"] == "*/json"
assert d["application/problem+json"] == "*/*json"
assert d["application/x-www-form-urlencoded"] == "*/*"
assert d["multipart/form-data"] == "multipart/form-data"
assert d["multipart/byteranges"] == "multipart/*"
# Test __contains__
assert "application/json" in d

View File

@@ -735,7 +735,6 @@ def test_form_transformation(api):
"properties": {
"param": {
"type": "string",
"default": "foo@bar.com",
"format": "email",
},
"array_param": {
@@ -746,6 +745,7 @@ def test_form_transformation(api):
"nullable": True,
},
},
"default": {"param": "foo@bar.com"},
"required": ["param"],
},
"encoding": {

View File

@@ -2,8 +2,8 @@ from unittest.mock import MagicMock
from urllib.parse import quote_plus
import pytest
from connexion.decorators.uri_parsing import Swagger2URIParser
from connexion.exceptions import BadRequestProblem
from connexion.uri_parsing import Swagger2URIParser
from connexion.validators.parameter import ParameterValidator
from starlette.datastructures import QueryParams