Move JSON response body validation to middleware (#1591)

* Extract boilerplate code into Routed base classes

* Use typing_extensions for Python 3.7 Protocol support

* Use Mock instead of AsyncMock

* Extract response validation to middleware

* Refactor Request validation to match Response validation

* Factor out shared functionality

* Fix typo in TextResponseBodyValidator class name

* Fix string formatting

* Use correct schema to check nullability in response validation
This commit is contained in:
Robbe Sneyders
2022-10-03 23:01:21 +02:00
committed by GitHub
parent 181c61bfb6
commit 1ab5400c0b
13 changed files with 358 additions and 282 deletions

View File

@@ -1,135 +0,0 @@
"""
This module defines a view function decorator to validate its responses.
"""
import asyncio
import functools
import logging
from jsonschema import ValidationError
from ..exceptions import NonConformingResponseBody, NonConformingResponseHeaders
from ..utils import all_json, has_coroutine
from .decorator import BaseDecorator
from .validation import ResponseBodyValidator
logger = logging.getLogger("connexion.decorators.response")
class ResponseValidator(BaseDecorator):
def __init__(self, operation, mimetype, validator=None):
"""
:type operation: Operation
:type mimetype: str
:param validator: Validator class that should be used to validate passed data
against API schema.
:type validator: jsonschema.IValidator
"""
self.operation = operation
self.mimetype = mimetype
self.validator = validator
def validate_response(self, data, status_code, headers, url):
"""
Validates the Response object based on what has been declared in the specification.
Ensures the response body matches the declared schema.
:type data: dict
:type status_code: int
:type headers: dict
:rtype bool | None
"""
# check against returned header, fall back to expected mimetype
content_type = headers.get("Content-Type", self.mimetype)
content_type = content_type.rsplit(";", 1)[
0
] # remove things like utf8 metadata
response_definition = self.operation.response_definition(
str(status_code), content_type
)
response_schema = self.operation.response_schema(str(status_code), content_type)
if self.is_json_schema_compatible(response_schema):
v = ResponseBodyValidator(response_schema, validator=self.validator)
try:
data = self.operation.json_loads(data)
v.validate_schema(data, url)
except ValidationError as e:
raise NonConformingResponseBody(message=str(e))
if response_definition and response_definition.get("headers"):
required_header_keys = {
k
for (k, v) in response_definition.get("headers").items()
if v.get("required", False)
}
header_keys = set(headers.keys())
missing_keys = required_header_keys - header_keys
if missing_keys:
pretty_list = ", ".join(missing_keys)
msg = (
"Keys in header don't match response specification. "
"Difference: {}"
).format(pretty_list)
raise NonConformingResponseHeaders(message=msg)
return True
def is_json_schema_compatible(self, response_schema: dict) -> bool:
"""
Verify if the specified operation responses are JSON schema
compatible.
All operations that specify a JSON schema and have content
type "application/json" or "text/plain" can be validated using
json_schema package.
"""
if not response_schema:
return False
return all_json([self.mimetype]) or self.mimetype == "text/plain"
def __call__(self, function):
"""
:type function: types.FunctionType
:rtype: types.FunctionType
"""
def _wrapper(request, response):
connexion_response = self.operation.api.get_connexion_response(
response, self.mimetype
)
if not connexion_response.is_streamed:
self.validate_response(
connexion_response.body,
connexion_response.status_code,
connexion_response.headers,
request.url,
)
else:
logger.warning("Skipping response validation for streamed response.")
return response
if has_coroutine(function):
@functools.wraps(function)
async def wrapper(request):
response = function(request)
while asyncio.iscoroutine(response):
response = await response
return _wrapper(request, response)
else: # pragma: no cover
@functools.wraps(function)
def wrapper(request):
response = function(request)
return _wrapper(request, response)
return wrapper
def __repr__(self):
"""
:rtype: str
"""
return "<ResponseValidator>" # pragma: no cover

View File

@@ -14,7 +14,7 @@ from werkzeug.datastructures import FileStorage
from ..exceptions import BadRequestProblem, ExtraParameterProblem from ..exceptions import BadRequestProblem, ExtraParameterProblem
from ..http_facts import FORM_CONTENT_TYPES from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator from ..json_schema import Draft4RequestValidator
from ..lifecycle import ConnexionResponse from ..lifecycle import ConnexionResponse
from ..utils import boolean, is_null, is_nullable from ..utils import boolean, is_null, is_nullable
@@ -196,29 +196,6 @@ class RequestBodyValidator:
return None return None
class ResponseBodyValidator:
def __init__(self, schema, validator=None):
"""
:param schema: The schema of the response body
:param validator: Validator class that should be used to validate passed data
against API schema. Default is Draft4ResponseValidator.
:type validator: jsonschema.IValidator
"""
ValidatorClass = validator or Draft4ResponseValidator
self.validator = ValidatorClass(schema, format_checker=draft4_format_checker)
def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]:
try:
self.validator.validate(data)
except ValidationError as exception:
logger.error(
f"{url} validation error: {exception}", extra={"validator": "response"}
)
raise exception
return None
class ParameterValidator: class ParameterValidator:
def __init__(self, parameters, api, strict_validation=False): def __init__(self, parameters, api, strict_validation=False):
""" """

View File

@@ -5,10 +5,11 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.middleware.abstract import AppMiddleware from connexion.middleware.abstract import AppMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.middleware.validation import ValidationMiddleware
class ConnexionMiddleware: class ConnexionMiddleware:
@@ -18,7 +19,8 @@ class ConnexionMiddleware:
SwaggerUIMiddleware, SwaggerUIMiddleware,
RoutingMiddleware, RoutingMiddleware,
SecurityMiddleware, SecurityMiddleware,
ValidationMiddleware, RequestValidationMiddleware,
ResponseValidationMiddleware,
] ]
def __init__( def __init__(

View File

@@ -6,71 +6,51 @@ import typing as t
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion import utils
from connexion.decorators.uri_parsing import AbstractURIParser from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import UnsupportedMediaTypeProblem from connexion.exceptions import UnsupportedMediaTypeProblem
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation from connexion.operations import AbstractOperation
from connexion.utils import is_nullable from connexion.validators import VALIDATOR_MAP
from connexion.validators import JSONBodyValidator
from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator
logger = logging.getLogger("connexion.middleware.validation") logger = logging.getLogger("connexion.middleware.validation")
VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {"application/json": JSONBodyValidator},
"response": ResponseValidator,
}
class RequestValidationOperation:
class ValidationOperation:
def __init__( def __init__(
self, self,
next_app: ASGIApp, next_app: ASGIApp,
*, *,
operation: AbstractOperation, operation: AbstractOperation,
validate_responses: bool = False,
strict_validation: bool = False, strict_validation: bool = False,
validator_map: t.Optional[dict] = None, validator_map: t.Optional[dict] = None,
uri_parser_class: t.Optional[AbstractURIParser] = None, uri_parser_class: t.Optional[AbstractURIParser] = None,
) -> None: ) -> None:
self.next_app = next_app self.next_app = next_app
self._operation = operation self._operation = operation
self.validate_responses = validate_responses
self.strict_validation = strict_validation self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP self._validator_map = VALIDATOR_MAP
self._validator_map.update(validator_map or {}) self._validator_map.update(validator_map or {})
self.uri_parser_class = uri_parser_class self.uri_parser_class = uri_parser_class
def extract_content_type(self, headers: dict) -> t.Tuple[str, str]: def extract_content_type(
self, headers: t.List[t.Tuple[bytes, bytes]]
) -> t.Tuple[str, str]:
"""Extract the mime type and encoding from the content type headers. """Extract the mime type and encoding from the content type headers.
:param headers: Header dict from ASGI scope :param headers: Headers from ASGI scope
:return: A tuple of mime type, encoding :return: A tuple of mime type, encoding
""" """
encoding = "utf-8" mime_type, encoding = utils.extract_content_type(headers)
for key, value in headers: if mime_type is None:
# Headers can always be decoded using latin-1:
# https://stackoverflow.com/a/27357138/4098821
key = key.decode("latin-1")
if key.lower() == "content-type":
content_type = value.decode("latin-1")
if ";" in content_type:
mime_type, parameters = content_type.split(";", maxsplit=1)
prefix = "charset="
for parameter in parameters.split(";"):
if parameter.startswith(prefix):
encoding = parameter[len(prefix) :]
else:
mime_type = content_type
break
else:
# Content-type header is not required. Take a best guess. # Content-type header is not required. Take a best guess.
mime_type = self._operation.consumes[0] try:
mime_type = self._operation.consumes[0]
except IndexError:
mime_type = "application/octet-stream"
if encoding is None:
encoding = "utf-8"
return mime_type, encoding return mime_type, encoding
@@ -86,6 +66,8 @@ class ValidationOperation:
) )
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
receive_fn = receive
headers = scope["headers"] headers = scope["headers"]
mime_type, encoding = self.extract_content_type(headers) mime_type, encoding = self.extract_content_type(headers)
self.validate_mime_type(mime_type) self.validate_mime_type(mime_type)
@@ -102,25 +84,25 @@ class ValidationOperation:
) )
else: else:
validator = body_validator( validator = body_validator(
self.next_app, scope,
receive,
schema=self._operation.body_schema, schema=self._operation.body_schema,
nullable=is_nullable(self._operation.body_definition), nullable=utils.is_nullable(self._operation.body_definition),
encoding=encoding, encoding=encoding,
) )
return await validator(scope, receive, send) receive_fn = validator.receive
await self.next_app(scope, receive, send) await self.next_app(scope, receive_fn, send)
class ValidationAPI(RoutedAPI[ValidationOperation]): class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
"""Validation API.""" """Validation API."""
operation_cls = ValidationOperation operation_cls = RequestValidationOperation
def __init__( def __init__(
self, self,
*args, *args,
validate_responses=False,
strict_validation=False, strict_validation=False,
validator_map=None, validator_map=None,
uri_parser_class=None, uri_parser_class=None,
@@ -129,9 +111,6 @@ class ValidationAPI(RoutedAPI[ValidationOperation]):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.validator_map = validator_map self.validator_map = validator_map
logger.debug("Validate Responses: %s", str(validate_responses))
self.validate_responses = validate_responses
logger.debug("Strict Request Validation: %s", str(strict_validation)) logger.debug("Strict Request Validation: %s", str(strict_validation))
self.strict_validation = strict_validation self.strict_validation = strict_validation
@@ -139,21 +118,22 @@ class ValidationAPI(RoutedAPI[ValidationOperation]):
self.add_paths() self.add_paths()
def make_operation(self, operation: AbstractOperation) -> ValidationOperation: def make_operation(
return ValidationOperation( self, operation: AbstractOperation
) -> RequestValidationOperation:
return RequestValidationOperation(
self.next_app, self.next_app,
operation=operation, operation=operation,
validate_responses=self.validate_responses,
strict_validation=self.strict_validation, strict_validation=self.strict_validation,
validator_map=self.validator_map, validator_map=self.validator_map,
uri_parser_class=self.uri_parser_class, uri_parser_class=self.uri_parser_class,
) )
class ValidationMiddleware(RoutedMiddleware[ValidationAPI]): class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]):
"""Middleware for validating requests according to the API contract.""" """Middleware for validating requests according to the API contract."""
api_cls = ValidationAPI api_cls = RequestValidationAPI
class MissingValidationOperation(Exception): class MissingValidationOperation(Exception):

View File

@@ -0,0 +1,158 @@
"""
Validation Middleware.
"""
import logging
import typing as t
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion import utils
from connexion.exceptions import NonConformingResponseHeaders
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation
from connexion.validators import VALIDATOR_MAP
logger = logging.getLogger("connexion.middleware.validation")
class ResponseValidationOperation:
def __init__(
self,
next_app: ASGIApp,
*,
operation: AbstractOperation,
validator_map: t.Optional[dict] = None,
) -> None:
self.next_app = next_app
self._operation = operation
self._validator_map = VALIDATOR_MAP
self._validator_map.update(validator_map or {})
def extract_content_type(
self, headers: t.List[t.Tuple[bytes, bytes]]
) -> t.Tuple[str, str]:
"""Extract the mime type and encoding from the content type headers.
:param headers: Headers from ASGI scope
:return: A tuple of mime type, encoding
"""
mime_type, encoding = utils.extract_content_type(headers)
if mime_type is None:
# Content-type header is not required. Take a best guess.
try:
mime_type = self._operation.produces[0]
except IndexError:
mime_type = "application/octet-stream"
if encoding is None:
encoding = "utf-8"
return mime_type, encoding
def validate_mime_type(self, mime_type: str) -> None:
"""Validate the mime type against the spec.
:param mime_type: mime type from content type header
"""
if mime_type.lower() not in [c.lower() for c in self._operation.produces]:
raise NonConformingResponseHeaders(
reason="Invalid Response Content-type",
message=f"Invalid Response Content-type ({mime_type}), "
f"expected {self._operation.produces}",
)
@staticmethod
def validate_required_headers(
headers: t.List[tuple], response_definition: dict
) -> None:
required_header_keys = {
k.lower()
for (k, v) in response_definition.get("headers", {}).items()
if v.get("required", False)
}
header_keys = set(header[0].decode("latin-1").lower() for header in headers)
missing_keys = required_header_keys - header_keys
if missing_keys:
pretty_list = ", ".join(missing_keys)
msg = (
"Keys in header don't match response specification. Difference: {}"
).format(pretty_list)
raise NonConformingResponseHeaders(message=msg)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
send_fn = send
async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
nonlocal send_fn
if message["type"] == "http.response.start":
status = str(message["status"])
headers = message["headers"]
mime_type, encoding = self.extract_content_type(headers)
# TODO: Add produces to all tests and fix response content types
# self.validate_mime_type(mime_type)
response_definition = self._operation.response_definition(
status, mime_type
)
self.validate_required_headers(headers, response_definition)
# Validate body
try:
body_validator = self._validator_map["response"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
scope,
send,
schema=self._operation.response_schema(status, mime_type),
nullable=utils.is_nullable(
self._operation.response_definition(status, mime_type)
),
encoding=encoding,
)
send_fn = validator.send
return await send_fn(message)
await self.next_app(scope, receive, wrapped_send)
class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]):
"""Validation API."""
operation_cls = ResponseValidationOperation
def __init__(
self,
*args,
validator_map=None,
validate_responses=False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.validator_map = validator_map
self.validate_responses = validate_responses
self.add_paths()
def make_operation(
self, operation: AbstractOperation
) -> ResponseValidationOperation:
if self.validate_responses:
return ResponseValidationOperation(
self.next_app,
operation=operation,
validator_map=self.validator_map,
)
else:
return self.next_app # type: ignore
class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]):
"""Middleware for validating requests according to the API contract."""
api_cls = ResponseValidationAPI

View File

@@ -9,7 +9,6 @@ import logging
from ..decorators.decorator import RequestResponseDecorator from ..decorators.decorator import RequestResponseDecorator
from ..decorators.parameter import parameter_to_arg from ..decorators.parameter import parameter_to_arg
from ..decorators.produces import BaseSerializer, Produces from ..decorators.produces import BaseSerializer, Produces
from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator, RequestBodyValidator from ..decorators.validation import ParameterValidator, RequestBodyValidator
from ..utils import all_json, is_nullable from ..utils import all_json, is_nullable
@@ -20,7 +19,6 @@ DEFAULT_MIMETYPE = "application/json"
VALIDATOR_MAP = { VALIDATOR_MAP = {
"parameter": ParameterValidator, "parameter": ParameterValidator,
"body": RequestBodyValidator, "body": RequestBodyValidator,
"response": ResponseValidator,
} }
@@ -389,12 +387,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
self.pythonic_params, self.pythonic_params,
) )
if self.validate_responses:
logger.debug("... Response validation enabled.")
response_decorator = self.__response_validation_decorator
logger.debug("... Adding response decorator (%r)", response_decorator)
function = response_decorator(function)
produces_decorator = self.__content_type_decorator produces_decorator = self.__content_type_decorator
logger.debug("... Adding produces decorator (%r)", produces_decorator) logger.debug("... Adding produces decorator (%r)", produces_decorator)
function = produces_decorator(function) function = produces_decorator(function)
@@ -473,15 +465,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
strict_validation=self.strict_validation, strict_validation=self.strict_validation,
) )
@property
def __response_validation_decorator(self):
"""
Get a decorator for validating the generated Response.
:rtype: types.FunctionType
"""
ResponseValidator = self.validator_map["response"]
return ResponseValidator(self, self.get_mimetype())
def json_loads(self, data): def json_loads(self, data):
""" """
A wrapper for calling the API specific JSON loader. A wrapper for calling the API specific JSON loader.

View File

@@ -5,6 +5,7 @@ This module provides general utility functions used within Connexion.
import asyncio import asyncio
import functools import functools
import importlib import importlib
import typing as t
import yaml import yaml
@@ -266,3 +267,32 @@ def not_installed_error(exc): # pragma: no cover
raise exc raise exc
return functools.partial(_required_lib, exc) return functools.partial(_required_lib, exc)
def extract_content_type(
headers: t.List[t.Tuple[bytes, bytes]]
) -> t.Tuple[t.Optional[str], t.Optional[str]]:
"""Extract the mime type and encoding from the content type headers.
:param headers: Headers from ASGI scope
:return: A tuple of mime type, encoding
"""
mime_type, encoding = None, None
for key, value in headers:
# Headers can always be decoded using latin-1:
# https://stackoverflow.com/a/27357138/4098821
decoded_key = key.decode("latin-1")
if decoded_key.lower() == "content-type":
content_type = value.decode("latin-1")
if ";" in content_type:
mime_type, parameters = content_type.split(";", maxsplit=1)
prefix = "charset="
for parameter in parameters.split(";"):
if parameter.startswith(prefix):
encoding = parameter[len(prefix) :]
else:
mime_type = content_type
break
return mime_type, encoding

View File

@@ -6,36 +6,38 @@ import logging
import typing as t import typing as t
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import Receive, Scope, Send
from connexion.exceptions import BadRequestProblem from connexion.decorators.validation import ParameterValidator
from connexion.json_schema import Draft4RequestValidator from connexion.exceptions import BadRequestProblem, NonConformingResponseBody
from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator
from connexion.utils import is_null from connexion.utils import is_null
logger = logging.getLogger("connexion.middleware.validators") logger = logging.getLogger("connexion.middleware.validators")
class JSONBodyValidator: class JSONRequestBodyValidator:
"""Request body validator for json content types.""" """Request body validator for json content types."""
def __init__( def __init__(
self, self,
next_app: ASGIApp, scope: Scope,
receive: Receive,
*, *,
schema: dict, schema: dict,
validator: t.Type[Draft4Validator] = None, validator: t.Type[Draft4Validator] = None,
nullable=False, nullable=False,
encoding: str, encoding: str,
) -> None: ) -> None:
self.next_app = next_app self._scope = scope
self._receive = receive
self.schema = schema self.schema = schema
self.has_default = schema.get("default", False) self.has_default = schema.get("default", False)
self.nullable = nullable self.nullable = nullable
self.validator_cls = validator or Draft4RequestValidator validator_cls = validator or Draft4RequestValidator
self.validator = self.validator_cls( self.validator = validator_cls(schema, format_checker=draft4_format_checker)
schema, format_checker=draft4_format_checker
)
self.encoding = encoding self.encoding = encoding
self._messages: t.List[t.MutableMapping[str, t.Any]] = []
@classmethod @classmethod
def _error_path_message(cls, exception): def _error_path_message(cls, exception):
@@ -44,7 +46,6 @@ class JSONBodyValidator:
return error_path_msg return error_path_msg
def validate(self, body: dict): def validate(self, body: dict):
try: try:
self.validator.validate(body) self.validator.validate(body)
except ValidationError as exception: except ValidationError as exception:
@@ -55,18 +56,15 @@ class JSONBodyValidator:
) )
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def receive(self) -> t.Optional[t.MutableMapping[str, t.Any]]:
# Based on https://github.com/encode/starlette/pull/1519#issuecomment-1060633787
# Ingest all body messages from the ASGI `receive` callable.
messages = []
more_body = True more_body = True
while more_body: while more_body:
message = await receive() message = await self._receive()
messages.append(message) self._messages.append(message)
more_body = message.get("more_body", False) more_body = message.get("more_body", False)
# TODO: make json library pluggable # TODO: make json library pluggable
bytes_body = b"".join([message.get("body", b"") for message in messages]) bytes_body = b"".join([message.get("body", b"") for message in self._messages])
decoded_body = bytes_body.decode(self.encoding) decoded_body = bytes_body.decode(self.encoding)
if decoded_body and not (self.nullable and is_null(decoded_body)): if decoded_body and not (self.nullable and is_null(decoded_body)):
@@ -77,11 +75,92 @@ class JSONBodyValidator:
self.validate(body) self.validate(body)
async def wrapped_receive(): while self._messages:
# First up we want to return any messages we've stashed. return self._messages.pop(0)
if messages: return None
return messages.pop(0)
# Once that's done we can just await any other messages.
return await receive()
await self.next_app(scope, wrapped_receive, send)
class JSONResponseBodyValidator:
"""Response body validator for json content types."""
def __init__(
self,
scope: Scope,
send: Send,
*,
schema: dict,
validator: t.Type[Draft4Validator] = None,
nullable=False,
encoding: str,
) -> None:
self._scope = scope
self._send = send
self.schema = schema
self.has_default = schema.get("default", False)
self.nullable = nullable
validator_cls = validator or Draft4ResponseValidator
self.validator = validator_cls(schema, format_checker=draft4_format_checker)
self.encoding = encoding
self._messages: t.List[t.MutableMapping[str, t.Any]] = []
@classmethod
def _error_path_message(cls, exception):
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg
def validate(self, body: dict):
try:
self.validator.validate(body)
except ValidationError as exception:
error_path_msg = self._error_path_message(exception=exception)
logger.error(
f"Validation error: {exception.message}{error_path_msg}",
extra={"validator": "body"},
)
raise NonConformingResponseBody(
message=f"{exception.message}{error_path_msg}"
)
@staticmethod
def parse(body: str) -> dict:
try:
return json.loads(body)
except json.decoder.JSONDecodeError as e:
raise BadRequestProblem(str(e))
async def send(self, message: t.MutableMapping[str, t.Any]) -> None:
self._messages.append(message)
if message["type"] == "http.response.start" or message.get("more_body", False):
return
# TODO: make json library pluggable
bytes_body = b"".join([message.get("body", b"") for message in self._messages])
decoded_body = bytes_body.decode(self.encoding)
if decoded_body and not (self.nullable and is_null(decoded_body)):
body = self.parse(decoded_body)
self.validate(body)
while self._messages:
await self._send(self._messages.pop(0))
class TextResponseBodyValidator(JSONResponseBodyValidator):
@staticmethod
def parse(body: str) -> str: # type: ignore
try:
return json.loads(body)
except json.decoder.JSONDecodeError:
return body
VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {"application/json": JSONRequestBodyValidator},
"response": {
"application/json": JSONResponseBodyValidator,
"text/plain": TextResponseBodyValidator,
},
}

View File

@@ -34,7 +34,7 @@ def test_header_not_returned(simple_openapi_app):
assert data["title"] == "Response headers do not conform to specification" assert data["title"] == "Response headers do not conform to specification"
assert ( assert (
data["detail"] data["detail"]
== "Keys in header don't match response specification. Difference: Location" == "Keys in header don't match response specification. Difference: location"
) )
assert data["status"] == 500 assert data["status"] == 500

View File

@@ -80,59 +80,59 @@ def test_schema_response(schema_app):
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/object/valid", headers={}, data=None "/v1.0/test_schema/response/object/valid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 200 assert request.status_code == 200, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/object/invalid_type", headers={}, data=None "/v1.0/test_schema/response/object/invalid_type", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/object/invalid_requirements", headers={}, data=None "/v1.0/test_schema/response/object/invalid_requirements", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/string/valid", headers={}, data=None "/v1.0/test_schema/response/string/valid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 200 assert request.status_code == 200, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/string/invalid", headers={}, data=None "/v1.0/test_schema/response/string/invalid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/integer/valid", headers={}, data=None "/v1.0/test_schema/response/integer/valid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 200 assert request.status_code == 200, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/integer/invalid", headers={}, data=None "/v1.0/test_schema/response/integer/invalid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/number/valid", headers={}, data=None "/v1.0/test_schema/response/number/valid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 200 assert request.status_code == 200, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/number/invalid", headers={}, data=None "/v1.0/test_schema/response/number/invalid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/boolean/valid", headers={}, data=None "/v1.0/test_schema/response/boolean/valid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 200 assert request.status_code == 200, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/boolean/invalid", headers={}, data=None "/v1.0/test_schema/response/boolean/invalid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/array/valid", headers={}, data=None "/v1.0/test_schema/response/array/valid", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 200 assert request.status_code == 200, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/array/invalid_dict", headers={}, data=None "/v1.0/test_schema/response/array/invalid_dict", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
request = app_client.get( request = app_client.get(
"/v1.0/test_schema/response/array/invalid_string", headers={}, data=None "/v1.0/test_schema/response/array/invalid_string", headers={}, data=None
) # type: flask.Response ) # type: flask.Response
assert request.status_code == 500 assert request.status_code == 500, request.text
def test_schema_in_query(schema_app): def test_schema_in_query(schema_app):

View File

@@ -6,19 +6,22 @@ class PetsView(MethodView):
mycontent = "demonstrate return from MethodView class" mycontent = "demonstrate return from MethodView class"
def get(self, **kwargs): def get(self, **kwargs):
kwargs.update({"method": "get"}) if kwargs:
return kwargs kwargs.update({"name": "get"})
return kwargs
else:
return [{"name": "get"}]
def search(self): def search(self):
return "search" return [{"name": "search"}]
def post(self, **kwargs): def post(self, **kwargs):
kwargs.update({"method": "post"}) kwargs.update({"name": "post"})
return kwargs return kwargs, 201
def put(self, *args, **kwargs): def put(self, *args, **kwargs):
kwargs.update({"method": "put"}) kwargs.update({"name": "put"})
return kwargs return kwargs, 201
# Test that operation_id can still override resolver # Test that operation_id can still override resolver

View File

@@ -3,10 +3,9 @@ import pathlib
import pytest import pytest
from connexion import App from connexion import App
from connexion.decorators.validation import RequestBodyValidator
from connexion.json_schema import Draft4RequestValidator from connexion.json_schema import Draft4RequestValidator
from connexion.spec import Specification from connexion.spec import Specification
from connexion.validators import JSONBodyValidator from connexion.validators import JSONRequestBodyValidator
from jsonschema.validators import _utils, extend from jsonschema.validators import _utils, extend
from conftest import build_app_from_fixture from conftest import build_app_from_fixture
@@ -31,7 +30,7 @@ def test_validator_map(json_validation_spec_dir, spec):
MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type}) MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type})
class MyJSONBodyValidator(JSONBodyValidator): class MyJSONBodyValidator(JSONRequestBodyValidator):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, validator=MinLengthRequestValidator, **kwargs) super().__init__(*args, validator=MinLengthRequestValidator, **kwargs)

View File

@@ -192,13 +192,13 @@ def test_method_view_resolver_integration(method_view_app):
client = method_view_app.app.test_client() client = method_view_app.app.test_client()
r = client.get("/v1.0/pets") r = client.get("/v1.0/pets")
assert r.json == {"method": "get"} assert r.json == [{"name": "get"}]
r = client.get("/v1.0/pets/1") r = client.get("/v1.0/pets/1")
assert r.json == {"method": "get", "petId": 1} assert r.json == {"name": "get", "petId": 1}
r = client.post("/v1.0/pets", json={"name": "Musti"}) r = client.post("/v1.0/pets", json={"name": "Musti"})
assert r.json == {"method": "post", "body": {"name": "Musti"}} assert r.json == {"name": "post", "body": {"name": "Musti"}}
r = client.put("/v1.0/pets/1", json={"name": "Igor"}) r = client.put("/v1.0/pets/1", json={"name": "Igor"})
assert r.json == {"method": "put", "petId": 1, "body": {"name": "Igor"}} assert r.json == {"name": "put", "petId": 1, "body": {"name": "Igor"}}