mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 04:19:26 +00:00
Move JSON response body validation to middleware (#1591)
* Extract boilerplate code into Routed base classes * Use typing_extensions for Python 3.7 Protocol support * Use Mock instead of AsyncMock * Extract response validation to middleware * Refactor Request validation to match Response validation * Factor out shared functionality * Fix typo in TextResponseBodyValidator class name * Fix string formatting * Use correct schema to check nullability in response validation
This commit is contained in:
@@ -1,135 +0,0 @@
|
||||
"""
|
||||
This module defines a view function decorator to validate its responses.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from jsonschema import ValidationError
|
||||
|
||||
from ..exceptions import NonConformingResponseBody, NonConformingResponseHeaders
|
||||
from ..utils import all_json, has_coroutine
|
||||
from .decorator import BaseDecorator
|
||||
from .validation import ResponseBodyValidator
|
||||
|
||||
logger = logging.getLogger("connexion.decorators.response")
|
||||
|
||||
|
||||
class ResponseValidator(BaseDecorator):
|
||||
def __init__(self, operation, mimetype, validator=None):
|
||||
"""
|
||||
:type operation: Operation
|
||||
:type mimetype: str
|
||||
:param validator: Validator class that should be used to validate passed data
|
||||
against API schema.
|
||||
:type validator: jsonschema.IValidator
|
||||
"""
|
||||
self.operation = operation
|
||||
self.mimetype = mimetype
|
||||
self.validator = validator
|
||||
|
||||
def validate_response(self, data, status_code, headers, url):
|
||||
"""
|
||||
Validates the Response object based on what has been declared in the specification.
|
||||
Ensures the response body matches the declared schema.
|
||||
:type data: dict
|
||||
:type status_code: int
|
||||
:type headers: dict
|
||||
:rtype bool | None
|
||||
"""
|
||||
# check against returned header, fall back to expected mimetype
|
||||
content_type = headers.get("Content-Type", self.mimetype)
|
||||
content_type = content_type.rsplit(";", 1)[
|
||||
0
|
||||
] # remove things like utf8 metadata
|
||||
|
||||
response_definition = self.operation.response_definition(
|
||||
str(status_code), content_type
|
||||
)
|
||||
response_schema = self.operation.response_schema(str(status_code), content_type)
|
||||
|
||||
if self.is_json_schema_compatible(response_schema):
|
||||
v = ResponseBodyValidator(response_schema, validator=self.validator)
|
||||
try:
|
||||
data = self.operation.json_loads(data)
|
||||
v.validate_schema(data, url)
|
||||
except ValidationError as e:
|
||||
raise NonConformingResponseBody(message=str(e))
|
||||
|
||||
if response_definition and response_definition.get("headers"):
|
||||
required_header_keys = {
|
||||
k
|
||||
for (k, v) in response_definition.get("headers").items()
|
||||
if v.get("required", False)
|
||||
}
|
||||
header_keys = set(headers.keys())
|
||||
missing_keys = required_header_keys - header_keys
|
||||
if missing_keys:
|
||||
pretty_list = ", ".join(missing_keys)
|
||||
msg = (
|
||||
"Keys in header don't match response specification. "
|
||||
"Difference: {}"
|
||||
).format(pretty_list)
|
||||
raise NonConformingResponseHeaders(message=msg)
|
||||
return True
|
||||
|
||||
def is_json_schema_compatible(self, response_schema: dict) -> bool:
|
||||
"""
|
||||
Verify if the specified operation responses are JSON schema
|
||||
compatible.
|
||||
|
||||
All operations that specify a JSON schema and have content
|
||||
type "application/json" or "text/plain" can be validated using
|
||||
json_schema package.
|
||||
"""
|
||||
if not response_schema:
|
||||
return False
|
||||
return all_json([self.mimetype]) or self.mimetype == "text/plain"
|
||||
|
||||
def __call__(self, function):
|
||||
"""
|
||||
:type function: types.FunctionType
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
|
||||
def _wrapper(request, response):
|
||||
connexion_response = self.operation.api.get_connexion_response(
|
||||
response, self.mimetype
|
||||
)
|
||||
if not connexion_response.is_streamed:
|
||||
self.validate_response(
|
||||
connexion_response.body,
|
||||
connexion_response.status_code,
|
||||
connexion_response.headers,
|
||||
request.url,
|
||||
)
|
||||
else:
|
||||
logger.warning("Skipping response validation for streamed response.")
|
||||
|
||||
return response
|
||||
|
||||
if has_coroutine(function):
|
||||
|
||||
@functools.wraps(function)
|
||||
async def wrapper(request):
|
||||
response = function(request)
|
||||
while asyncio.iscoroutine(response):
|
||||
response = await response
|
||||
|
||||
return _wrapper(request, response)
|
||||
|
||||
else: # pragma: no cover
|
||||
|
||||
@functools.wraps(function)
|
||||
def wrapper(request):
|
||||
response = function(request)
|
||||
return _wrapper(request, response)
|
||||
|
||||
return wrapper
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
:rtype: str
|
||||
"""
|
||||
return "<ResponseValidator>" # pragma: no cover
|
||||
@@ -14,7 +14,7 @@ from werkzeug.datastructures import FileStorage
|
||||
|
||||
from ..exceptions import BadRequestProblem, ExtraParameterProblem
|
||||
from ..http_facts import FORM_CONTENT_TYPES
|
||||
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
|
||||
from ..json_schema import Draft4RequestValidator
|
||||
from ..lifecycle import ConnexionResponse
|
||||
from ..utils import boolean, is_null, is_nullable
|
||||
|
||||
@@ -196,29 +196,6 @@ class RequestBodyValidator:
|
||||
return None
|
||||
|
||||
|
||||
class ResponseBodyValidator:
|
||||
def __init__(self, schema, validator=None):
|
||||
"""
|
||||
:param schema: The schema of the response body
|
||||
:param validator: Validator class that should be used to validate passed data
|
||||
against API schema. Default is Draft4ResponseValidator.
|
||||
:type validator: jsonschema.IValidator
|
||||
"""
|
||||
ValidatorClass = validator or Draft4ResponseValidator
|
||||
self.validator = ValidatorClass(schema, format_checker=draft4_format_checker)
|
||||
|
||||
def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]:
|
||||
try:
|
||||
self.validator.validate(data)
|
||||
except ValidationError as exception:
|
||||
logger.error(
|
||||
f"{url} validation error: {exception}", extra={"validator": "response"}
|
||||
)
|
||||
raise exception
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ParameterValidator:
|
||||
def __init__(self, parameters, api, strict_validation=False):
|
||||
"""
|
||||
|
||||
@@ -5,10 +5,11 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.middleware.abstract import AppMiddleware
|
||||
from connexion.middleware.exceptions import ExceptionMiddleware
|
||||
from connexion.middleware.request_validation import RequestValidationMiddleware
|
||||
from connexion.middleware.response_validation import ResponseValidationMiddleware
|
||||
from connexion.middleware.routing import RoutingMiddleware
|
||||
from connexion.middleware.security import SecurityMiddleware
|
||||
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
||||
from connexion.middleware.validation import ValidationMiddleware
|
||||
|
||||
|
||||
class ConnexionMiddleware:
|
||||
@@ -18,7 +19,8 @@ class ConnexionMiddleware:
|
||||
SwaggerUIMiddleware,
|
||||
RoutingMiddleware,
|
||||
SecurityMiddleware,
|
||||
ValidationMiddleware,
|
||||
RequestValidationMiddleware,
|
||||
ResponseValidationMiddleware,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -6,71 +6,51 @@ import typing as t
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion import utils
|
||||
from connexion.decorators.uri_parsing import AbstractURIParser
|
||||
from connexion.exceptions import UnsupportedMediaTypeProblem
|
||||
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
|
||||
from connexion.operations import AbstractOperation
|
||||
from connexion.utils import is_nullable
|
||||
from connexion.validators import JSONBodyValidator
|
||||
|
||||
from ..decorators.response import ResponseValidator
|
||||
from ..decorators.validation import ParameterValidator
|
||||
from connexion.validators import VALIDATOR_MAP
|
||||
|
||||
logger = logging.getLogger("connexion.middleware.validation")
|
||||
|
||||
VALIDATOR_MAP = {
|
||||
"parameter": ParameterValidator,
|
||||
"body": {"application/json": JSONBodyValidator},
|
||||
"response": ResponseValidator,
|
||||
}
|
||||
|
||||
|
||||
class ValidationOperation:
|
||||
class RequestValidationOperation:
|
||||
def __init__(
|
||||
self,
|
||||
next_app: ASGIApp,
|
||||
*,
|
||||
operation: AbstractOperation,
|
||||
validate_responses: bool = False,
|
||||
strict_validation: bool = False,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
) -> None:
|
||||
self.next_app = next_app
|
||||
self._operation = operation
|
||||
self.validate_responses = validate_responses
|
||||
self.strict_validation = strict_validation
|
||||
self._validator_map = VALIDATOR_MAP
|
||||
self._validator_map.update(validator_map or {})
|
||||
self.uri_parser_class = uri_parser_class
|
||||
|
||||
def extract_content_type(self, headers: dict) -> t.Tuple[str, str]:
|
||||
def extract_content_type(
|
||||
self, headers: t.List[t.Tuple[bytes, bytes]]
|
||||
) -> t.Tuple[str, str]:
|
||||
"""Extract the mime type and encoding from the content type headers.
|
||||
|
||||
:param headers: Header dict from ASGI scope
|
||||
:param headers: Headers from ASGI scope
|
||||
|
||||
:return: A tuple of mime type, encoding
|
||||
"""
|
||||
encoding = "utf-8"
|
||||
for key, value in headers:
|
||||
# Headers can always be decoded using latin-1:
|
||||
# https://stackoverflow.com/a/27357138/4098821
|
||||
key = key.decode("latin-1")
|
||||
if key.lower() == "content-type":
|
||||
content_type = value.decode("latin-1")
|
||||
if ";" in content_type:
|
||||
mime_type, parameters = content_type.split(";", maxsplit=1)
|
||||
|
||||
prefix = "charset="
|
||||
for parameter in parameters.split(";"):
|
||||
if parameter.startswith(prefix):
|
||||
encoding = parameter[len(prefix) :]
|
||||
else:
|
||||
mime_type = content_type
|
||||
break
|
||||
else:
|
||||
mime_type, encoding = utils.extract_content_type(headers)
|
||||
if mime_type is None:
|
||||
# Content-type header is not required. Take a best guess.
|
||||
try:
|
||||
mime_type = self._operation.consumes[0]
|
||||
except IndexError:
|
||||
mime_type = "application/octet-stream"
|
||||
if encoding is None:
|
||||
encoding = "utf-8"
|
||||
|
||||
return mime_type, encoding
|
||||
|
||||
@@ -86,6 +66,8 @@ class ValidationOperation:
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
receive_fn = receive
|
||||
|
||||
headers = scope["headers"]
|
||||
mime_type, encoding = self.extract_content_type(headers)
|
||||
self.validate_mime_type(mime_type)
|
||||
@@ -102,25 +84,25 @@ class ValidationOperation:
|
||||
)
|
||||
else:
|
||||
validator = body_validator(
|
||||
self.next_app,
|
||||
scope,
|
||||
receive,
|
||||
schema=self._operation.body_schema,
|
||||
nullable=is_nullable(self._operation.body_definition),
|
||||
nullable=utils.is_nullable(self._operation.body_definition),
|
||||
encoding=encoding,
|
||||
)
|
||||
return await validator(scope, receive, send)
|
||||
receive_fn = validator.receive
|
||||
|
||||
await self.next_app(scope, receive, send)
|
||||
await self.next_app(scope, receive_fn, send)
|
||||
|
||||
|
||||
class ValidationAPI(RoutedAPI[ValidationOperation]):
|
||||
class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
|
||||
"""Validation API."""
|
||||
|
||||
operation_cls = ValidationOperation
|
||||
operation_cls = RequestValidationOperation
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
validate_responses=False,
|
||||
strict_validation=False,
|
||||
validator_map=None,
|
||||
uri_parser_class=None,
|
||||
@@ -129,9 +111,6 @@ class ValidationAPI(RoutedAPI[ValidationOperation]):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.validator_map = validator_map
|
||||
|
||||
logger.debug("Validate Responses: %s", str(validate_responses))
|
||||
self.validate_responses = validate_responses
|
||||
|
||||
logger.debug("Strict Request Validation: %s", str(strict_validation))
|
||||
self.strict_validation = strict_validation
|
||||
|
||||
@@ -139,21 +118,22 @@ class ValidationAPI(RoutedAPI[ValidationOperation]):
|
||||
|
||||
self.add_paths()
|
||||
|
||||
def make_operation(self, operation: AbstractOperation) -> ValidationOperation:
|
||||
return ValidationOperation(
|
||||
def make_operation(
|
||||
self, operation: AbstractOperation
|
||||
) -> RequestValidationOperation:
|
||||
return RequestValidationOperation(
|
||||
self.next_app,
|
||||
operation=operation,
|
||||
validate_responses=self.validate_responses,
|
||||
strict_validation=self.strict_validation,
|
||||
validator_map=self.validator_map,
|
||||
uri_parser_class=self.uri_parser_class,
|
||||
)
|
||||
|
||||
|
||||
class ValidationMiddleware(RoutedMiddleware[ValidationAPI]):
|
||||
class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]):
|
||||
"""Middleware for validating requests according to the API contract."""
|
||||
|
||||
api_cls = ValidationAPI
|
||||
api_cls = RequestValidationAPI
|
||||
|
||||
|
||||
class MissingValidationOperation(Exception):
|
||||
158
connexion/middleware/response_validation.py
Normal file
158
connexion/middleware/response_validation.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Validation Middleware.
|
||||
"""
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion import utils
|
||||
from connexion.exceptions import NonConformingResponseHeaders
|
||||
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
|
||||
from connexion.operations import AbstractOperation
|
||||
from connexion.validators import VALIDATOR_MAP
|
||||
|
||||
logger = logging.getLogger("connexion.middleware.validation")
|
||||
|
||||
|
||||
class ResponseValidationOperation:
|
||||
def __init__(
|
||||
self,
|
||||
next_app: ASGIApp,
|
||||
*,
|
||||
operation: AbstractOperation,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
) -> None:
|
||||
self.next_app = next_app
|
||||
self._operation = operation
|
||||
self._validator_map = VALIDATOR_MAP
|
||||
self._validator_map.update(validator_map or {})
|
||||
|
||||
def extract_content_type(
|
||||
self, headers: t.List[t.Tuple[bytes, bytes]]
|
||||
) -> t.Tuple[str, str]:
|
||||
"""Extract the mime type and encoding from the content type headers.
|
||||
|
||||
:param headers: Headers from ASGI scope
|
||||
|
||||
:return: A tuple of mime type, encoding
|
||||
"""
|
||||
mime_type, encoding = utils.extract_content_type(headers)
|
||||
if mime_type is None:
|
||||
# Content-type header is not required. Take a best guess.
|
||||
try:
|
||||
mime_type = self._operation.produces[0]
|
||||
except IndexError:
|
||||
mime_type = "application/octet-stream"
|
||||
if encoding is None:
|
||||
encoding = "utf-8"
|
||||
|
||||
return mime_type, encoding
|
||||
|
||||
def validate_mime_type(self, mime_type: str) -> None:
|
||||
"""Validate the mime type against the spec.
|
||||
|
||||
:param mime_type: mime type from content type header
|
||||
"""
|
||||
if mime_type.lower() not in [c.lower() for c in self._operation.produces]:
|
||||
raise NonConformingResponseHeaders(
|
||||
reason="Invalid Response Content-type",
|
||||
message=f"Invalid Response Content-type ({mime_type}), "
|
||||
f"expected {self._operation.produces}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_required_headers(
|
||||
headers: t.List[tuple], response_definition: dict
|
||||
) -> None:
|
||||
required_header_keys = {
|
||||
k.lower()
|
||||
for (k, v) in response_definition.get("headers", {}).items()
|
||||
if v.get("required", False)
|
||||
}
|
||||
header_keys = set(header[0].decode("latin-1").lower() for header in headers)
|
||||
missing_keys = required_header_keys - header_keys
|
||||
if missing_keys:
|
||||
pretty_list = ", ".join(missing_keys)
|
||||
msg = (
|
||||
"Keys in header don't match response specification. Difference: {}"
|
||||
).format(pretty_list)
|
||||
raise NonConformingResponseHeaders(message=msg)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
|
||||
send_fn = send
|
||||
|
||||
async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
|
||||
nonlocal send_fn
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
status = str(message["status"])
|
||||
headers = message["headers"]
|
||||
mime_type, encoding = self.extract_content_type(headers)
|
||||
# TODO: Add produces to all tests and fix response content types
|
||||
# self.validate_mime_type(mime_type)
|
||||
response_definition = self._operation.response_definition(
|
||||
status, mime_type
|
||||
)
|
||||
self.validate_required_headers(headers, response_definition)
|
||||
|
||||
# Validate body
|
||||
try:
|
||||
body_validator = self._validator_map["response"][mime_type] # type: ignore
|
||||
except KeyError:
|
||||
logging.info(
|
||||
f"Skipping validation. No validator registered for content type: "
|
||||
f"{mime_type}."
|
||||
)
|
||||
else:
|
||||
validator = body_validator(
|
||||
scope,
|
||||
send,
|
||||
schema=self._operation.response_schema(status, mime_type),
|
||||
nullable=utils.is_nullable(
|
||||
self._operation.response_definition(status, mime_type)
|
||||
),
|
||||
encoding=encoding,
|
||||
)
|
||||
send_fn = validator.send
|
||||
|
||||
return await send_fn(message)
|
||||
|
||||
await self.next_app(scope, receive, wrapped_send)
|
||||
|
||||
|
||||
class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]):
|
||||
"""Validation API."""
|
||||
|
||||
operation_cls = ResponseValidationOperation
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
validator_map=None,
|
||||
validate_responses=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.validator_map = validator_map
|
||||
self.validate_responses = validate_responses
|
||||
self.add_paths()
|
||||
|
||||
def make_operation(
|
||||
self, operation: AbstractOperation
|
||||
) -> ResponseValidationOperation:
|
||||
if self.validate_responses:
|
||||
return ResponseValidationOperation(
|
||||
self.next_app,
|
||||
operation=operation,
|
||||
validator_map=self.validator_map,
|
||||
)
|
||||
else:
|
||||
return self.next_app # type: ignore
|
||||
|
||||
|
||||
class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]):
|
||||
"""Middleware for validating requests according to the API contract."""
|
||||
|
||||
api_cls = ResponseValidationAPI
|
||||
@@ -9,7 +9,6 @@ import logging
|
||||
from ..decorators.decorator import RequestResponseDecorator
|
||||
from ..decorators.parameter import parameter_to_arg
|
||||
from ..decorators.produces import BaseSerializer, Produces
|
||||
from ..decorators.response import ResponseValidator
|
||||
from ..decorators.validation import ParameterValidator, RequestBodyValidator
|
||||
from ..utils import all_json, is_nullable
|
||||
|
||||
@@ -20,7 +19,6 @@ DEFAULT_MIMETYPE = "application/json"
|
||||
VALIDATOR_MAP = {
|
||||
"parameter": ParameterValidator,
|
||||
"body": RequestBodyValidator,
|
||||
"response": ResponseValidator,
|
||||
}
|
||||
|
||||
|
||||
@@ -389,12 +387,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
self.pythonic_params,
|
||||
)
|
||||
|
||||
if self.validate_responses:
|
||||
logger.debug("... Response validation enabled.")
|
||||
response_decorator = self.__response_validation_decorator
|
||||
logger.debug("... Adding response decorator (%r)", response_decorator)
|
||||
function = response_decorator(function)
|
||||
|
||||
produces_decorator = self.__content_type_decorator
|
||||
logger.debug("... Adding produces decorator (%r)", produces_decorator)
|
||||
function = produces_decorator(function)
|
||||
@@ -473,15 +465,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
strict_validation=self.strict_validation,
|
||||
)
|
||||
|
||||
@property
|
||||
def __response_validation_decorator(self):
|
||||
"""
|
||||
Get a decorator for validating the generated Response.
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
ResponseValidator = self.validator_map["response"]
|
||||
return ResponseValidator(self, self.get_mimetype())
|
||||
|
||||
def json_loads(self, data):
|
||||
"""
|
||||
A wrapper for calling the API specific JSON loader.
|
||||
|
||||
@@ -5,6 +5,7 @@ This module provides general utility functions used within Connexion.
|
||||
import asyncio
|
||||
import functools
|
||||
import importlib
|
||||
import typing as t
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -266,3 +267,32 @@ def not_installed_error(exc): # pragma: no cover
|
||||
raise exc
|
||||
|
||||
return functools.partial(_required_lib, exc)
|
||||
|
||||
|
||||
def extract_content_type(
|
||||
headers: t.List[t.Tuple[bytes, bytes]]
|
||||
) -> t.Tuple[t.Optional[str], t.Optional[str]]:
|
||||
"""Extract the mime type and encoding from the content type headers.
|
||||
|
||||
:param headers: Headers from ASGI scope
|
||||
|
||||
:return: A tuple of mime type, encoding
|
||||
"""
|
||||
mime_type, encoding = None, None
|
||||
for key, value in headers:
|
||||
# Headers can always be decoded using latin-1:
|
||||
# https://stackoverflow.com/a/27357138/4098821
|
||||
decoded_key = key.decode("latin-1")
|
||||
if decoded_key.lower() == "content-type":
|
||||
content_type = value.decode("latin-1")
|
||||
if ";" in content_type:
|
||||
mime_type, parameters = content_type.split(";", maxsplit=1)
|
||||
|
||||
prefix = "charset="
|
||||
for parameter in parameters.split(";"):
|
||||
if parameter.startswith(prefix):
|
||||
encoding = parameter[len(prefix) :]
|
||||
else:
|
||||
mime_type = content_type
|
||||
break
|
||||
return mime_type, encoding
|
||||
|
||||
@@ -6,36 +6,38 @@ import logging
|
||||
import typing as t
|
||||
|
||||
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from connexion.exceptions import BadRequestProblem
|
||||
from connexion.json_schema import Draft4RequestValidator
|
||||
from connexion.decorators.validation import ParameterValidator
|
||||
from connexion.exceptions import BadRequestProblem, NonConformingResponseBody
|
||||
from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator
|
||||
from connexion.utils import is_null
|
||||
|
||||
logger = logging.getLogger("connexion.middleware.validators")
|
||||
|
||||
|
||||
class JSONBodyValidator:
|
||||
class JSONRequestBodyValidator:
|
||||
"""Request body validator for json content types."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
next_app: ASGIApp,
|
||||
scope: Scope,
|
||||
receive: Receive,
|
||||
*,
|
||||
schema: dict,
|
||||
validator: t.Type[Draft4Validator] = None,
|
||||
nullable=False,
|
||||
encoding: str,
|
||||
) -> None:
|
||||
self.next_app = next_app
|
||||
self._scope = scope
|
||||
self._receive = receive
|
||||
self.schema = schema
|
||||
self.has_default = schema.get("default", False)
|
||||
self.nullable = nullable
|
||||
self.validator_cls = validator or Draft4RequestValidator
|
||||
self.validator = self.validator_cls(
|
||||
schema, format_checker=draft4_format_checker
|
||||
)
|
||||
validator_cls = validator or Draft4RequestValidator
|
||||
self.validator = validator_cls(schema, format_checker=draft4_format_checker)
|
||||
self.encoding = encoding
|
||||
self._messages: t.List[t.MutableMapping[str, t.Any]] = []
|
||||
|
||||
@classmethod
|
||||
def _error_path_message(cls, exception):
|
||||
@@ -44,7 +46,6 @@ class JSONBodyValidator:
|
||||
return error_path_msg
|
||||
|
||||
def validate(self, body: dict):
|
||||
|
||||
try:
|
||||
self.validator.validate(body)
|
||||
except ValidationError as exception:
|
||||
@@ -55,18 +56,15 @@ class JSONBodyValidator:
|
||||
)
|
||||
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Based on https://github.com/encode/starlette/pull/1519#issuecomment-1060633787
|
||||
# Ingest all body messages from the ASGI `receive` callable.
|
||||
messages = []
|
||||
async def receive(self) -> t.Optional[t.MutableMapping[str, t.Any]]:
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
messages.append(message)
|
||||
message = await self._receive()
|
||||
self._messages.append(message)
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# TODO: make json library pluggable
|
||||
bytes_body = b"".join([message.get("body", b"") for message in messages])
|
||||
bytes_body = b"".join([message.get("body", b"") for message in self._messages])
|
||||
decoded_body = bytes_body.decode(self.encoding)
|
||||
|
||||
if decoded_body and not (self.nullable and is_null(decoded_body)):
|
||||
@@ -77,11 +75,92 @@ class JSONBodyValidator:
|
||||
|
||||
self.validate(body)
|
||||
|
||||
async def wrapped_receive():
|
||||
# First up we want to return any messages we've stashed.
|
||||
if messages:
|
||||
return messages.pop(0)
|
||||
# Once that's done we can just await any other messages.
|
||||
return await receive()
|
||||
while self._messages:
|
||||
return self._messages.pop(0)
|
||||
return None
|
||||
|
||||
await self.next_app(scope, wrapped_receive, send)
|
||||
|
||||
class JSONResponseBodyValidator:
|
||||
"""Response body validator for json content types."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: Scope,
|
||||
send: Send,
|
||||
*,
|
||||
schema: dict,
|
||||
validator: t.Type[Draft4Validator] = None,
|
||||
nullable=False,
|
||||
encoding: str,
|
||||
) -> None:
|
||||
self._scope = scope
|
||||
self._send = send
|
||||
self.schema = schema
|
||||
self.has_default = schema.get("default", False)
|
||||
self.nullable = nullable
|
||||
validator_cls = validator or Draft4ResponseValidator
|
||||
self.validator = validator_cls(schema, format_checker=draft4_format_checker)
|
||||
self.encoding = encoding
|
||||
self._messages: t.List[t.MutableMapping[str, t.Any]] = []
|
||||
|
||||
@classmethod
|
||||
def _error_path_message(cls, exception):
|
||||
error_path = ".".join(str(item) for item in exception.path)
|
||||
error_path_msg = f" - '{error_path}'" if error_path else ""
|
||||
return error_path_msg
|
||||
|
||||
def validate(self, body: dict):
|
||||
try:
|
||||
self.validator.validate(body)
|
||||
except ValidationError as exception:
|
||||
error_path_msg = self._error_path_message(exception=exception)
|
||||
logger.error(
|
||||
f"Validation error: {exception.message}{error_path_msg}",
|
||||
extra={"validator": "body"},
|
||||
)
|
||||
raise NonConformingResponseBody(
|
||||
message=f"{exception.message}{error_path_msg}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse(body: str) -> dict:
|
||||
try:
|
||||
return json.loads(body)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
raise BadRequestProblem(str(e))
|
||||
|
||||
async def send(self, message: t.MutableMapping[str, t.Any]) -> None:
|
||||
self._messages.append(message)
|
||||
|
||||
if message["type"] == "http.response.start" or message.get("more_body", False):
|
||||
return
|
||||
|
||||
# TODO: make json library pluggable
|
||||
bytes_body = b"".join([message.get("body", b"") for message in self._messages])
|
||||
decoded_body = bytes_body.decode(self.encoding)
|
||||
|
||||
if decoded_body and not (self.nullable and is_null(decoded_body)):
|
||||
body = self.parse(decoded_body)
|
||||
self.validate(body)
|
||||
|
||||
while self._messages:
|
||||
await self._send(self._messages.pop(0))
|
||||
|
||||
|
||||
class TextResponseBodyValidator(JSONResponseBodyValidator):
|
||||
@staticmethod
|
||||
def parse(body: str) -> str: # type: ignore
|
||||
try:
|
||||
return json.loads(body)
|
||||
except json.decoder.JSONDecodeError:
|
||||
return body
|
||||
|
||||
|
||||
VALIDATOR_MAP = {
|
||||
"parameter": ParameterValidator,
|
||||
"body": {"application/json": JSONRequestBodyValidator},
|
||||
"response": {
|
||||
"application/json": JSONResponseBodyValidator,
|
||||
"text/plain": TextResponseBodyValidator,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ def test_header_not_returned(simple_openapi_app):
|
||||
assert data["title"] == "Response headers do not conform to specification"
|
||||
assert (
|
||||
data["detail"]
|
||||
== "Keys in header don't match response specification. Difference: Location"
|
||||
== "Keys in header don't match response specification. Difference: location"
|
||||
)
|
||||
assert data["status"] == 500
|
||||
|
||||
|
||||
@@ -80,59 +80,59 @@ def test_schema_response(schema_app):
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/object/valid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 200
|
||||
assert request.status_code == 200, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/object/invalid_type", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/object/invalid_requirements", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/string/valid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 200
|
||||
assert request.status_code == 200, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/string/invalid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/integer/valid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 200
|
||||
assert request.status_code == 200, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/integer/invalid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/number/valid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 200
|
||||
assert request.status_code == 200, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/number/invalid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/boolean/valid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 200
|
||||
assert request.status_code == 200, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/boolean/invalid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/array/valid", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 200
|
||||
assert request.status_code == 200, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/array/invalid_dict", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
request = app_client.get(
|
||||
"/v1.0/test_schema/response/array/invalid_string", headers={}, data=None
|
||||
) # type: flask.Response
|
||||
assert request.status_code == 500
|
||||
assert request.status_code == 500, request.text
|
||||
|
||||
|
||||
def test_schema_in_query(schema_app):
|
||||
|
||||
@@ -6,19 +6,22 @@ class PetsView(MethodView):
|
||||
mycontent = "demonstrate return from MethodView class"
|
||||
|
||||
def get(self, **kwargs):
|
||||
kwargs.update({"method": "get"})
|
||||
if kwargs:
|
||||
kwargs.update({"name": "get"})
|
||||
return kwargs
|
||||
else:
|
||||
return [{"name": "get"}]
|
||||
|
||||
def search(self):
|
||||
return "search"
|
||||
return [{"name": "search"}]
|
||||
|
||||
def post(self, **kwargs):
|
||||
kwargs.update({"method": "post"})
|
||||
return kwargs
|
||||
kwargs.update({"name": "post"})
|
||||
return kwargs, 201
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
kwargs.update({"method": "put"})
|
||||
return kwargs
|
||||
kwargs.update({"name": "put"})
|
||||
return kwargs, 201
|
||||
|
||||
# Test that operation_id can still override resolver
|
||||
|
||||
|
||||
@@ -3,10 +3,9 @@ import pathlib
|
||||
|
||||
import pytest
|
||||
from connexion import App
|
||||
from connexion.decorators.validation import RequestBodyValidator
|
||||
from connexion.json_schema import Draft4RequestValidator
|
||||
from connexion.spec import Specification
|
||||
from connexion.validators import JSONBodyValidator
|
||||
from connexion.validators import JSONRequestBodyValidator
|
||||
from jsonschema.validators import _utils, extend
|
||||
|
||||
from conftest import build_app_from_fixture
|
||||
@@ -31,7 +30,7 @@ def test_validator_map(json_validation_spec_dir, spec):
|
||||
|
||||
MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type})
|
||||
|
||||
class MyJSONBodyValidator(JSONBodyValidator):
|
||||
class MyJSONBodyValidator(JSONRequestBodyValidator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, validator=MinLengthRequestValidator, **kwargs)
|
||||
|
||||
|
||||
@@ -192,13 +192,13 @@ def test_method_view_resolver_integration(method_view_app):
|
||||
client = method_view_app.app.test_client()
|
||||
|
||||
r = client.get("/v1.0/pets")
|
||||
assert r.json == {"method": "get"}
|
||||
assert r.json == [{"name": "get"}]
|
||||
|
||||
r = client.get("/v1.0/pets/1")
|
||||
assert r.json == {"method": "get", "petId": 1}
|
||||
assert r.json == {"name": "get", "petId": 1}
|
||||
|
||||
r = client.post("/v1.0/pets", json={"name": "Musti"})
|
||||
assert r.json == {"method": "post", "body": {"name": "Musti"}}
|
||||
assert r.json == {"name": "post", "body": {"name": "Musti"}}
|
||||
|
||||
r = client.put("/v1.0/pets/1", json={"name": "Igor"})
|
||||
assert r.json == {"method": "put", "petId": 1, "body": {"name": "Igor"}}
|
||||
assert r.json == {"name": "put", "petId": 1, "body": {"name": "Igor"}}
|
||||
|
||||
Reference in New Issue
Block a user