Create AbstractRequestBodyValidator class

This commit is contained in:
Robbe Sneyders
2023-02-25 00:59:07 +01:00
parent 969c1460e6
commit 3e733df181
11 changed files with 264 additions and 226 deletions

View File

@@ -107,6 +107,13 @@ def resolve_refs(spec, store=None, base_uri=""):
return res return res
def format_error_with_path(exception: ValidationError) -> str:
"""Format a `ValidationError` with path to error."""
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg
def allow_nullable(validation_fn: t.Callable) -> t.Callable: def allow_nullable(validation_fn: t.Callable) -> t.Callable:
"""Extend an existing validation function, so it allows nullable values to be null.""" """Extend an existing validation function, so it allows nullable values to be null."""

View File

@@ -65,5 +65,5 @@ class ExceptionMiddleware(StarletteExceptionMiddleware):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Needs to be set so starlette router throws exceptions instead of returning error responses # Needs to be set so starlette router throws exceptions instead of returning error responses
scope["app"] = self scope["app"] = "connexion"
await super().__call__(scope, receive, send) await super().__call__(scope, receive, send)

View File

@@ -68,8 +68,6 @@ class RequestValidationOperation:
) )
async def __call__(self, scope: Scope, receive: Receive, send: Send): async def __call__(self, scope: Scope, receive: Receive, send: Send):
receive_fn = receive
# Validate parameters & headers # Validate parameters & headers
uri_parser_class = self._operation._uri_parser_class uri_parser_class = self._operation._uri_parser_class
uri_parser = uri_parser_class( uri_parser = uri_parser_class(
@@ -100,8 +98,6 @@ class RequestValidationOperation:
) )
else: else:
validator = body_validator( validator = body_validator(
scope,
receive,
schema=schema, schema=schema,
required=self._operation.request_body.get("required", False), required=self._operation.request_body.get("required", False),
nullable=utils.is_nullable( nullable=utils.is_nullable(
@@ -113,9 +109,9 @@ class RequestValidationOperation:
self._operation.parameters, self._operation.body_definition() self._operation.parameters, self._operation.body_definition()
), ),
) )
receive_fn = await validator.wrapped_receive() receive = await validator.wrap_receive(receive, scope=scope)
await self.next_app(scope, receive_fn, send) await self.next_app(scope, receive, send)
class RequestValidationAPI(RoutedAPI[RequestValidationOperation]): class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):

View File

@@ -291,6 +291,7 @@ class Swagger2Operation(AbstractOperation):
default = param.get("default") default = param.get("default")
if default is not None: if default is not None:
prop["default"] = default
defaults[param["name"]] = default defaults[param["name"]] = default
nullable = param.get("x-nullable") nullable = param.get("x-nullable")
@@ -320,11 +321,11 @@ class Swagger2Operation(AbstractOperation):
"schema": { "schema": {
"type": "object", "type": "object",
"properties": properties, "properties": properties,
"default": defaults,
"required": required, "required": required,
} }
} }
if defaults:
definition["schema"]["default"] = defaults
if encoding: if encoding:
definition["encoding"] = encoding definition["encoding"] = encoding

View File

@@ -1,5 +1,6 @@
from connexion.datastructures import MediaTypeDict from connexion.datastructures import MediaTypeDict
from .abstract import AbstractRequestBodyValidator # NOQA
from .form_data import FormDataValidator, MultiPartFormDataValidator from .form_data import FormDataValidator, MultiPartFormDataValidator
from .json import DefaultsJSONRequestBodyValidator # NOQA from .json import DefaultsJSONRequestBodyValidator # NOQA
from .json import ( from .json import (

View File

@@ -0,0 +1,151 @@
"""
This module defines a Validator interface with base functionality that can be subclassed
for custom validators provided to the RequestValidationMiddleware.
"""
import copy
import json
import typing as t
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import Receive, Scope
from connexion.exceptions import BadRequestProblem
from connexion.utils import is_null
class AbstractRequestBodyValidator:
"""
Validator interface with base functionality that can be subclassed for custom validators.
.. note: Validators load the whole body into memory, which can be a problem for large payloads.
"""
MUTABLE_VALIDATION = False
"""
Whether mutations to the body during validation should be transmitted via the receive channel.
Note that this does not apply to the substitution of a missing body with the default body, which always
updates the receive channel.
"""
MAX_MESSAGE_LENGTH = 256000
"""Maximum message length that will be sent via the receive channel for mutated bodies."""
def __init__(
self,
*,
schema: dict,
required: bool = False,
nullable: bool = False,
encoding: str,
strict_validation: bool,
**kwargs,
):
"""
:param schema: Schema of operation to validate
:param required: Whether RequestBody is required
:param nullable: Whether RequestBody is nullable
:param encoding: Encoding of body (passed via Content-Type header)
:param kwargs: Additional arguments for subclasses
:param strict_validation: Whether to allow parameters not defined in the spec
"""
self._schema = schema
self._nullable = nullable
self._required = required
self._encoding = encoding
self._strict_validation = strict_validation
async def _parse(
self, stream: t.AsyncGenerator[bytes, None], scope: Scope
) -> t.Any:
"""Parse the incoming stream."""
def _validate(self, body: t.Any) -> t.Optional[dict]:
"""
Validate the parsed body.
:raises: :class:`connexion.exceptions.BadRequestProblem`
"""
def _insert_body(self, receive: Receive, *, body: t.Any, scope: Scope) -> Receive:
"""
Insert messages transmitting the body at the start of the `receive` channel.
This method updates the provided `scope` in place with the right `Content-Length` header.
"""
if body is None:
return receive
bytes_body = json.dumps(body).encode(self._encoding)
# Update the content-length header
new_scope = copy.deepcopy(scope)
headers = MutableHeaders(scope=new_scope)
headers["content-length"] = str(len(bytes_body))
# Wrap in new receive channel
messages = (
{
"type": "http.request",
"body": bytes_body[i : i + self.MAX_MESSAGE_LENGTH],
"more_body": i + self.MAX_MESSAGE_LENGTH < len(bytes_body),
}
for i in range(0, len(bytes_body), self.MAX_MESSAGE_LENGTH)
)
receive = self._insert_messages(receive, messages=messages)
return receive
@staticmethod
def _insert_messages(
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
) -> Receive:
"""Insert messages at the start of the `receive` channel."""
async def receive_() -> t.MutableMapping[str, t.Any]:
for message in messages:
return message
return await receive()
return receive_
async def wrap_receive(self, receive: Receive, *, scope: Scope) -> Receive:
"""
Wrap the provided `receive` channel with request body validation.
This method updates the provided `scope` in place with the right `Content-Length` header.
"""
# Handle missing bodies
headers = Headers(scope=scope)
if not int(headers.get("content-length", 0)):
body = self._schema.get("default")
if body is None and self._required:
raise BadRequestProblem("RequestBody is required")
# The default body is encoded as a `receive` channel to mimic an incoming body
receive = self._insert_body(receive, body=body, scope=scope)
# The receive channel is converted to a stream for convenient access
messages = []
async def stream() -> t.AsyncGenerator[bytes, None]:
more_body = True
while more_body:
message = await receive()
messages.append(message)
more_body = message.get("more_body", False)
yield message.get("body", b"")
yield b""
# The body is parsed and validated
body = await self._parse(stream(), scope=scope)
if not (body is None and self._nullable):
self._validate(body)
# If MUTABLE_VALIDATION is enabled, include any changes made during validation in the messages to send
if self.MUTABLE_VALIDATION:
# Include changes made during validation
receive = self._insert_body(receive, body=body, scope=scope)
else:
# Serialize original messages
receive = self._insert_messages(receive, messages=messages)
return receive

View File

@@ -1,85 +1,57 @@
import logging import logging
import typing as t import typing as t
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker from jsonschema import ValidationError, draft4_format_checker
from starlette.datastructures import FormData, Headers, UploadFile from starlette.datastructures import Headers, UploadFile
from starlette.formparsers import FormParser, MultiPartParser from starlette.formparsers import FormParser, MultiPartParser
from starlette.types import Receive, Scope from starlette.types import Scope
from connexion.exceptions import BadRequestProblem, ExtraParameterProblem from connexion.exceptions import BadRequestProblem, ExtraParameterProblem
from connexion.json_schema import Draft4RequestValidator from connexion.json_schema import Draft4RequestValidator, format_error_with_path
from connexion.uri_parsing import AbstractURIParser from connexion.uri_parsing import AbstractURIParser
from connexion.utils import is_null from connexion.validators import AbstractRequestBodyValidator
logger = logging.getLogger("connexion.validators.form_data") logger = logging.getLogger("connexion.validators.form_data")
class FormDataValidator: class FormDataValidator(AbstractRequestBodyValidator):
"""Request body validator for form content types.""" """Request body validator for form content types."""
def __init__( def __init__(
self, self,
scope: Scope,
receive: Receive,
*, *,
schema: dict, schema: dict,
validator: t.Type[Draft4Validator] = None,
required=False, required=False,
nullable=False, nullable=False,
encoding: str, encoding: str,
uri_parser: t.Optional[AbstractURIParser] = None,
strict_validation: bool, strict_validation: bool,
uri_parser: t.Optional[AbstractURIParser] = None,
) -> None: ) -> None:
self._scope = scope super().__init__(
self._receive = receive schema=schema,
self.schema = schema required=required,
self.has_default = schema.get("default", False) nullable=nullable,
self.nullable = nullable encoding=encoding,
self.required = required strict_validation=strict_validation,
validator_cls = validator or Draft4RequestValidator )
self.validator = validator_cls(schema, format_checker=draft4_format_checker) self._uri_parser = uri_parser
self.uri_parser = uri_parser
self.encoding = encoding
self._messages: t.List[t.MutableMapping[str, t.Any]] = []
self.headers = Headers(scope=scope)
self.strict_validation = strict_validation
self.check_empty()
@property @property
def form_parser_cls(self): def _validator(self):
return Draft4RequestValidator(
self._schema, format_checker=draft4_format_checker
)
@property
def _form_parser_cls(self):
return FormParser return FormParser
def check_empty(self): async def _parse(self, stream: t.AsyncGenerator[bytes, None], scope: Scope) -> dict:
"""`receive` is never called if body is empty, so we need to check this case at headers = Headers(scope=scope)
initialization.""" form_parser = self._form_parser_cls(headers, stream)
if not int(self.headers.get("content-length", 0)): data = await form_parser.parse()
# TODO: default should be passed along and content-length updated
if self.schema.get("default"):
self.validate(self.schema.get("default"))
elif self.required: # RequestBody itself is required
raise BadRequestProblem("RequestBody is required")
elif self.schema.get("required", []): # Required top level properties
self._validate({})
@classmethod if self._uri_parser is not None:
def _error_path_message(cls, exception):
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg
def _validate(self, data: dict) -> None:
try:
self.validator.validate(data)
except ValidationError as exception:
error_path_msg = self._error_path_message(exception=exception)
logger.error(
f"Validation error: {exception.message}{error_path_msg}",
extra={"validator": "body"},
)
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
def _parse(self, data: FormData) -> dict:
if self.uri_parser is not None:
# Don't parse file_data # Don't parse file_data
form_data = {} form_data = {}
file_data = {} file_data = {}
@@ -90,7 +62,7 @@ class FormDataValidator:
# Replace files with empty strings for validation # Replace files with empty strings for validation
file_data[k] = "" file_data[k] = ""
data = self.uri_parser.resolve_form(form_data) data = self._uri_parser.resolve_form(form_data)
# Add the files again # Add the files again
data.update(file_data) data.update(file_data)
else: else:
@@ -98,45 +70,29 @@ class FormDataValidator:
return data return data
def _validate_strictly(self, data: FormData) -> None: def _validate(self, data: dict) -> None:
if self._strict_validation:
self._validate_params_strictly(data)
try:
self._validator.validate(data)
except ValidationError as exception:
error_path_msg = format_error_with_path(exception=exception)
logger.error(
f"Validation error: {exception.message}{error_path_msg}",
extra={"validator": "body"},
)
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
def _validate_params_strictly(self, data: dict) -> None:
form_params = data.keys() form_params = data.keys()
spec_params = self.schema.get("properties", {}).keys() spec_params = self._schema.get("properties", {}).keys()
errors = set(form_params).difference(set(spec_params)) errors = set(form_params).difference(set(spec_params))
if errors: if errors:
raise ExtraParameterProblem(param_type="formData", extra_params=errors) raise ExtraParameterProblem(param_type="formData", extra_params=errors)
def validate(self, data: FormData) -> None:
if self.strict_validation:
self._validate_strictly(data)
data = self._parse(data)
self._validate(data)
async def wrapped_receive(self) -> Receive:
async def stream() -> t.AsyncGenerator[bytes, None]:
more_body = True
while more_body:
message = await self._receive()
self._messages.append(message)
more_body = message.get("more_body", False)
yield message.get("body", b"")
yield b""
form_parser = self.form_parser_cls(self.headers, stream())
form = await form_parser.parse()
if form and not (self.nullable and is_null(form)):
self.validate(form)
async def receive() -> t.MutableMapping[str, t.Any]:
while self._messages:
return self._messages.pop(0)
return await self._receive()
return receive
class MultiPartFormDataValidator(FormDataValidator): class MultiPartFormDataValidator(FormDataValidator):
@property @property
def form_parser_cls(self): def _form_parser_cls(self):
return MultiPartParser return MultiPartParser

View File

@@ -4,107 +4,83 @@ import typing as t
import jsonschema import jsonschema
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
from starlette.datastructures import Headers from starlette.types import Scope, Send
from starlette.types import Receive, Scope, Send
from connexion.exceptions import BadRequestProblem, NonConformingResponseBody from connexion.exceptions import BadRequestProblem, NonConformingResponseBody
from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator from connexion.json_schema import (
from connexion.utils import is_null Draft4RequestValidator,
Draft4ResponseValidator,
format_error_with_path,
)
from connexion.validators import AbstractRequestBodyValidator
logger = logging.getLogger("connexion.validators.json") logger = logging.getLogger(__name__)
class JSONRequestBodyValidator: class JSONRequestBodyValidator(AbstractRequestBodyValidator):
"""Request body validator for json content types.""" """Request body validator for json content types."""
def __init__( def __init__(
self, self,
scope: Scope,
receive: Receive,
*, *,
schema: dict, schema: dict,
validator: t.Type[Draft4Validator] = Draft4RequestValidator,
required=False, required=False,
nullable=False, nullable=False,
encoding: str, encoding: str,
strict_validation: bool,
**kwargs, **kwargs,
) -> None: ) -> None:
self._scope = scope super().__init__(
self._receive = receive schema=schema,
self.schema = schema required=required,
self.has_default = schema.get("default", False) nullable=nullable,
self.nullable = nullable encoding=encoding,
self.required = required strict_validation=strict_validation,
self.validator = validator(schema, format_checker=draft4_format_checker) )
self.encoding = encoding
self.headers = Headers(scope=scope)
self.check_empty()
def check_empty(self): @property
"""receive` is never called if body is empty, so we need to check this case at def _validator(self):
initialization.""" return Draft4RequestValidator(
if not int(self.headers.get("content-length", 0)): self._schema, format_checker=draft4_format_checker
# TODO: default should be passed along and content-length updated )
if self.schema.get("default"):
self.validate(self.schema.get("default"))
elif self.required: # RequestBody itself is required
raise BadRequestProblem("RequestBody is required")
elif self.schema.get("required", []): # Required top level properties
self.validate({})
@classmethod async def _parse(
def _error_path_message(cls, exception): self, stream: t.AsyncGenerator[bytes, None], scope: Scope
error_path = ".".join(str(item) for item in exception.path) ) -> t.Any:
error_path_msg = f" - '{error_path}'" if error_path else "" bytes_body = b"".join([message async for message in stream])
return error_path_msg body = bytes_body.decode(self._encoding)
if not body:
return None
def validate(self, body: dict):
try: try:
self.validator.validate(body) return json.loads(body)
except json.decoder.JSONDecodeError as e:
raise BadRequestProblem(detail=str(e))
def _validate(self, body: dict) -> None:
try:
return self._validator.validate(body)
except ValidationError as exception: except ValidationError as exception:
error_path_msg = self._error_path_message(exception=exception) error_path_msg = format_error_with_path(exception=exception)
logger.error( logger.error(
f"Validation error: {exception.message}{error_path_msg}", f"Validation error: {exception.message}{error_path_msg}",
extra={"validator": "body"}, extra={"validator": "body"},
) )
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
def parse(self, body: str) -> dict:
try:
return json.loads(body)
except json.decoder.JSONDecodeError as e:
raise BadRequestProblem(str(e))
async def wrapped_receive(self) -> Receive:
more_body = True
messages = []
while more_body:
message = await self._receive()
messages.append(message)
more_body = message.get("more_body", False)
bytes_body = b"".join([message.get("body", b"") for message in messages])
decoded_body = bytes_body.decode(self.encoding)
if decoded_body and not (self.nullable and is_null(decoded_body)):
body = self.parse(decoded_body)
self.validate(body)
async def receive() -> t.MutableMapping[str, t.Any]:
while messages:
return messages.pop(0)
return await self._receive()
return receive
class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator): class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator):
"""Request body validator for json content types which fills in default values. This Validator """Request body validator for json content types which fills in default values. This Validator
intercepts the body, makes changes to it, and replays it for the next ASGI application.""" intercepts the body, makes changes to it, and replays it for the next ASGI application."""
def __init__(self, *args, **kwargs): MUTABLE_VALIDATION = True
defaults_validator = self.extend_with_set_default(Draft4RequestValidator) """This validator might mutate to the body."""
super().__init__(*args, validator=defaults_validator, **kwargs)
@property
def _validator(self):
validator_cls = self.extend_with_set_default(Draft4RequestValidator)
return validator_cls(self._schema, format_checker=draft4_format_checker)
# via https://python-jsonschema.readthedocs.io/ # via https://python-jsonschema.readthedocs.io/
@staticmethod @staticmethod
@@ -122,58 +98,6 @@ class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator):
validator_class, {"properties": set_defaults} validator_class, {"properties": set_defaults}
) )
async def read_body(self) -> t.Tuple[str, int]:
"""Read the body from the receive channel.
:return: A tuple (body, max_length) where max_length is the length of the largest message.
"""
more_body = True
max_length = 256000
messages = []
while more_body:
message = await self._receive()
max_length = max(max_length, len(message.get("body", b"")))
messages.append(message)
more_body = message.get("more_body", False)
bytes_body = b"".join([message.get("body", b"") for message in messages])
return bytes_body.decode(self.encoding), max_length
async def wrapped_receive(self) -> Receive:
"""Receive channel to pass on to next ASGI application."""
decoded_body, max_length = await self.read_body()
# Validate the body if not null
if decoded_body and not (self.nullable and is_null(decoded_body)):
body = self.parse(decoded_body)
del decoded_body
self.validate(body)
str_body = json.dumps(body)
else:
str_body = decoded_body
bytes_body = str_body.encode(self.encoding)
del str_body
# Recreate ASGI messages from validated body so changes made by the validator are propagated
messages = [
{
"type": "http.request",
"body": bytes_body[i : i + max_length],
"more_body": i + max_length < len(bytes_body),
}
for i in range(0, len(bytes_body), max_length)
]
del bytes_body
async def receive() -> t.MutableMapping[str, t.Any]:
while messages:
return messages.pop(0)
return await self._receive()
return receive
class JSONResponseBodyValidator: class JSONResponseBodyValidator:
"""Response body validator for json content types.""" """Response body validator for json content types."""

View File

@@ -314,7 +314,10 @@ def test_mixed_formdata(simple_app):
def test_formdata_file_upload_bad_request(simple_app): def test_formdata_file_upload_bad_request(simple_app):
app_client = simple_app.test_client() app_client = simple_app.test_client()
resp = app_client.post("/v1.0/test-formData-file-upload") resp = app_client.post(
"/v1.0/test-formData-file-upload",
headers={"Content-Type": b"multipart/form-data; boundary=-"},
)
assert resp.status_code == 400 assert resp.status_code == 400
assert resp.json()["detail"] in [ assert resp.json()["detail"] in [
"Missing formdata parameter 'fileData'", "Missing formdata parameter 'fileData'",
@@ -443,11 +446,8 @@ def test_nullable_parameter(simple_app):
resp = app_client.put("/v1.0/nullable-parameters", content="null", headers=headers) resp = app_client.put("/v1.0/nullable-parameters", content="null", headers=headers)
assert resp.json() == "it was None" assert resp.json() == "it was None"
resp = app_client.put("/v1.0/nullable-parameters", content="None", headers=headers)
assert resp.json() == "it was None"
resp = app_client.put( resp = app_client.put(
"/v1.0/nullable-parameters-noargs", content="None", headers=headers "/v1.0/nullable-parameters-noargs", content="null", headers=headers
) )
assert resp.json() == "hello" assert resp.json() == "hello"

View File

@@ -28,8 +28,9 @@ def test_validator_map(json_validation_spec_dir, spec):
MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type}) MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type})
class MyJSONBodyValidator(JSONRequestBodyValidator): class MyJSONBodyValidator(JSONRequestBodyValidator):
def __init__(self, *args, **kwargs): @property
super().__init__(*args, validator=MinLengthRequestValidator, **kwargs) def _validator(self):
return MinLengthRequestValidator(self._schema)
validator_map = {"body": {"application/json": MyJSONBodyValidator}} validator_map = {"body": {"application/json": MyJSONBodyValidator}}

View File

@@ -749,6 +749,7 @@ def test_form_transformation(api):
"param": { "param": {
"type": "string", "type": "string",
"format": "email", "format": "email",
"default": "foo@bar.com",
}, },
"array_param": { "array_param": {
"type": "array", "type": "array",