mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 04:19:26 +00:00
Move parameter decorator related methods out of operation classes
This commit is contained in:
@@ -20,7 +20,7 @@ from .utils import not_installed_error # NOQA
|
||||
try:
|
||||
from flask import request # NOQA
|
||||
|
||||
from .apis.flask_api import FlaskApi, context # NOQA
|
||||
from .apis.flask_api import FlaskApi # NOQA
|
||||
from .apps.flask_app import FlaskApp
|
||||
except ImportError as e: # pragma: no cover
|
||||
_flask_not_installed_error = not_installed_error(e)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
This module defines an AbstractAPI, which defines a standardized interface for a Connexion API.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import pathlib
|
||||
@@ -9,15 +8,18 @@ import sys
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
|
||||
from ..datastructures import NoContent
|
||||
from ..exceptions import ResolverError
|
||||
from ..http_facts import METHODS
|
||||
from ..jsonifier import Jsonifier
|
||||
from ..lifecycle import ConnexionResponse
|
||||
from ..operations import make_operation
|
||||
from ..options import ConnexionOptions
|
||||
from ..resolver import Resolver
|
||||
from ..spec import Specification
|
||||
from starlette.requests import Request
|
||||
|
||||
from connexion.datastructures import NoContent
|
||||
from connexion.decorators.parameter import parameter_to_arg
|
||||
from connexion.exceptions import ResolverError
|
||||
from connexion.http_facts import METHODS
|
||||
from connexion.jsonifier import Jsonifier
|
||||
from connexion.lifecycle import ConnexionResponse
|
||||
from connexion.operations import make_operation
|
||||
from connexion.options import ConnexionOptions
|
||||
from connexion.resolver import Resolver
|
||||
from connexion.spec import Specification
|
||||
|
||||
MODULE_PATH = pathlib.Path(__file__).absolute().parent.parent
|
||||
SWAGGER_UI_URL = "ui"
|
||||
@@ -231,30 +233,19 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
self.resolver,
|
||||
pythonic_params=self.pythonic_params,
|
||||
uri_parser_class=self.options.uri_parser_class,
|
||||
parameter_to_arg=parameter_to_arg,
|
||||
)
|
||||
self._add_operation_internal(method, path, operation)
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_request(self, *args, **kwargs):
|
||||
def get_request(cls, uri_parser) -> Request:
|
||||
"""
|
||||
This method converts the user framework request to a ConnexionRequest.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def get_response(self, response, mimetype=None, request=None):
|
||||
"""
|
||||
This method converts a handler response to a framework response.
|
||||
This method should just retrieve response from handler then call `cls._get_response`.
|
||||
:param response: A response to cast (tuple, framework response, etc).
|
||||
:param mimetype: The response mimetype.
|
||||
:type mimetype: Union[None, str]
|
||||
:param request: The request associated with this response (the user framework request).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _get_response(cls, response, mimetype=None, extra_context=None):
|
||||
def _get_response(cls, response, mimetype=None):
|
||||
"""
|
||||
This method converts a handler response to a framework response.
|
||||
The response can be a ConnexionResponse, an operation handler, a framework response or a tuple.
|
||||
@@ -262,31 +253,24 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
:param response: A response to cast (tuple, framework response, etc).
|
||||
:param mimetype: The response mimetype.
|
||||
:type mimetype: Union[None, str]
|
||||
:param extra_context: dict of extra details, like url, to include in logs
|
||||
:type extra_context: Union[None, dict]
|
||||
"""
|
||||
if extra_context is None:
|
||||
extra_context = {}
|
||||
logger.debug(
|
||||
"Getting data and status code",
|
||||
extra={"data": response, "data_type": type(response), **extra_context},
|
||||
extra={"data": response, "data_type": type(response)},
|
||||
)
|
||||
|
||||
if isinstance(response, ConnexionResponse):
|
||||
framework_response = cls._connexion_to_framework_response(
|
||||
response, mimetype, extra_context
|
||||
response, mimetype
|
||||
)
|
||||
else:
|
||||
framework_response = cls._response_from_handler(
|
||||
response, mimetype, extra_context
|
||||
)
|
||||
framework_response = cls._response_from_handler(response, mimetype)
|
||||
|
||||
logger.debug(
|
||||
"Got framework response",
|
||||
extra={
|
||||
"response": framework_response,
|
||||
"response_type": type(framework_response),
|
||||
**extra_context,
|
||||
},
|
||||
)
|
||||
return framework_response
|
||||
@@ -298,7 +282,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
t.Any, str, t.Tuple[str], t.Tuple[str, int], t.Tuple[str, int, dict]
|
||||
],
|
||||
mimetype: str,
|
||||
extra_context: t.Optional[dict] = None,
|
||||
) -> t.Any:
|
||||
"""
|
||||
Create a framework response from the operation handler data.
|
||||
@@ -311,7 +294,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
|
||||
:param response: A response from an operation handler.
|
||||
:param mimetype: The response mimetype.
|
||||
:param extra_context: dict of extra details, like url, to include in logs
|
||||
"""
|
||||
if cls._is_framework_response(response):
|
||||
return response
|
||||
@@ -320,9 +302,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
len_response = len(response)
|
||||
if len_response == 1:
|
||||
(data,) = response
|
||||
return cls._build_response(
|
||||
mimetype=mimetype, data=data, extra_context=extra_context
|
||||
)
|
||||
return cls._build_response(mimetype=mimetype, data=data)
|
||||
if len_response == 2:
|
||||
if isinstance(response[1], (int, Enum)):
|
||||
data, status_code = response
|
||||
@@ -330,7 +310,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
mimetype=mimetype,
|
||||
data=data,
|
||||
status_code=status_code,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
else:
|
||||
data, headers = response
|
||||
@@ -338,7 +317,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
mimetype=mimetype,
|
||||
data=data,
|
||||
headers=headers,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
elif len_response == 3:
|
||||
data, status_code, headers = response
|
||||
@@ -347,7 +325,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
data=data,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
@@ -356,9 +333,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
" (body, status), or (body, headers)."
|
||||
)
|
||||
else:
|
||||
return cls._build_response(
|
||||
mimetype=mimetype, data=response, extra_context=extra_context
|
||||
)
|
||||
return cls._build_response(mimetype=mimetype, data=response)
|
||||
|
||||
@classmethod
|
||||
def get_connexion_response(cls, response, mimetype=None):
|
||||
@@ -384,7 +359,7 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _connexion_to_framework_response(cls, response, mimetype, extra_context=None):
|
||||
def _connexion_to_framework_response(cls, response, mimetype):
|
||||
"""Cast ConnexionResponse to framework response class"""
|
||||
|
||||
@classmethod
|
||||
@@ -396,7 +371,6 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
content_type=None,
|
||||
status_code=None,
|
||||
headers=None,
|
||||
extra_context=None,
|
||||
):
|
||||
"""
|
||||
Create a framework response from the provided arguments.
|
||||
@@ -407,16 +381,12 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
:type status_code: int
|
||||
:param headers: The response status code.
|
||||
:type headers: Union[Iterable[Tuple[str, str]], Dict[str, str]]
|
||||
:param extra_context: dict of extra details, like url, to include in logs
|
||||
:type extra_context: Union[None, dict]
|
||||
:return A framework response.
|
||||
:rtype Response
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _prepare_body_and_status_code(
|
||||
cls, data, mimetype, status_code=None, extra_context=None
|
||||
):
|
||||
def _prepare_body_and_status_code(cls, data, mimetype, status_code=None):
|
||||
if data is NoContent:
|
||||
data = None
|
||||
|
||||
@@ -435,12 +405,10 @@ class AbstractAPI(AbstractRoutingAPI, metaclass=AbstractAPIMeta):
|
||||
else:
|
||||
body = data
|
||||
|
||||
if extra_context is None:
|
||||
extra_context = {}
|
||||
logger.debug(
|
||||
"Prepared body and status code (%d)",
|
||||
status_code,
|
||||
extra={"body": body, **extra_context},
|
||||
extra={"body": body},
|
||||
)
|
||||
|
||||
return body, status_code, mimetype
|
||||
|
||||
@@ -2,13 +2,9 @@
|
||||
This module defines a Flask Connexion API which implements translations between Flask and
|
||||
Connexion requests / responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import flask
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from connexion.apis import flask_utils
|
||||
from connexion.apis.abstract import AbstractAPI
|
||||
@@ -52,7 +48,7 @@ class FlaskApi(AbstractAPI):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_response(cls, response, mimetype=None, request=None):
|
||||
def get_response(cls, response, mimetype=None):
|
||||
"""Gets ConnexionResponse instance for the operation handler
|
||||
result. Status Code and Headers for response. If only body
|
||||
data is returned by the endpoint function, then the status
|
||||
@@ -64,9 +60,7 @@ class FlaskApi(AbstractAPI):
|
||||
:type response: flask.Response | (flask.Response,) | (flask.Response, int) | (flask.Response, dict) | (flask.Response, int, dict)
|
||||
:rtype: ConnexionResponse
|
||||
"""
|
||||
return cls._get_response(
|
||||
response, mimetype=mimetype, extra_context={"url": flask.request.url}
|
||||
)
|
||||
return cls._get_response(response, mimetype=mimetype)
|
||||
|
||||
@classmethod
|
||||
def _is_framework_response(cls, response):
|
||||
@@ -86,7 +80,7 @@ class FlaskApi(AbstractAPI):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _connexion_to_framework_response(cls, response, mimetype, extra_context=None):
|
||||
def _connexion_to_framework_response(cls, response, mimetype):
|
||||
"""Cast ConnexionResponse to framework response class"""
|
||||
flask_response = cls._build_response(
|
||||
mimetype=response.mimetype or mimetype,
|
||||
@@ -94,7 +88,6 @@ class FlaskApi(AbstractAPI):
|
||||
headers=response.headers,
|
||||
status_code=response.status_code,
|
||||
data=response.body,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
|
||||
return flask_response
|
||||
@@ -107,7 +100,6 @@ class FlaskApi(AbstractAPI):
|
||||
headers=None,
|
||||
status_code=None,
|
||||
data=None,
|
||||
extra_context=None,
|
||||
):
|
||||
if cls._is_framework_response(data):
|
||||
return flask.current_app.make_response((data, status_code, headers))
|
||||
@@ -116,7 +108,6 @@ class FlaskApi(AbstractAPI):
|
||||
data=data,
|
||||
mimetype=mimetype,
|
||||
status_code=status_code,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
@@ -133,61 +124,14 @@ class FlaskApi(AbstractAPI):
|
||||
def _serialize_data(cls, data, mimetype):
|
||||
if isinstance(mimetype, str) and is_json_mimetype(mimetype):
|
||||
body = cls.jsonifier.dumps(data)
|
||||
elif not (isinstance(data, bytes) or isinstance(data, str)):
|
||||
warnings.warn(
|
||||
"Implicit (flask) JSON serialization will change in the next major version. "
|
||||
"This is triggered because a response body is being serialized as JSON "
|
||||
"even though the mimetype is not a JSON type. "
|
||||
"This will be replaced by something that is mimetype-specific and may "
|
||||
"raise an error instead of silently converting everything to JSON. "
|
||||
"Please make sure to specify media/mime types in your specs.",
|
||||
FutureWarning, # a Deprecation targeted at application users.
|
||||
)
|
||||
body = cls.jsonifier.dumps(data)
|
||||
else:
|
||||
body = data
|
||||
|
||||
return body, mimetype
|
||||
|
||||
@classmethod
|
||||
def get_request(cls, *args, **params):
|
||||
# type: (*Any, **Any) -> ConnexionRequest
|
||||
"""Gets ConnexionRequest instance for the operation handler
|
||||
result. Status Code and Headers for response. If only body
|
||||
data is returned by the endpoint function, then the status
|
||||
code will be set to 200 and no headers will be added.
|
||||
|
||||
If the returned object is a flask.Response then it will just
|
||||
pass the information needed to recreate it.
|
||||
|
||||
:rtype: ConnexionRequest
|
||||
"""
|
||||
flask_request = flask.request
|
||||
scope = flask_request.environ["asgi.scope"]
|
||||
context_dict = scope.get("extensions", {}).get("connexion_context", {})
|
||||
setattr(flask.globals.request_ctx, "connexion_context", context_dict)
|
||||
request = ConnexionRequest(
|
||||
flask_request.url,
|
||||
flask_request.method,
|
||||
headers=flask_request.headers,
|
||||
form=flask_request.form,
|
||||
query=flask_request.args,
|
||||
body=flask_request.get_data(),
|
||||
json_getter=lambda: flask_request.get_json(silent=True),
|
||||
files=flask_request.files,
|
||||
path_params=params,
|
||||
context=context_dict,
|
||||
cookies=flask_request.cookies,
|
||||
)
|
||||
logger.debug(
|
||||
"Getting data and status code",
|
||||
extra={
|
||||
"data": request.body,
|
||||
"data_type": type(request.body),
|
||||
"url": request.url,
|
||||
},
|
||||
)
|
||||
return request
|
||||
def get_request(cls, uri_parser) -> ConnexionRequest:
|
||||
return ConnexionRequest(flask.request, uri_parser=uri_parser)
|
||||
|
||||
@classmethod
|
||||
def _set_jsonifier(cls):
|
||||
@@ -195,10 +139,3 @@ class FlaskApi(AbstractAPI):
|
||||
Use Flask specific JSON loader
|
||||
"""
|
||||
cls.jsonifier = Jsonifier(flask.json, indent=2)
|
||||
|
||||
|
||||
def _get_context():
|
||||
return getattr(flask.globals.request_ctx, "connexion_context")
|
||||
|
||||
|
||||
context = LocalProxy(_get_context)
|
||||
|
||||
@@ -7,6 +7,7 @@ import pathlib
|
||||
from types import FunctionType # NOQA
|
||||
|
||||
import a2wsgi
|
||||
import asgiref.wsgi
|
||||
import flask
|
||||
import werkzeug.exceptions
|
||||
from flask import signals
|
||||
|
||||
9
connexion/context.py
Normal file
9
connexion/context.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from asyncio import AbstractEventLoop
|
||||
from contextvars import ContextVar
|
||||
|
||||
_context: ContextVar[AbstractEventLoop] = ContextVar("CONTEXT")
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "context":
|
||||
return _context.get()
|
||||
@@ -22,7 +22,7 @@ class RequestResponseDecorator:
|
||||
self.api = api
|
||||
self.mimetype = mimetype
|
||||
|
||||
def __call__(self, function):
|
||||
def __call__(self, function, uri_parser):
|
||||
"""
|
||||
:type function: types.FunctionType
|
||||
:rtype: types.FunctionType
|
||||
@@ -31,7 +31,7 @@ class RequestResponseDecorator:
|
||||
|
||||
@functools.wraps(function)
|
||||
async def wrapper(*args, **kwargs):
|
||||
connexion_request = self.api.get_request(*args, **kwargs)
|
||||
connexion_request = self.api.get_request(uri_parser=uri_parser)
|
||||
while asyncio.iscoroutine(connexion_request):
|
||||
connexion_request = await connexion_request
|
||||
|
||||
@@ -40,7 +40,7 @@ class RequestResponseDecorator:
|
||||
connexion_response = await connexion_response
|
||||
|
||||
framework_response = self.api.get_response(
|
||||
connexion_response, self.mimetype, connexion_request
|
||||
connexion_response, self.mimetype
|
||||
)
|
||||
while asyncio.iscoroutine(framework_response):
|
||||
framework_response = await framework_response
|
||||
@@ -51,8 +51,8 @@ class RequestResponseDecorator:
|
||||
|
||||
@functools.wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
request = self.api.get_request(*args, **kwargs)
|
||||
request = self.api.get_request(uri_parser)
|
||||
response = function(request)
|
||||
return self.api.get_response(response, self.mimetype, request)
|
||||
return self.api.get_response(response, self.mimetype)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -1,112 +1,69 @@
|
||||
"""
|
||||
This module defines a decorator to convert request parameters to arguments for the view function.
|
||||
This module defines a utility functions to convert request parameters to arguments for the view
|
||||
function.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import functools
|
||||
import inspect
|
||||
import keyword
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
import typing as t
|
||||
from copy import copy, deepcopy
|
||||
|
||||
import inflection
|
||||
|
||||
from ..http_facts import FORM_CONTENT_TYPES
|
||||
from ..lifecycle import ConnexionRequest # NOQA
|
||||
from ..utils import all_json
|
||||
from connexion.http_facts import FORM_CONTENT_TYPES
|
||||
from connexion.lifecycle import ConnexionRequest
|
||||
from connexion.operations import AbstractOperation, Swagger2Operation
|
||||
from connexion.utils import (
|
||||
deep_merge,
|
||||
is_json_mimetype,
|
||||
is_null,
|
||||
is_nullable,
|
||||
make_type,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CONTEXT_NAME = "context_"
|
||||
|
||||
|
||||
def inspect_function_arguments(function): # pragma: no cover
|
||||
"""
|
||||
Returns the list of variables names of a function and if it
|
||||
accepts keyword arguments.
|
||||
|
||||
:type function: Callable
|
||||
:rtype: tuple[list[str], bool]
|
||||
"""
|
||||
parameters = inspect.signature(function).parameters
|
||||
bound_arguments = [
|
||||
name
|
||||
for name, p in parameters.items()
|
||||
if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
|
||||
]
|
||||
has_kwargs = any(p.kind == p.VAR_KEYWORD for p in parameters.values())
|
||||
return list(bound_arguments), has_kwargs
|
||||
|
||||
|
||||
def snake_and_shadow(name):
|
||||
"""
|
||||
Converts the given name into Pythonic form. Firstly it converts CamelCase names to snake_case. Secondly it looks to
|
||||
see if the name matches a known built-in and if it does it appends an underscore to the name.
|
||||
:param name: The parameter name
|
||||
:type name: str
|
||||
:return:
|
||||
"""
|
||||
snake = inflection.underscore(name)
|
||||
if snake in builtins.__dict__ or keyword.iskeyword(snake):
|
||||
return f"{snake}_"
|
||||
return snake
|
||||
|
||||
|
||||
def sanitized(name):
|
||||
return name and re.sub(
|
||||
"^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", re.sub(r"\[(?!])", "_", name))
|
||||
)
|
||||
|
||||
|
||||
def pythonic(name):
|
||||
name = name and snake_and_shadow(name)
|
||||
return sanitized(name)
|
||||
|
||||
|
||||
def parameter_to_arg(operation, function, pythonic_params=False):
|
||||
"""
|
||||
Pass query and body parameters as keyword arguments to handler function.
|
||||
|
||||
See (https://github.com/zalando/connexion/issues/59)
|
||||
:param operation: The operation being called
|
||||
:type operation: connexion.operations.AbstractOperation
|
||||
:param pythonic_params: When True CamelCase parameters are converted to snake_case and an underscore is appended to
|
||||
any shadowed built-ins
|
||||
:type pythonic_params: bool
|
||||
"""
|
||||
consumes = operation.consumes
|
||||
|
||||
def parameter_to_arg(
|
||||
operation: AbstractOperation,
|
||||
function: t.Callable,
|
||||
pythonic_params: bool = False,
|
||||
) -> t.Callable[[ConnexionRequest], t.Any]:
|
||||
sanitize = pythonic if pythonic_params else sanitized
|
||||
arguments, has_kwargs = inspect_function_arguments(function)
|
||||
|
||||
@functools.wraps(function)
|
||||
def wrapper(request):
|
||||
# type: (ConnexionRequest) -> Any
|
||||
logger.debug("Function Arguments: %s", arguments)
|
||||
def wrapper(request: ConnexionRequest) -> t.Any:
|
||||
kwargs = {}
|
||||
|
||||
if all_json(consumes):
|
||||
request_body = request.json
|
||||
elif consumes[0] in FORM_CONTENT_TYPES:
|
||||
request_body = request.form
|
||||
body_name = sanitize(operation.body_name(request.content_type))
|
||||
if body_name in arguments or has_kwargs:
|
||||
request_body = get_body(request)
|
||||
# Pass form contents separately for Swagger2 for backward compatibility with Connexion 2
|
||||
# Checking for body_name is not enough
|
||||
elif request.mimetype in FORM_CONTENT_TYPES and isinstance(
|
||||
operation, Swagger2Operation
|
||||
):
|
||||
request_body = get_body(request)
|
||||
else:
|
||||
request_body = request.body
|
||||
|
||||
try:
|
||||
query = request.query.to_dict(flat=False)
|
||||
except AttributeError:
|
||||
query = dict(request.query.items())
|
||||
request_body = None
|
||||
|
||||
kwargs.update(
|
||||
operation.get_arguments(
|
||||
request.path_params,
|
||||
query,
|
||||
request_body,
|
||||
request.files,
|
||||
arguments,
|
||||
has_kwargs,
|
||||
sanitize,
|
||||
get_arguments(
|
||||
operation,
|
||||
path_params=request.view_args,
|
||||
query_params=request.args,
|
||||
body=request_body,
|
||||
files=request.files,
|
||||
arguments=arguments,
|
||||
has_kwargs=has_kwargs,
|
||||
sanitize=sanitize,
|
||||
content_type=request.content_type,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -127,3 +84,348 @@ def parameter_to_arg(operation, function, pythonic_params=False):
|
||||
return function(**kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_body(request: ConnexionRequest) -> t.Any:
|
||||
"""Get body from the request based on the content type."""
|
||||
if is_json_mimetype(request.content_type):
|
||||
return request.get_json(silent=True)
|
||||
elif request.mimetype in FORM_CONTENT_TYPES:
|
||||
return request.form
|
||||
else:
|
||||
# Return explicit None instead of empty bytestring so it is handled as null downstream
|
||||
return request.get_data() or None
|
||||
|
||||
|
||||
def inspect_function_arguments(function: t.Callable) -> t.Tuple[t.List[str], bool]:
|
||||
"""
|
||||
Returns the list of variables names of a function and if it
|
||||
accepts keyword arguments.
|
||||
"""
|
||||
parameters = inspect.signature(function).parameters
|
||||
bound_arguments = [
|
||||
name
|
||||
for name, p in parameters.items()
|
||||
if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
|
||||
]
|
||||
has_kwargs = any(p.kind == p.VAR_KEYWORD for p in parameters.values())
|
||||
return list(bound_arguments), has_kwargs
|
||||
|
||||
|
||||
def snake_and_shadow(name: str) -> str:
|
||||
"""
|
||||
Converts the given name into Pythonic form. Firstly it converts CamelCase names to snake_case. Secondly it looks to
|
||||
see if the name matches a known built-in and if it does it appends an underscore to the name.
|
||||
:param name: The parameter name
|
||||
"""
|
||||
snake = inflection.underscore(name)
|
||||
if snake in builtins.__dict__ or keyword.iskeyword(snake):
|
||||
return f"{snake}_"
|
||||
return snake
|
||||
|
||||
|
||||
def sanitized(name: str) -> str:
|
||||
return name and re.sub(
|
||||
"^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", re.sub(r"\[(?!])", "_", name))
|
||||
)
|
||||
|
||||
|
||||
def pythonic(name: str) -> str:
|
||||
name = name and snake_and_shadow(name)
|
||||
return sanitized(name)
|
||||
|
||||
|
||||
def get_arguments(
|
||||
operation: AbstractOperation,
|
||||
*,
|
||||
path_params: dict,
|
||||
query_params: dict,
|
||||
body: t.Any,
|
||||
files: dict,
|
||||
arguments: t.List[str],
|
||||
has_kwargs: bool,
|
||||
sanitize: t.Callable,
|
||||
content_type: str,
|
||||
) -> t.Dict[str, t.Any]:
|
||||
"""
|
||||
get arguments for handler function
|
||||
"""
|
||||
ret = {}
|
||||
ret.update(_get_path_arguments(path_params, operation=operation, sanitize=sanitize))
|
||||
ret.update(
|
||||
_get_query_arguments(
|
||||
query_params,
|
||||
operation=operation,
|
||||
arguments=arguments,
|
||||
has_kwargs=has_kwargs,
|
||||
sanitize=sanitize,
|
||||
)
|
||||
)
|
||||
|
||||
if operation.method.upper() in ["PATCH", "POST", "PUT"]:
|
||||
ret.update(
|
||||
_get_body_argument(
|
||||
body,
|
||||
operation=operation,
|
||||
arguments=arguments,
|
||||
has_kwargs=has_kwargs,
|
||||
sanitize=sanitize,
|
||||
content_type=content_type,
|
||||
)
|
||||
)
|
||||
ret.update(_get_file_arguments(files, arguments, has_kwargs))
|
||||
return ret
|
||||
|
||||
|
||||
def _get_path_arguments(
|
||||
path_params: dict, *, operation: AbstractOperation, sanitize: t.Callable
|
||||
) -> dict:
|
||||
"""
|
||||
Extract handler function arguments from path parameters
|
||||
"""
|
||||
kwargs = {}
|
||||
|
||||
path_definitions = {
|
||||
parameter["name"]: parameter
|
||||
for parameter in operation.parameters
|
||||
if parameter["in"] == "path"
|
||||
}
|
||||
|
||||
for name, value in path_params.items():
|
||||
sanitized_key = sanitize(name)
|
||||
if name in path_definitions:
|
||||
kwargs[sanitized_key] = _get_val_from_param(value, path_definitions[name])
|
||||
else: # Assume path params mechanism used for injection
|
||||
kwargs[sanitized_key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def _get_val_from_param(value: t.Any, param_definitions: t.Dict[str, dict]) -> t.Any:
|
||||
"""Cast a value according to its definition in the specification."""
|
||||
param_schema = param_definitions.get("schema", param_definitions)
|
||||
|
||||
if is_nullable(param_schema) and is_null(value):
|
||||
return None
|
||||
|
||||
if param_schema["type"] == "array":
|
||||
type_ = param_schema["items"]["type"]
|
||||
format_ = param_schema["items"].get("format")
|
||||
return [make_type(part, type_, format_) for part in value]
|
||||
else:
|
||||
type_ = param_schema["type"]
|
||||
format_ = param_schema.get("format")
|
||||
return make_type(value, type_, format_)
|
||||
|
||||
|
||||
def _get_query_arguments(
|
||||
query_params: dict,
|
||||
*,
|
||||
operation: AbstractOperation,
|
||||
arguments: t.List[str],
|
||||
has_kwargs: bool,
|
||||
sanitize: t.Callable,
|
||||
) -> dict:
|
||||
"""
|
||||
extract handler function arguments from the query parameters
|
||||
"""
|
||||
query_definitions = {
|
||||
parameter["name"]: parameter
|
||||
for parameter in operation.parameters
|
||||
if parameter["in"] == "query"
|
||||
}
|
||||
|
||||
default_query_params = _get_query_defaults(query_definitions)
|
||||
|
||||
query_arguments = deepcopy(default_query_params)
|
||||
query_arguments = deep_merge(query_arguments, query_params)
|
||||
return _query_args_helper(
|
||||
query_definitions, query_arguments, arguments, has_kwargs, sanitize
|
||||
)
|
||||
|
||||
|
||||
def _get_query_defaults(query_definitions: t.Dict[str, dict]) -> t.Dict[str, t.Any]:
|
||||
"""Get the default values for the query parameter from the parameter definition."""
|
||||
defaults = {}
|
||||
for k, v in query_definitions.items():
|
||||
try:
|
||||
if "default" in v:
|
||||
defaults[k] = v["default"]
|
||||
elif v["schema"]["type"] == "object":
|
||||
defaults[k] = _get_default_obj(v["schema"])
|
||||
else:
|
||||
defaults[k] = v["schema"]["default"]
|
||||
except KeyError:
|
||||
pass
|
||||
return defaults
|
||||
|
||||
|
||||
def _get_default_obj(schema: dict) -> dict:
|
||||
try:
|
||||
return deepcopy(schema["default"])
|
||||
except KeyError:
|
||||
properties = schema.get("properties", {})
|
||||
return _build_default_obj_recursive(properties, {})
|
||||
|
||||
|
||||
def _build_default_obj_recursive(properties: dict, default_object: dict) -> dict:
|
||||
"""takes disparate and nested default keys, and builds up a default object"""
|
||||
for name, property_ in properties.items():
|
||||
if "default" in property_ and name not in default_object:
|
||||
default_object[name] = copy(property_["default"])
|
||||
elif property_.get("type") == "object" and "properties" in property_:
|
||||
default_object.setdefault(name, {})
|
||||
default_object[name] = _build_default_obj_recursive(
|
||||
property_["properties"], default_object[name]
|
||||
)
|
||||
return default_object
|
||||
|
||||
|
||||
def _query_args_helper(
|
||||
query_definitions: dict,
|
||||
query_arguments: dict,
|
||||
function_arguments: t.List[str],
|
||||
has_kwargs: bool,
|
||||
sanitize: t.Callable,
|
||||
) -> dict:
|
||||
result = {}
|
||||
for key, value in query_arguments.items():
|
||||
sanitized_key = sanitize(key)
|
||||
if not has_kwargs and sanitized_key not in function_arguments:
|
||||
logger.debug(
|
||||
"Query Parameter '%s' (sanitized: '%s') not in function arguments",
|
||||
key,
|
||||
sanitized_key,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Query Parameter '%s' (sanitized: '%s') in function arguments",
|
||||
key,
|
||||
sanitized_key,
|
||||
)
|
||||
try:
|
||||
query_defn = query_definitions[key]
|
||||
except KeyError: # pragma: no cover
|
||||
logger.error(
|
||||
"Function argument '%s' (non-sanitized: %s) not defined in specification",
|
||||
sanitized_key,
|
||||
key,
|
||||
)
|
||||
else:
|
||||
logger.debug("%s is a %s", key, query_defn)
|
||||
result.update({sanitized_key: _get_val_from_param(value, query_defn)})
|
||||
return result
|
||||
|
||||
|
||||
def _get_body_argument(
|
||||
body: t.Any,
|
||||
*,
|
||||
operation: AbstractOperation,
|
||||
arguments: t.List[str],
|
||||
has_kwargs: bool,
|
||||
sanitize: t.Callable,
|
||||
content_type: str,
|
||||
) -> dict:
|
||||
if len(arguments) <= 0 and not has_kwargs:
|
||||
return {}
|
||||
|
||||
body_name = sanitize(operation.body_name(content_type))
|
||||
|
||||
if content_type in FORM_CONTENT_TYPES:
|
||||
result = _get_body_argument_form(
|
||||
body, operation=operation, content_type=content_type
|
||||
)
|
||||
|
||||
# Unpack form values for Swagger for compatibility with Connexion 2 behavior
|
||||
if content_type in FORM_CONTENT_TYPES and isinstance(
|
||||
operation, Swagger2Operation
|
||||
):
|
||||
if has_kwargs:
|
||||
return result
|
||||
else:
|
||||
return {
|
||||
sanitize(name): value
|
||||
for name, value in result.items()
|
||||
if sanitize(name) in arguments
|
||||
}
|
||||
else:
|
||||
result = _get_body_argument_json(
|
||||
body, operation=operation, content_type=content_type
|
||||
)
|
||||
|
||||
if body_name in arguments or has_kwargs:
|
||||
return {body_name: result}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _get_body_argument_json(
|
||||
body: t.Any, *, operation: AbstractOperation, content_type: str
|
||||
) -> t.Any:
|
||||
# if the body came in null, and the schema says it can be null, we decide
|
||||
# to include no value for the body argument, rather than the default body
|
||||
if is_nullable(operation.body_schema(content_type)) and is_null(body):
|
||||
return None
|
||||
|
||||
if body is None:
|
||||
default_body = operation.body_schema(content_type).get("default", {})
|
||||
return deepcopy(default_body)
|
||||
|
||||
return body
|
||||
|
||||
|
||||
def _get_body_argument_form(
|
||||
body: dict, *, operation: AbstractOperation, content_type: str
|
||||
) -> dict:
|
||||
# now determine the actual value for the body (whether it came in or is default)
|
||||
default_body = operation.body_schema(content_type).get("default", {})
|
||||
body_props = {
|
||||
k: {"schema": v}
|
||||
for k, v in operation.body_schema(content_type).get("properties", {}).items()
|
||||
}
|
||||
|
||||
# by OpenAPI specification `additionalProperties` defaults to `true`
|
||||
# see: https://github.com/OAI/OpenAPI-Specification/blame/3.0.2/versions/3.0.2.md#L2305
|
||||
additional_props = operation.body_schema().get("additionalProperties", True)
|
||||
|
||||
body_arg = deepcopy(default_body)
|
||||
body_arg.update(body or {})
|
||||
|
||||
if body_props or additional_props:
|
||||
return _get_typed_body_values(body_arg, body_props, additional_props)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _get_typed_body_values(body_arg, body_props, additional_props):
|
||||
"""
|
||||
Return a copy of the provided body_arg dictionary
|
||||
whose values will have the appropriate types
|
||||
as defined in the provided schemas.
|
||||
|
||||
:type body_arg: type dict
|
||||
:type body_props: dict
|
||||
:type additional_props: dict|bool
|
||||
:rtype: dict
|
||||
"""
|
||||
additional_props_defn = (
|
||||
{"schema": additional_props} if isinstance(additional_props, dict) else None
|
||||
)
|
||||
res = {}
|
||||
|
||||
for key, value in body_arg.items():
|
||||
try:
|
||||
prop_defn = body_props[key]
|
||||
res[key] = _get_val_from_param(value, prop_defn)
|
||||
except KeyError: # pragma: no cover
|
||||
if not additional_props:
|
||||
logger.error(f"Body property '{key}' not defined in body schema")
|
||||
continue
|
||||
if additional_props_defn is not None:
|
||||
value = _get_val_from_param(value, additional_props_defn)
|
||||
res[key] = value
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _get_file_arguments(files, arguments, has_kwargs=False):
|
||||
return {k: v for k, v in files.items() if k in arguments or has_kwargs}
|
||||
|
||||
@@ -2,46 +2,13 @@
|
||||
This module defines interfaces for requests and responses used in Connexion for authentication,
|
||||
validation, serialization, etc.
|
||||
"""
|
||||
import typing as t
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import StreamingResponse as StarletteStreamingResponse
|
||||
|
||||
|
||||
class ConnexionRequest:
|
||||
"""Connexion interface for a request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url,
|
||||
method,
|
||||
path_params=None,
|
||||
query=None,
|
||||
headers=None,
|
||||
form=None,
|
||||
body=None,
|
||||
json_getter=None,
|
||||
files=None,
|
||||
context=None,
|
||||
cookies=None,
|
||||
):
|
||||
self.url = url
|
||||
self.method = method
|
||||
self.path_params = path_params or {}
|
||||
self.query = query or {}
|
||||
self.headers = headers or {}
|
||||
self.form = form or {}
|
||||
self.body = body
|
||||
self.json_getter = json_getter
|
||||
self.files = files
|
||||
self.context = context if context is not None else {}
|
||||
self.cookies = cookies or {}
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
if not hasattr(self, "_json"):
|
||||
self._json = self.json_getter()
|
||||
return self._json
|
||||
|
||||
|
||||
class ConnexionResponse:
|
||||
"""Connexion interface for a response."""
|
||||
|
||||
@@ -62,6 +29,41 @@ class ConnexionResponse:
|
||||
self.is_streamed = is_streamed
|
||||
|
||||
|
||||
class ConnexionRequest:
|
||||
def __init__(self, flask_request: FlaskRequest, uri_parser=None):
|
||||
self._flask_request = flask_request
|
||||
self.uri_parser = uri_parser
|
||||
self._context = None
|
||||
|
||||
@property
|
||||
def context(self):
|
||||
if self._context is None:
|
||||
scope = self._flask_request.environ["asgi.scope"]
|
||||
extensions = scope.setdefault("extensions", {})
|
||||
self._context = extensions.setdefault("connexion_context", {})
|
||||
|
||||
return self._context
|
||||
|
||||
@property
|
||||
def view_args(self) -> t.Dict[str, t.Any]:
|
||||
return self.uri_parser.resolve_path(self._flask_request.view_args)
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
query_params = self._flask_request.args
|
||||
query_params = {k: query_params.getlist(k) for k in query_params}
|
||||
return self.uri_parser.resolve_query(query_params)
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
form = self._flask_request.form.to_dict(flat=False)
|
||||
form_data = self.uri_parser.resolve_form(form)
|
||||
return form_data
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._flask_request, item)
|
||||
|
||||
|
||||
class MiddlewareRequest(StarletteRequest):
|
||||
"""Wraps starlette Request so it can easily be extended."""
|
||||
|
||||
|
||||
15
connexion/middleware/context.py
Normal file
15
connexion/middleware/context.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""The ContextMiddleware creates a global context based the scope. It should be last in the
|
||||
middleware stack, so it exposes the scope passed to the application"""
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.context import _context
|
||||
|
||||
|
||||
class ContextMiddleware:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
context = scope.get("extensions", {}).get("connexion_context", {})
|
||||
_context.set(context)
|
||||
await self.app(scope, receive, send)
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.middleware.abstract import AppMiddleware
|
||||
from connexion.middleware.context import ContextMiddleware
|
||||
from connexion.middleware.exceptions import ExceptionMiddleware
|
||||
from connexion.middleware.request_validation import RequestValidationMiddleware
|
||||
from connexion.middleware.response_validation import ResponseValidationMiddleware
|
||||
@@ -21,6 +22,7 @@ class ConnexionMiddleware:
|
||||
SecurityMiddleware,
|
||||
RequestValidationMiddleware,
|
||||
ResponseValidationMiddleware,
|
||||
ContextMiddleware,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import functools
|
||||
import pathlib
|
||||
import re
|
||||
import typing as t
|
||||
from contextvars import ContextVar
|
||||
|
||||
import starlette.convertors
|
||||
from starlette.routing import Router
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
@@ -53,7 +56,7 @@ class RoutingAPI(AbstractRoutingAPI):
|
||||
resolver: t.Optional[Resolver] = None,
|
||||
resolver_error_handler: t.Optional[t.Callable] = None,
|
||||
debug: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""API implementation on top of Starlette Router for Connexion middleware."""
|
||||
self.next_app = next_app
|
||||
@@ -76,6 +79,8 @@ class RoutingAPI(AbstractRoutingAPI):
|
||||
routing_operation = RoutingOperation.from_operation(
|
||||
operation, next_app=self.next_app
|
||||
)
|
||||
types = operation.get_path_parameter_types()
|
||||
path = starlettify_path(path, types)
|
||||
self._add_operation_internal(method, path, routing_operation)
|
||||
|
||||
def _add_operation_internal(
|
||||
@@ -94,13 +99,15 @@ class RoutingMiddleware(AppMiddleware):
|
||||
self.app = app
|
||||
# Pass unknown routes to next app
|
||||
self.router = Router(default=RoutingOperation(None, self.app))
|
||||
starlette.convertors.register_url_convertor("float", FloatConverter())
|
||||
starlette.convertors.register_url_convertor("int", IntegerConverter())
|
||||
|
||||
def add_api(
|
||||
self,
|
||||
specification: t.Union[pathlib.Path, str, dict],
|
||||
base_path: t.Optional[str] = None,
|
||||
arguments: t.Optional[dict] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Add an API to the router based on a OpenAPI spec.
|
||||
|
||||
@@ -113,7 +120,7 @@ class RoutingMiddleware(AppMiddleware):
|
||||
base_path=base_path,
|
||||
arguments=arguments,
|
||||
next_app=self.app,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
self.router.mount(api.base_path, app=api.router)
|
||||
|
||||
@@ -129,3 +136,46 @@ class RoutingMiddleware(AppMiddleware):
|
||||
# Needs to be set so starlette router throws exceptions instead of returning error responses
|
||||
scope["app"] = self
|
||||
await self.router(scope, receive, send)
|
||||
|
||||
|
||||
PATH_PARAMETER = re.compile(r"\{([^}]*)\}")
|
||||
PATH_PARAMETER_CONVERTERS = {"integer": "int", "number": "float", "path": "path"}
|
||||
|
||||
|
||||
def convert_path_parameter(match, types):
|
||||
name = match.group(1)
|
||||
swagger_type = types.get(name)
|
||||
converter = PATH_PARAMETER_CONVERTERS.get(swagger_type)
|
||||
return f'{{{name.replace("-", "_")}{":" if converter else ""}{converter or ""}}}'
|
||||
|
||||
|
||||
def starlettify_path(swagger_path, types=None):
|
||||
"""
|
||||
Convert swagger path templates to flask path templates
|
||||
|
||||
:type swagger_path: str
|
||||
:type types: dict
|
||||
:rtype: str
|
||||
|
||||
>>> starlettify_path('/foo-bar/{my-param}')
|
||||
'/foo-bar/{my_param}'
|
||||
|
||||
>>> starlettify_path('/foo/{someint}', {'someint': 'int'})
|
||||
'/foo/{someint:int}'
|
||||
"""
|
||||
if types is None:
|
||||
types = {}
|
||||
convert_match = functools.partial(convert_path_parameter, types=types)
|
||||
return PATH_PARAMETER.sub(convert_match, swagger_path)
|
||||
|
||||
|
||||
class FloatConverter(starlette.convertors.FloatConvertor):
|
||||
"""Starlette converter for OpenAPI number type"""
|
||||
|
||||
regex = r"[+-]?[0-9]*(\.[0-9]*)?"
|
||||
|
||||
|
||||
class IntegerConverter(starlette.convertors.IntegerConvertor):
|
||||
"""Starlette converter for OpenAPI integer type"""
|
||||
|
||||
regex = r"[+-]?[0-9]+"
|
||||
|
||||
@@ -6,9 +6,8 @@ and functionality shared between Swagger 2 and OpenAPI 3 specifications.
|
||||
import abc
|
||||
import logging
|
||||
|
||||
from ..decorators.lifecycle import RequestResponseDecorator
|
||||
from ..decorators.parameter import parameter_to_arg
|
||||
from ..utils import all_json
|
||||
from connexion.decorators.lifecycle import RequestResponseDecorator
|
||||
from connexion.utils import all_json
|
||||
|
||||
logger = logging.getLogger("connexion.operations.abstract")
|
||||
|
||||
@@ -46,6 +45,7 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
randomize_endpoint=None,
|
||||
pythonic_params=False,
|
||||
uri_parser_class=None,
|
||||
parameter_to_arg=None,
|
||||
):
|
||||
"""
|
||||
:param api: api that this operation is attached to
|
||||
@@ -87,6 +87,8 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
|
||||
self._responses = self._operation.get("responses", {})
|
||||
|
||||
self.parameter_to_arg = parameter_to_arg
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
return self._api
|
||||
@@ -148,75 +150,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
return self._pythonic_params
|
||||
|
||||
@staticmethod
|
||||
def _get_file_arguments(files, arguments, has_kwargs=False):
|
||||
return {k: v for k, v in files.items() if k in arguments or has_kwargs}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_val_from_param(self, value, query_defn):
|
||||
"""
|
||||
Convert input parameters into the correct type
|
||||
"""
|
||||
|
||||
def _query_args_helper(
|
||||
self, query_defns, query_arguments, function_arguments, has_kwargs, sanitize
|
||||
):
|
||||
res = {}
|
||||
for key, value in query_arguments.items():
|
||||
sanitized_key = sanitize(key)
|
||||
if not has_kwargs and sanitized_key not in function_arguments:
|
||||
logger.debug(
|
||||
"Query Parameter '%s' (sanitized: '%s') not in function arguments",
|
||||
key,
|
||||
sanitized_key,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Query Parameter '%s' (sanitized: '%s') in function arguments",
|
||||
key,
|
||||
sanitized_key,
|
||||
)
|
||||
try:
|
||||
query_defn = query_defns[key]
|
||||
except KeyError: # pragma: no cover
|
||||
logger.error(
|
||||
"Function argument '%s' (non-sanitized: %s) not defined in specification",
|
||||
sanitized_key,
|
||||
key,
|
||||
)
|
||||
else:
|
||||
logger.debug("%s is a %s", key, query_defn)
|
||||
res.update(
|
||||
{sanitized_key: self._get_val_from_param(value, query_defn)}
|
||||
)
|
||||
return res
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_query_arguments(self, query, arguments, has_kwargs, sanitize):
|
||||
"""
|
||||
extract handler function arguments from the query parameters
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_body_argument(self, body, arguments, has_kwargs, sanitize):
|
||||
"""
|
||||
extract handler function arguments from the request body
|
||||
"""
|
||||
|
||||
def _get_path_arguments(self, path_params, sanitize):
|
||||
"""
|
||||
extract handler function arguments from path parameters
|
||||
"""
|
||||
kwargs = {}
|
||||
path_defns = {p["name"]: p for p in self.parameters if p["in"] == "path"}
|
||||
for key, value in path_params.items():
|
||||
sanitized_key = sanitize(key)
|
||||
if key in path_defns:
|
||||
kwargs[sanitized_key] = self._get_val_from_param(value, path_defns[key])
|
||||
else: # Assume path params mechanism used for injection
|
||||
kwargs[sanitized_key] = value
|
||||
return kwargs
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def parameters(self):
|
||||
@@ -238,6 +171,12 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
Content-Types that the operation consumes
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def body_name(self, content_type: str) -> str:
|
||||
"""
|
||||
Name of the body in the spec.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def body_schema(self, content_type: str = None) -> dict:
|
||||
"""
|
||||
@@ -251,23 +190,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
def get_arguments(
|
||||
self, path_params, query_params, body, files, arguments, has_kwargs, sanitize
|
||||
):
|
||||
"""
|
||||
get arguments for handler function
|
||||
"""
|
||||
ret = {}
|
||||
ret.update(self._get_path_arguments(path_params, sanitize))
|
||||
ret.update(
|
||||
self._get_query_arguments(query_params, arguments, has_kwargs, sanitize)
|
||||
)
|
||||
|
||||
if self.method.upper() in ["PATCH", "POST", "PUT"]:
|
||||
ret.update(self._get_body_argument(body, arguments, has_kwargs, sanitize))
|
||||
ret.update(self._get_file_arguments(files, arguments, has_kwargs))
|
||||
return ret
|
||||
|
||||
def response_definition(self, status_code=None, content_type=None):
|
||||
"""
|
||||
response definition for this endpoint
|
||||
@@ -335,16 +257,18 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
function = parameter_to_arg(
|
||||
function = self._resolution.function
|
||||
|
||||
if self.parameter_to_arg:
|
||||
function = self.parameter_to_arg(
|
||||
self,
|
||||
self._resolution.function,
|
||||
function,
|
||||
self.pythonic_params,
|
||||
)
|
||||
|
||||
uri_parsing_decorator = self._uri_parsing_decorator
|
||||
function = uri_parsing_decorator(function)
|
||||
|
||||
function = self._request_response_decorator(function)
|
||||
function = self._request_response_decorator(
|
||||
function, self._uri_parsing_decorator
|
||||
)
|
||||
|
||||
return function
|
||||
|
||||
|
||||
@@ -3,14 +3,11 @@ This module defines an OpenAPIOperation class, a Connexion operation specific fo
|
||||
"""
|
||||
|
||||
import logging
|
||||
from copy import copy, deepcopy
|
||||
|
||||
from connexion.datastructures import MediaTypeDict
|
||||
from connexion.operations.abstract import AbstractOperation
|
||||
|
||||
from ..decorators.uri_parsing import OpenAPIURIParser
|
||||
from ..http_facts import FORM_CONTENT_TYPES
|
||||
from ..utils import deep_get, deep_merge, is_null, is_nullable, make_type
|
||||
from connexion.uri_parsing import OpenAPIURIParser
|
||||
from connexion.utils import deep_get
|
||||
|
||||
logger = logging.getLogger("connexion.operations.openapi3")
|
||||
|
||||
@@ -35,6 +32,7 @@ class OpenAPIOperation(AbstractOperation):
|
||||
randomize_endpoint=None,
|
||||
pythonic_params=False,
|
||||
uri_parser_class=None,
|
||||
parameter_to_arg=None,
|
||||
):
|
||||
"""
|
||||
This class uses the OperationID identify the module and function that will handle the operation
|
||||
@@ -88,6 +86,7 @@ class OpenAPIOperation(AbstractOperation):
|
||||
randomize_endpoint=randomize_endpoint,
|
||||
pythonic_params=pythonic_params,
|
||||
uri_parser_class=uri_parser_class,
|
||||
parameter_to_arg=parameter_to_arg,
|
||||
)
|
||||
|
||||
self._request_body = operation.get("requestBody", {})
|
||||
@@ -144,8 +143,9 @@ class OpenAPIOperation(AbstractOperation):
|
||||
def produces(self):
|
||||
return self._produces
|
||||
|
||||
def with_definitions(self, schema):
|
||||
def with_definitions(self, schema: dict):
|
||||
if self.components:
|
||||
schema.setdefault("schema", {})
|
||||
schema["schema"]["components"] = self.components
|
||||
return schema
|
||||
|
||||
@@ -240,6 +240,9 @@ class OpenAPIOperation(AbstractOperation):
|
||||
types[path_defn["name"]] = path_type
|
||||
return types
|
||||
|
||||
def body_name(self, _content_type: str) -> str:
|
||||
return self.request_body.get("x-body-name", "body")
|
||||
|
||||
def body_schema(self, content_type: str = None) -> dict:
|
||||
"""
|
||||
The body schema definition for this operation.
|
||||
@@ -267,131 +270,3 @@ class OpenAPIOperation(AbstractOperation):
|
||||
res = content_type_dict.get(content_type, {})
|
||||
return self.with_definitions(res)
|
||||
return {}
|
||||
|
||||
def _get_body_argument(self, body, arguments, has_kwargs, sanitize):
|
||||
if len(arguments) <= 0 and not has_kwargs:
|
||||
return {}
|
||||
|
||||
x_body_name = sanitize(self.request_body.get("x-body-name", "body"))
|
||||
|
||||
if self.consumes[0] in FORM_CONTENT_TYPES:
|
||||
result = self._get_body_argument_form(body)
|
||||
else:
|
||||
result = self._get_body_argument_json(body)
|
||||
|
||||
if x_body_name in arguments or has_kwargs:
|
||||
return {x_body_name: result}
|
||||
return {}
|
||||
|
||||
def _get_body_argument_json(self, body):
|
||||
# if the body came in null, and the schema says it can be null, we decide
|
||||
# to include no value for the body argument, rather than the default body
|
||||
if is_nullable(self.body_schema()) and is_null(body):
|
||||
return None
|
||||
|
||||
if body is None:
|
||||
default_body = self.body_schema().get("default", {})
|
||||
return deepcopy(default_body)
|
||||
|
||||
return body
|
||||
|
||||
def _get_body_argument_form(self, body):
|
||||
# now determine the actual value for the body (whether it came in or is default)
|
||||
default_body = self.body_schema().get("default", {})
|
||||
body_props = {
|
||||
k: {"schema": v}
|
||||
for k, v in self.body_schema().get("properties", {}).items()
|
||||
}
|
||||
|
||||
# by OpenAPI specification `additionalProperties` defaults to `true`
|
||||
# see: https://github.com/OAI/OpenAPI-Specification/blame/3.0.2/versions/3.0.2.md#L2305
|
||||
additional_props = self.body_schema().get("additionalProperties", True)
|
||||
|
||||
body_arg = deepcopy(default_body)
|
||||
body_arg.update(body or {})
|
||||
|
||||
if body_props or additional_props:
|
||||
return self._get_typed_body_values(body_arg, body_props, additional_props)
|
||||
return {}
|
||||
|
||||
def _get_typed_body_values(self, body_arg, body_props, additional_props):
|
||||
"""
|
||||
Return a copy of the provided body_arg dictionary
|
||||
whose values will have the appropriate types
|
||||
as defined in the provided schemas.
|
||||
|
||||
:type body_arg: type dict
|
||||
:type body_props: dict
|
||||
:type additional_props: dict|bool
|
||||
:rtype: dict
|
||||
"""
|
||||
additional_props_defn = (
|
||||
{"schema": additional_props} if isinstance(additional_props, dict) else None
|
||||
)
|
||||
res = {}
|
||||
|
||||
for key, value in body_arg.items():
|
||||
try:
|
||||
prop_defn = body_props[key]
|
||||
res[key] = self._get_val_from_param(value, prop_defn)
|
||||
except KeyError: # pragma: no cover
|
||||
if not additional_props:
|
||||
logger.error(f"Body property '{key}' not defined in body schema")
|
||||
continue
|
||||
if additional_props_defn is not None:
|
||||
value = self._get_val_from_param(value, additional_props_defn)
|
||||
res[key] = value
|
||||
|
||||
return res
|
||||
|
||||
def _build_default_obj_recursive(self, _properties, res):
|
||||
"""takes disparate and nested default keys, and builds up a default object"""
|
||||
for key, prop in _properties.items():
|
||||
if "default" in prop and key not in res:
|
||||
res[key] = copy(prop["default"])
|
||||
elif prop.get("type") == "object" and "properties" in prop:
|
||||
res.setdefault(key, {})
|
||||
res[key] = self._build_default_obj_recursive(
|
||||
prop["properties"], res[key]
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_default_obj(self, schema):
|
||||
try:
|
||||
return deepcopy(schema["default"])
|
||||
except KeyError:
|
||||
_properties = schema.get("properties", {})
|
||||
return self._build_default_obj_recursive(_properties, {})
|
||||
|
||||
def _get_query_defaults(self, query_defns):
|
||||
defaults = {}
|
||||
for k, v in query_defns.items():
|
||||
try:
|
||||
if v["schema"]["type"] == "object":
|
||||
defaults[k] = self._get_default_obj(v["schema"])
|
||||
else:
|
||||
defaults[k] = v["schema"]["default"]
|
||||
except KeyError:
|
||||
pass
|
||||
return defaults
|
||||
|
||||
def _get_query_arguments(self, query, arguments, has_kwargs, sanitize):
|
||||
query_defns = {p["name"]: p for p in self.parameters if p["in"] == "query"}
|
||||
default_query_params = self._get_query_defaults(query_defns)
|
||||
|
||||
query_arguments = deepcopy(default_query_params)
|
||||
query_arguments = deep_merge(query_arguments, query)
|
||||
return self._query_args_helper(
|
||||
query_defns, query_arguments, arguments, has_kwargs, sanitize
|
||||
)
|
||||
|
||||
def _get_val_from_param(self, value, query_defn):
|
||||
query_schema = query_defn["schema"]
|
||||
|
||||
if is_nullable(query_schema) and is_null(value):
|
||||
return None
|
||||
|
||||
if query_schema["type"] == "array":
|
||||
return [make_type(part, query_schema["items"]["type"]) for part in value]
|
||||
else:
|
||||
return make_type(value, query_schema["type"])
|
||||
|
||||
@@ -4,14 +4,12 @@ This module defines a Swagger2Operation class, a Connexion operation specific fo
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
from copy import deepcopy
|
||||
|
||||
from connexion.exceptions import InvalidSpecification
|
||||
from connexion.http_facts import FORM_CONTENT_TYPES
|
||||
from connexion.operations.abstract import AbstractOperation
|
||||
|
||||
from ..decorators.uri_parsing import Swagger2URIParser
|
||||
from ..exceptions import InvalidSpecification
|
||||
from ..http_facts import FORM_CONTENT_TYPES
|
||||
from ..utils import deep_get, is_null, is_nullable, make_type
|
||||
from connexion.uri_parsing import Swagger2URIParser
|
||||
from connexion.utils import deep_get
|
||||
|
||||
logger = logging.getLogger("connexion.operations.swagger2")
|
||||
|
||||
@@ -52,6 +50,7 @@ class Swagger2Operation(AbstractOperation):
|
||||
randomize_endpoint=None,
|
||||
pythonic_params=False,
|
||||
uri_parser_class=None,
|
||||
parameter_to_arg=None,
|
||||
):
|
||||
"""
|
||||
:param api: api that this operation is attached to
|
||||
@@ -101,6 +100,7 @@ class Swagger2Operation(AbstractOperation):
|
||||
randomize_endpoint=randomize_endpoint,
|
||||
pythonic_params=pythonic_params,
|
||||
uri_parser_class=uri_parser_class,
|
||||
parameter_to_arg=parameter_to_arg,
|
||||
)
|
||||
|
||||
self._produces = operation.get("produces", app_produces)
|
||||
@@ -222,6 +222,9 @@ class Swagger2Operation(AbstractOperation):
|
||||
except KeyError:
|
||||
raise
|
||||
|
||||
def body_name(self, content_type: str = None) -> str:
|
||||
return self.body_definition(content_type).get("name", "body")
|
||||
|
||||
def body_schema(self, content_type: str = None) -> dict:
|
||||
"""
|
||||
The body schema definition for this operation.
|
||||
@@ -237,6 +240,7 @@ class Swagger2Operation(AbstractOperation):
|
||||
|
||||
:rtype: dict
|
||||
"""
|
||||
# TODO: cache
|
||||
if content_type in FORM_CONTENT_TYPES:
|
||||
form_parameters = [p for p in self.parameters if p["in"] == "formData"]
|
||||
body_definition = self._transform_form(form_parameters)
|
||||
@@ -248,12 +252,21 @@ class Swagger2Operation(AbstractOperation):
|
||||
method=self.method, path=self.path
|
||||
)
|
||||
)
|
||||
body_definition = body_parameters[0] if body_parameters else {}
|
||||
body_parameter = body_parameters[0] if body_parameters else {}
|
||||
body_definition = self._transform_json(body_parameter)
|
||||
return body_definition
|
||||
|
||||
def _transform_json(self, body_parameter: dict) -> dict:
|
||||
"""Translate Swagger2 json parameters into OpenAPI 3 jsonschema spec."""
|
||||
nullable = body_parameter.get("x-nullable")
|
||||
if nullable is not None:
|
||||
body_parameter["schema"]["nullable"] = nullable
|
||||
return body_parameter
|
||||
|
||||
def _transform_form(self, form_parameters: t.List[dict]) -> dict:
|
||||
"""Translate Swagger2 form parameters into OpenAPI 3 jsonschema spec."""
|
||||
properties = {}
|
||||
defaults = {}
|
||||
required = []
|
||||
encoding = {}
|
||||
|
||||
@@ -276,7 +289,7 @@ class Swagger2Operation(AbstractOperation):
|
||||
|
||||
default = param.get("default")
|
||||
if default is not None:
|
||||
prop["default"] = default
|
||||
defaults[param["name"]] = default
|
||||
|
||||
nullable = param.get("x-nullable")
|
||||
if nullable is not None:
|
||||
@@ -305,6 +318,7 @@ class Swagger2Operation(AbstractOperation):
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"default": defaults,
|
||||
"required": required,
|
||||
}
|
||||
}
|
||||
@@ -313,76 +327,3 @@ class Swagger2Operation(AbstractOperation):
|
||||
definition["encoding"] = encoding
|
||||
|
||||
return definition
|
||||
|
||||
def _get_query_arguments(self, query, arguments, has_kwargs, sanitize):
|
||||
query_defns = {p["name"]: p for p in self.parameters if p["in"] == "query"}
|
||||
default_query_params = {
|
||||
k: v["default"] for k, v in query_defns.items() if "default" in v
|
||||
}
|
||||
query_arguments = deepcopy(default_query_params)
|
||||
query_arguments.update(query)
|
||||
return self._query_args_helper(
|
||||
query_defns, query_arguments, arguments, has_kwargs, sanitize
|
||||
)
|
||||
|
||||
def _get_body_argument(self, body, arguments, has_kwargs, sanitize):
|
||||
kwargs = {}
|
||||
body_parameters = [p for p in self.parameters if p["in"] == "body"] or [{}]
|
||||
if body is None:
|
||||
body = deepcopy(body_parameters[0].get("schema", {}).get("default"))
|
||||
body_name = sanitize(body_parameters[0].get("name"))
|
||||
|
||||
form_defns = {p["name"]: p for p in self.parameters if p["in"] == "formData"}
|
||||
|
||||
default_form_params = {
|
||||
k: v["default"] for k, v in form_defns.items() if "default" in v
|
||||
}
|
||||
|
||||
# Add body parameters
|
||||
if body_name:
|
||||
if not has_kwargs and body_name not in arguments:
|
||||
logger.debug("Body parameter '%s' not in function arguments", body_name)
|
||||
else:
|
||||
logger.debug("Body parameter '%s' in function arguments", body_name)
|
||||
kwargs[body_name] = body
|
||||
|
||||
# Add formData parameters
|
||||
form_arguments = deepcopy(default_form_params)
|
||||
if form_defns and body:
|
||||
form_arguments.update(body)
|
||||
for key, value in form_arguments.items():
|
||||
sanitized_key = sanitize(key)
|
||||
if not has_kwargs and sanitized_key not in arguments:
|
||||
logger.debug(
|
||||
"FormData parameter '%s' (sanitized: '%s') not in function arguments",
|
||||
key,
|
||||
sanitized_key,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"FormData parameter '%s' (sanitized: '%s') in function arguments",
|
||||
key,
|
||||
sanitized_key,
|
||||
)
|
||||
try:
|
||||
form_defn = form_defns[key]
|
||||
except KeyError: # pragma: no cover
|
||||
logger.error(
|
||||
"Function argument '%s' (non-sanitized: %s) not defined in specification",
|
||||
key,
|
||||
sanitized_key,
|
||||
)
|
||||
else:
|
||||
kwargs[sanitized_key] = self._get_val_from_param(value, form_defn)
|
||||
return kwargs
|
||||
|
||||
def _get_val_from_param(self, value, query_defn):
|
||||
if is_nullable(query_defn) and is_null(value):
|
||||
return None
|
||||
|
||||
query_schema = query_defn
|
||||
|
||||
if query_schema["type"] == "array":
|
||||
return [make_type(part, query_defn["items"]["type"]) for part in value]
|
||||
else:
|
||||
return make_type(value, query_defn["type"])
|
||||
|
||||
@@ -10,7 +10,7 @@ try:
|
||||
except ImportError:
|
||||
swagger_ui_2_path = swagger_ui_3_path = None
|
||||
|
||||
from connexion.decorators.uri_parsing import AbstractURIParser
|
||||
from connexion.uri_parsing import AbstractURIParser
|
||||
|
||||
NO_UI_MSG = """The swagger_ui directory could not be found.
|
||||
Please install connexion with extra install: pip install connexion[swagger-ui]
|
||||
|
||||
@@ -13,14 +13,14 @@ import typing as t
|
||||
|
||||
import httpx
|
||||
|
||||
from ..decorators.parameter import inspect_function_arguments
|
||||
from ..exceptions import (
|
||||
from connexion.decorators.parameter import inspect_function_arguments
|
||||
from connexion.exceptions import (
|
||||
ConnexionException,
|
||||
OAuthProblem,
|
||||
OAuthResponseProblem,
|
||||
OAuthScopeProblem,
|
||||
)
|
||||
from ..utils import get_function_from_name
|
||||
from connexion.utils import get_function_from_name
|
||||
|
||||
logger = logging.getLogger("connexion.api.security")
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
This module defines SecurityHandlerFactories which support the creation of security
|
||||
handlers for operations.
|
||||
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
from .security_handler_factory import SecurityHandlerFactory # NOQA
|
||||
@@ -1,9 +1,9 @@
|
||||
"""
|
||||
This module defines view function decorators to split query and path parameters.
|
||||
This module defines URIParsers which parse query and path parameters according to OpenAPI
|
||||
serialization rules.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -130,33 +130,6 @@ class AbstractURIParser(metaclass=abc.ABCMeta):
|
||||
|
||||
return resolved_param
|
||||
|
||||
def __call__(self, function):
|
||||
"""
|
||||
:type function: types.FunctionType
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
|
||||
@functools.wraps(function)
|
||||
def wrapper(request):
|
||||
def coerce_dict(md):
|
||||
"""MultiDict -> dict of lists"""
|
||||
try:
|
||||
return md.to_dict(flat=False)
|
||||
except AttributeError:
|
||||
return dict(md.items())
|
||||
|
||||
query = coerce_dict(request.query)
|
||||
path_params = coerce_dict(request.path_params)
|
||||
form = coerce_dict(request.form)
|
||||
|
||||
request.query = self.resolve_query(query)
|
||||
request.path_params = self.resolve_path(path_params)
|
||||
request.form = self.resolve_form(form)
|
||||
response = function(request)
|
||||
return response
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class OpenAPIURIParser(AbstractURIParser):
|
||||
style_defaults = {
|
||||
@@ -281,7 +254,7 @@ class OpenAPIURIParser(AbstractURIParser):
|
||||
class Swagger2URIParser(AbstractURIParser):
|
||||
"""
|
||||
Adheres to the Swagger2 spec,
|
||||
Assumes the the last defined query parameter should be used.
|
||||
Assumes that the last defined query parameter should be used.
|
||||
"""
|
||||
|
||||
parsable_parameters = ["query", "path", "formData"]
|
||||
|
||||
@@ -35,18 +35,24 @@ def boolean(s):
|
||||
|
||||
|
||||
# https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#data-types
|
||||
TYPE_MAP = {
|
||||
TYPE_MAP: t.Dict[str, t.Any] = {
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"string": str,
|
||||
"boolean": boolean,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"file": lambda x: x, # Don't cast files
|
||||
} # map of swagger types to python types
|
||||
|
||||
|
||||
def make_type(value, _type):
|
||||
type_func = TYPE_MAP[_type] # convert value to right type
|
||||
def make_type(value: t.Any, type_: str, format_: t.Optional[str]) -> t.Any:
|
||||
"""Cast a value to the type defined in the specification."""
|
||||
# In OpenAPI, files are represented with string type and binary format
|
||||
if type_ == "string" and format_ == "binary":
|
||||
type_ = "file"
|
||||
|
||||
type_func = TYPE_MAP[type_]
|
||||
return type_func(value)
|
||||
|
||||
|
||||
@@ -141,6 +147,8 @@ def is_json_mimetype(mimetype):
|
||||
:type mimetype: str
|
||||
:rtype: bool
|
||||
"""
|
||||
if mimetype is None:
|
||||
return False
|
||||
|
||||
maintype, subtype = mimetype.split("/") # type: str, str
|
||||
if ";" in subtype:
|
||||
@@ -209,9 +217,7 @@ def has_coroutine(function, api=None):
|
||||
return iscorofunc(function)
|
||||
|
||||
else:
|
||||
return any(
|
||||
iscorofunc(func) for func in (function, api.get_request, api.get_response)
|
||||
)
|
||||
return any(iscorofunc(func) for func in (function, api.get_response))
|
||||
|
||||
|
||||
def yamldumper(openapi):
|
||||
|
||||
@@ -6,13 +6,13 @@ from starlette.datastructures import FormData, Headers, UploadFile
|
||||
from starlette.formparsers import FormParser, MultiPartParser
|
||||
from starlette.types import Receive, Scope
|
||||
|
||||
from connexion.decorators.uri_parsing import AbstractURIParser
|
||||
from connexion.exceptions import (
|
||||
BadRequestProblem,
|
||||
ExtraParameterProblem,
|
||||
TypeValidationError,
|
||||
)
|
||||
from connexion.json_schema import Draft4RequestValidator
|
||||
from connexion.uri_parsing import AbstractURIParser
|
||||
from connexion.utils import coerce_type, is_null
|
||||
|
||||
logger = logging.getLogger("connexion.validators.form_data")
|
||||
|
||||
@@ -102,6 +102,9 @@ class ParameterValidator:
|
||||
return self.validate_parameter("query", val, param)
|
||||
|
||||
def validate_path_parameter(self, param, request):
|
||||
# TODO: activate
|
||||
# path_params = self.uri_parser.resolve_path(request.path_params)
|
||||
# val = path_params.get(param["name"].replace("-", "_"))
|
||||
val = request.path_params.get(param["name"].replace("-", "_"))
|
||||
return self.validate_parameter("path", val, param)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ def test_app_with_different_server_option(simple_api_spec_dir, spec):
|
||||
|
||||
|
||||
def test_app_with_different_uri_parser(simple_api_spec_dir):
|
||||
from connexion.decorators.uri_parsing import FirstValueURIParser
|
||||
from connexion.uri_parsing import FirstValueURIParser
|
||||
|
||||
app = App(
|
||||
__name__,
|
||||
|
||||
@@ -172,7 +172,7 @@ def test_path_parameter_someint__bad(simple_app):
|
||||
# non-integer values will not match Flask route
|
||||
app_client = simple_app.app.test_client()
|
||||
resp = app_client.get("/v1.0/test-int-path/foo") # type: flask.Response
|
||||
assert resp.status_code == 400, resp.text
|
||||
assert resp.status_code == 404, resp.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -205,7 +205,7 @@ def test_path_parameter_somefloat__bad(simple_app):
|
||||
# non-float values will not match Flask route
|
||||
app_client = simple_app.app.test_client()
|
||||
resp = app_client.get("/v1.0/test-float-path/123,45") # type: flask.Response
|
||||
assert resp.status_code == 400, resp.text
|
||||
assert resp.status_code == 404, resp.text
|
||||
|
||||
|
||||
def test_default_param(strict_app):
|
||||
@@ -280,7 +280,7 @@ def test_formdata_file_upload(simple_app):
|
||||
app_client = simple_app.app.test_client()
|
||||
resp = app_client.post(
|
||||
"/v1.0/test-formData-file-upload",
|
||||
data={"formData": (BytesIO(b"file contents"), "filename.txt")},
|
||||
data={"fileData": (BytesIO(b"file contents"), "filename.txt")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
response = json.loads(resp.data.decode("utf-8", "replace"))
|
||||
@@ -293,8 +293,8 @@ def test_formdata_file_upload_bad_request(simple_app):
|
||||
assert resp.status_code == 400
|
||||
response = json.loads(resp.data.decode("utf-8", "replace"))
|
||||
assert response["detail"] in [
|
||||
"Missing formdata parameter 'formData'",
|
||||
"'formData' is a required property", # OAS3
|
||||
"Missing formdata parameter 'fileData'",
|
||||
"'fileData' is a required property", # OAS3
|
||||
]
|
||||
|
||||
|
||||
@@ -302,7 +302,7 @@ def test_formdata_file_upload_missing_param(simple_app):
|
||||
app_client = simple_app.app.test_client()
|
||||
resp = app_client.post(
|
||||
"/v1.0/test-formData-file-upload-missing-param",
|
||||
data={"missing_formData": (BytesIO(b"file contents"), "example.txt")},
|
||||
data={"missing_fileData": (BytesIO(b"file contents"), "example.txt")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@@ -3,11 +3,19 @@ from unittest.mock import MagicMock
|
||||
from connexion.decorators.parameter import parameter_to_arg, pythonic
|
||||
|
||||
|
||||
def test_injection():
|
||||
request = MagicMock(name="request", path_params={"p1": "123"})
|
||||
request.args = {}
|
||||
async def test_injection():
|
||||
request = MagicMock(name="request")
|
||||
request.query_params = {}
|
||||
request.path_params = {"p1": "123"}
|
||||
request.headers = {}
|
||||
request.params = {}
|
||||
request.content_type = "application/json"
|
||||
|
||||
async def coro():
|
||||
return
|
||||
|
||||
request.json = coro
|
||||
request.loop = None
|
||||
request.context = {}
|
||||
|
||||
func = MagicMock()
|
||||
|
||||
@@ -16,17 +24,29 @@ def test_injection():
|
||||
|
||||
class Op:
|
||||
consumes = ["application/json"]
|
||||
parameters = []
|
||||
method = "GET"
|
||||
|
||||
def get_arguments(self, *args, **kwargs):
|
||||
return {"p1": "123"}
|
||||
def body_name(self, *args, **kwargs):
|
||||
return "body"
|
||||
|
||||
parameter_to_arg(Op(), handler)(request)
|
||||
parameter_decorator = parameter_to_arg(Op(), handler)
|
||||
await parameter_decorator(request)
|
||||
func.assert_called_with(p1="123")
|
||||
|
||||
|
||||
def test_injection_with_context():
|
||||
async def test_injection_with_context():
|
||||
request = MagicMock(name="request")
|
||||
|
||||
async def coro():
|
||||
return
|
||||
|
||||
request.json = coro
|
||||
request.loop = None
|
||||
request.context = {}
|
||||
request.content_type = "application/json"
|
||||
request.path_params = {"p1": "123"}
|
||||
|
||||
func = MagicMock()
|
||||
|
||||
def handler(context_, **kwargs):
|
||||
@@ -34,11 +54,14 @@ def test_injection_with_context():
|
||||
|
||||
class Op2:
|
||||
consumes = ["application/json"]
|
||||
parameters = []
|
||||
method = "GET"
|
||||
|
||||
def get_arguments(self, *args, **kwargs):
|
||||
return {"p1": "123"}
|
||||
def body_name(self, *args, **kwargs):
|
||||
return "body"
|
||||
|
||||
parameter_to_arg(Op2(), handler)(request)
|
||||
parameter_decorator = parameter_to_arg(Op2(), handler)
|
||||
await parameter_decorator(request)
|
||||
func.assert_called_with(request.context, p1="123")
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from connexion.decorators.uri_parsing import (
|
||||
from connexion.uri_parsing import (
|
||||
AlwaysMultiURIParser,
|
||||
FirstValueURIParser,
|
||||
OpenAPIURIParser,
|
||||
@@ -46,7 +46,9 @@ MULTI = "multi"
|
||||
(AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY2, PIPES),
|
||||
],
|
||||
)
|
||||
def test_uri_parser_query_params(parser_class, expected, query_in, collection_format):
|
||||
async def test_uri_parser_query_params(
|
||||
parser_class, expected, query_in, collection_format
|
||||
):
|
||||
class Request:
|
||||
query = query_in
|
||||
path_params = {}
|
||||
@@ -63,9 +65,9 @@ def test_uri_parser_query_params(parser_class, expected, query_in, collection_fo
|
||||
}
|
||||
]
|
||||
body_defn = {}
|
||||
p = parser_class(parameters, body_defn)
|
||||
res = p(lambda x: x)(request)
|
||||
assert res.query["letters"] == expected
|
||||
parser = parser_class(parameters, body_defn)
|
||||
res = parser.resolve_query(request.query.to_dict(flat=False))
|
||||
assert res["letters"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -82,7 +84,9 @@ def test_uri_parser_query_params(parser_class, expected, query_in, collection_fo
|
||||
(AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY2, PIPES),
|
||||
],
|
||||
)
|
||||
def test_uri_parser_form_params(parser_class, expected, query_in, collection_format):
|
||||
async def test_uri_parser_form_params(
|
||||
parser_class, expected, query_in, collection_format
|
||||
):
|
||||
class Request:
|
||||
query = {}
|
||||
form = query_in
|
||||
@@ -99,9 +103,9 @@ def test_uri_parser_form_params(parser_class, expected, query_in, collection_for
|
||||
}
|
||||
]
|
||||
body_defn = {}
|
||||
p = parser_class(parameters, body_defn)
|
||||
res = p(lambda x: x)(request)
|
||||
assert res.form["letters"] == expected
|
||||
parser = parser_class(parameters, body_defn)
|
||||
res = parser.resolve_form(request.form.to_dict(flat=False))
|
||||
assert res["letters"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -115,7 +119,9 @@ def test_uri_parser_form_params(parser_class, expected, query_in, collection_for
|
||||
(AlwaysMultiURIParser, ["d", "e", "f"], PATH2, PIPES),
|
||||
],
|
||||
)
|
||||
def test_uri_parser_path_params(parser_class, expected, query_in, collection_format):
|
||||
async def test_uri_parser_path_params(
|
||||
parser_class, expected, query_in, collection_format
|
||||
):
|
||||
class Request:
|
||||
query = {}
|
||||
form = {}
|
||||
@@ -132,9 +138,9 @@ def test_uri_parser_path_params(parser_class, expected, query_in, collection_for
|
||||
}
|
||||
]
|
||||
body_defn = {}
|
||||
p = parser_class(parameters, body_defn)
|
||||
res = p(lambda x: x)(request)
|
||||
assert res.path_params["letters"] == expected
|
||||
parser = parser_class(parameters, body_defn)
|
||||
res = parser.resolve_path(request.path_params)
|
||||
assert res["letters"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -149,7 +155,7 @@ def test_uri_parser_path_params(parser_class, expected, query_in, collection_for
|
||||
(AlwaysMultiURIParser, ["a", "b", "c", "d", "e", "f"], QUERY4, PIPES),
|
||||
],
|
||||
)
|
||||
def test_uri_parser_query_params_with_square_brackets(
|
||||
async def test_uri_parser_query_params_with_square_brackets(
|
||||
parser_class, expected, query_in, collection_format
|
||||
):
|
||||
class Request:
|
||||
@@ -168,9 +174,9 @@ def test_uri_parser_query_params_with_square_brackets(
|
||||
}
|
||||
]
|
||||
body_defn = {}
|
||||
p = parser_class(parameters, body_defn)
|
||||
res = p(lambda x: x)(request)
|
||||
assert res.query["letters[eq]"] == expected
|
||||
parser = parser_class(parameters, body_defn)
|
||||
res = parser.resolve_query(request.query.to_dict(flat=False))
|
||||
assert res["letters[eq]"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -188,7 +194,7 @@ def test_uri_parser_query_params_with_square_brackets(
|
||||
(AlwaysMultiURIParser, ["a"], QUERY6, PIPES),
|
||||
],
|
||||
)
|
||||
def test_uri_parser_query_params_with_underscores(
|
||||
async def test_uri_parser_query_params_with_underscores(
|
||||
parser_class, expected, query_in, collection_format
|
||||
):
|
||||
class Request:
|
||||
@@ -207,9 +213,9 @@ def test_uri_parser_query_params_with_underscores(
|
||||
}
|
||||
]
|
||||
body_defn = {}
|
||||
p = parser_class(parameters, body_defn)
|
||||
res = p(lambda x: x)(request)
|
||||
assert res.query["letters_eq"] == expected
|
||||
parser = parser_class(parameters, body_defn)
|
||||
res = parser.resolve_query(request.query.to_dict(flat=False))
|
||||
assert res["letters_eq"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -231,7 +237,7 @@ def test_uri_parser_query_params_with_underscores(
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_uri_parser_query_params_with_malformed_names(
|
||||
async def test_uri_parser_query_params_with_malformed_names(
|
||||
parser_class, query_in, collection_format, explode, expected
|
||||
):
|
||||
class Request:
|
||||
@@ -253,6 +259,6 @@ def test_uri_parser_query_params_with_malformed_names(
|
||||
}
|
||||
]
|
||||
body_defn = {}
|
||||
p = parser_class(parameters, body_defn)
|
||||
res = p(lambda x: x)(request)
|
||||
assert res.query == expected
|
||||
parser = parser_class(parameters, body_defn)
|
||||
res = parser.resolve_query(request.query.to_dict(flat=False))
|
||||
assert res == expected
|
||||
|
||||
@@ -84,7 +84,7 @@ def get_bye_secure(name, user, token_info):
|
||||
|
||||
|
||||
def get_bye_secure_from_flask():
|
||||
return "Goodbye {user} (Secure!)".format(user=context["user"])
|
||||
return "Goodbye {user} (Secure!)".format(user=context.context["user"])
|
||||
|
||||
|
||||
def get_bye_secure_from_connexion(context_):
|
||||
@@ -314,9 +314,11 @@ def test_formdata_missing_param():
|
||||
return ""
|
||||
|
||||
|
||||
def test_formdata_file_upload(formData, **kwargs):
|
||||
filename = formData.filename
|
||||
contents = formData.read().decode("utf-8", "replace")
|
||||
def test_formdata_file_upload(fileData, **kwargs):
|
||||
"""In Swagger, form paramaeters and files are passed separately"""
|
||||
filename = fileData.filename
|
||||
contents = fileData.read()
|
||||
contents = contents.decode("utf-8", "replace")
|
||||
return {filename: contents}
|
||||
|
||||
|
||||
|
||||
82
tests/fixtures/simple/openapi.yaml
vendored
82
tests/fixtures/simple/openapi.yaml
vendored
@@ -3,25 +3,6 @@ info:
|
||||
title: '{{title}}'
|
||||
version: '1.0'
|
||||
paths:
|
||||
'/greeting/{name}':
|
||||
post:
|
||||
summary: Generate greeting
|
||||
description: Generates a greeting message.
|
||||
operationId: fakeapi.hello.post_greeting
|
||||
responses:
|
||||
'200':
|
||||
description: greeting response
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
type: object
|
||||
parameters:
|
||||
- name: name
|
||||
in: path
|
||||
description: Name of the person to greet.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
'/greeting/{name}/{remainder}':
|
||||
post:
|
||||
summary: Generate greeting and collect the remainder of the url
|
||||
@@ -48,6 +29,25 @@ paths:
|
||||
schema:
|
||||
type: string
|
||||
format: path
|
||||
'/greeting/{name}':
|
||||
post:
|
||||
summary: Generate greeting
|
||||
description: Generates a greeting message.
|
||||
operationId: fakeapi.hello.post_greeting
|
||||
responses:
|
||||
'200':
|
||||
description: greeting response
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
type: object
|
||||
parameters:
|
||||
- name: name
|
||||
in: path
|
||||
description: Name of the person to greet.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
'/greetings/{name}':
|
||||
get:
|
||||
summary: Generate greeting
|
||||
@@ -600,11 +600,11 @@ paths:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
formData:
|
||||
fileData:
|
||||
type: string
|
||||
format: binary
|
||||
required:
|
||||
- formData
|
||||
- fileData
|
||||
/test-formData-file-upload-missing-param:
|
||||
post:
|
||||
summary: 'Test formData with file type, missing parameter in handler'
|
||||
@@ -618,11 +618,11 @@ paths:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
missing_formData:
|
||||
missing_fileData:
|
||||
type: string
|
||||
format: binary
|
||||
required:
|
||||
- missing_formData
|
||||
- missing_fileData
|
||||
/test-bool-param:
|
||||
get:
|
||||
summary: Test usage of boolean default value
|
||||
@@ -690,6 +690,24 @@ paths:
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
/goodday/noheader:
|
||||
post:
|
||||
summary: Generate good day greeting
|
||||
description: Generates a good day message.
|
||||
operationId: fakeapi.hello.post_goodday_no_header
|
||||
responses:
|
||||
'201':
|
||||
description: goodday response
|
||||
headers:
|
||||
Location:
|
||||
description: The URI of the created resource
|
||||
schema:
|
||||
type: string
|
||||
required: true
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
type: object
|
||||
'/goodday/{name}':
|
||||
post:
|
||||
summary: Generate good day greeting
|
||||
@@ -715,24 +733,6 @@ paths:
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/goodday/noheader:
|
||||
post:
|
||||
summary: Generate good day greeting
|
||||
description: Generates a good day message.
|
||||
operationId: fakeapi.hello.post_goodday_no_header
|
||||
responses:
|
||||
'201':
|
||||
description: goodday response
|
||||
headers:
|
||||
Location:
|
||||
description: The URI of the created resource
|
||||
schema:
|
||||
type: string
|
||||
required: true
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
type: object
|
||||
'/goodevening/{name}':
|
||||
post:
|
||||
summary: Generate good evening
|
||||
|
||||
36
tests/fixtures/simple/swagger.yaml
vendored
36
tests/fixtures/simple/swagger.yaml
vendored
@@ -7,22 +7,6 @@ info:
|
||||
basePath: /v1.0
|
||||
|
||||
paths:
|
||||
/greeting/{name}:
|
||||
post:
|
||||
summary: Generate greeting
|
||||
description: Generates a greeting message.
|
||||
operationId: fakeapi.hello.post_greeting
|
||||
responses:
|
||||
'200':
|
||||
description: greeting response
|
||||
schema:
|
||||
type: object
|
||||
parameters:
|
||||
- name: name
|
||||
in: path
|
||||
description: Name of the person to greet.
|
||||
required: true
|
||||
type: string
|
||||
/greeting/{name}/{remainder}:
|
||||
post:
|
||||
summary: Generate greeting and collect the remainder of the url
|
||||
@@ -45,6 +29,22 @@ paths:
|
||||
required: true
|
||||
type: string
|
||||
format: path
|
||||
/greeting/{name}:
|
||||
post:
|
||||
summary: Generate greeting
|
||||
description: Generates a greeting message.
|
||||
operationId: fakeapi.hello.post_greeting
|
||||
responses:
|
||||
'200':
|
||||
description: greeting response
|
||||
schema:
|
||||
type: object
|
||||
parameters:
|
||||
- name: name
|
||||
in: path
|
||||
description: Name of the person to greet.
|
||||
required: true
|
||||
type: string
|
||||
/greetings/{name}:
|
||||
get:
|
||||
summary: Generate greeting
|
||||
@@ -457,7 +457,7 @@ paths:
|
||||
consumes:
|
||||
- multipart/form-data
|
||||
parameters:
|
||||
- name: formData
|
||||
- name: fileData
|
||||
type: file
|
||||
in: formData
|
||||
required: true
|
||||
@@ -472,7 +472,7 @@ paths:
|
||||
consumes:
|
||||
- multipart/form-data
|
||||
parameters:
|
||||
- name: missing_formData
|
||||
- name: missing_fileData
|
||||
type: file
|
||||
in: formData
|
||||
required: true
|
||||
|
||||
22
tests/test_datastructures.py
Normal file
22
tests/test_datastructures.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from connexion.datastructures import MediaTypeDict
|
||||
|
||||
|
||||
def test_media_type_dict():
|
||||
d = MediaTypeDict(
|
||||
{
|
||||
"*/*": "*/*",
|
||||
"*/json": "*/json",
|
||||
"*/*json": "*/*json",
|
||||
"multipart/*": "multipart/*",
|
||||
"multipart/form-data": "multipart/form-data",
|
||||
}
|
||||
)
|
||||
|
||||
assert d["application/json"] == "*/json"
|
||||
assert d["application/problem+json"] == "*/*json"
|
||||
assert d["application/x-www-form-urlencoded"] == "*/*"
|
||||
assert d["multipart/form-data"] == "multipart/form-data"
|
||||
assert d["multipart/byteranges"] == "multipart/*"
|
||||
|
||||
# Test __contains__
|
||||
assert "application/json" in d
|
||||
@@ -735,7 +735,6 @@ def test_form_transformation(api):
|
||||
"properties": {
|
||||
"param": {
|
||||
"type": "string",
|
||||
"default": "foo@bar.com",
|
||||
"format": "email",
|
||||
},
|
||||
"array_param": {
|
||||
@@ -746,6 +745,7 @@ def test_form_transformation(api):
|
||||
"nullable": True,
|
||||
},
|
||||
},
|
||||
"default": {"param": "foo@bar.com"},
|
||||
"required": ["param"],
|
||||
},
|
||||
"encoding": {
|
||||
|
||||
@@ -2,8 +2,8 @@ from unittest.mock import MagicMock
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pytest
|
||||
from connexion.decorators.uri_parsing import Swagger2URIParser
|
||||
from connexion.exceptions import BadRequestProblem
|
||||
from connexion.uri_parsing import Swagger2URIParser
|
||||
from connexion.validators.parameter import ParameterValidator
|
||||
from starlette.datastructures import QueryParams
|
||||
|
||||
|
||||
Reference in New Issue
Block a user