mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-10 04:19:37 +00:00
Create AbstractRequestBodyValidator class
This commit is contained in:
@@ -107,6 +107,13 @@ def resolve_refs(spec, store=None, base_uri=""):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def format_error_with_path(exception: ValidationError) -> str:
|
||||||
|
"""Format a `ValidationError` with path to error."""
|
||||||
|
error_path = ".".join(str(item) for item in exception.path)
|
||||||
|
error_path_msg = f" - '{error_path}'" if error_path else ""
|
||||||
|
return error_path_msg
|
||||||
|
|
||||||
|
|
||||||
def allow_nullable(validation_fn: t.Callable) -> t.Callable:
|
def allow_nullable(validation_fn: t.Callable) -> t.Callable:
|
||||||
"""Extend an existing validation function, so it allows nullable values to be null."""
|
"""Extend an existing validation function, so it allows nullable values to be null."""
|
||||||
|
|
||||||
|
|||||||
@@ -65,5 +65,5 @@ class ExceptionMiddleware(StarletteExceptionMiddleware):
|
|||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
# Needs to be set so starlette router throws exceptions instead of returning error responses
|
# Needs to be set so starlette router throws exceptions instead of returning error responses
|
||||||
scope["app"] = self
|
scope["app"] = "connexion"
|
||||||
await super().__call__(scope, receive, send)
|
await super().__call__(scope, receive, send)
|
||||||
|
|||||||
@@ -68,8 +68,6 @@ class RequestValidationOperation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||||
receive_fn = receive
|
|
||||||
|
|
||||||
# Validate parameters & headers
|
# Validate parameters & headers
|
||||||
uri_parser_class = self._operation._uri_parser_class
|
uri_parser_class = self._operation._uri_parser_class
|
||||||
uri_parser = uri_parser_class(
|
uri_parser = uri_parser_class(
|
||||||
@@ -100,8 +98,6 @@ class RequestValidationOperation:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validator = body_validator(
|
validator = body_validator(
|
||||||
scope,
|
|
||||||
receive,
|
|
||||||
schema=schema,
|
schema=schema,
|
||||||
required=self._operation.request_body.get("required", False),
|
required=self._operation.request_body.get("required", False),
|
||||||
nullable=utils.is_nullable(
|
nullable=utils.is_nullable(
|
||||||
@@ -113,9 +109,9 @@ class RequestValidationOperation:
|
|||||||
self._operation.parameters, self._operation.body_definition()
|
self._operation.parameters, self._operation.body_definition()
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
receive_fn = await validator.wrapped_receive()
|
receive = await validator.wrap_receive(receive, scope=scope)
|
||||||
|
|
||||||
await self.next_app(scope, receive_fn, send)
|
await self.next_app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
|
class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
|
||||||
|
|||||||
@@ -291,6 +291,7 @@ class Swagger2Operation(AbstractOperation):
|
|||||||
|
|
||||||
default = param.get("default")
|
default = param.get("default")
|
||||||
if default is not None:
|
if default is not None:
|
||||||
|
prop["default"] = default
|
||||||
defaults[param["name"]] = default
|
defaults[param["name"]] = default
|
||||||
|
|
||||||
nullable = param.get("x-nullable")
|
nullable = param.get("x-nullable")
|
||||||
@@ -320,11 +321,11 @@ class Swagger2Operation(AbstractOperation):
|
|||||||
"schema": {
|
"schema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": properties,
|
"properties": properties,
|
||||||
"default": defaults,
|
|
||||||
"required": required,
|
"required": required,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if defaults:
|
||||||
|
definition["schema"]["default"] = defaults
|
||||||
if encoding:
|
if encoding:
|
||||||
definition["encoding"] = encoding
|
definition["encoding"] = encoding
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from connexion.datastructures import MediaTypeDict
|
from connexion.datastructures import MediaTypeDict
|
||||||
|
|
||||||
|
from .abstract import AbstractRequestBodyValidator # NOQA
|
||||||
from .form_data import FormDataValidator, MultiPartFormDataValidator
|
from .form_data import FormDataValidator, MultiPartFormDataValidator
|
||||||
from .json import DefaultsJSONRequestBodyValidator # NOQA
|
from .json import DefaultsJSONRequestBodyValidator # NOQA
|
||||||
from .json import (
|
from .json import (
|
||||||
|
|||||||
151
connexion/validators/abstract.py
Normal file
151
connexion/validators/abstract.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""
|
||||||
|
This module defines a Validator interface with base functionality that can be subclassed
|
||||||
|
for custom validators provided to the RequestValidationMiddleware.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from starlette.datastructures import Headers, MutableHeaders
|
||||||
|
from starlette.types import Receive, Scope
|
||||||
|
|
||||||
|
from connexion.exceptions import BadRequestProblem
|
||||||
|
from connexion.utils import is_null
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractRequestBodyValidator:
|
||||||
|
"""
|
||||||
|
Validator interface with base functionality that can be subclassed for custom validators.
|
||||||
|
|
||||||
|
.. note: Validators load the whole body into memory, which can be a problem for large payloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MUTABLE_VALIDATION = False
|
||||||
|
"""
|
||||||
|
Whether mutations to the body during validation should be transmitted via the receive channel.
|
||||||
|
Note that this does not apply to the substitution of a missing body with the default body, which always
|
||||||
|
updates the receive channel.
|
||||||
|
"""
|
||||||
|
MAX_MESSAGE_LENGTH = 256000
|
||||||
|
"""Maximum message length that will be sent via the receive channel for mutated bodies."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
schema: dict,
|
||||||
|
required: bool = False,
|
||||||
|
nullable: bool = False,
|
||||||
|
encoding: str,
|
||||||
|
strict_validation: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param schema: Schema of operation to validate
|
||||||
|
:param required: Whether RequestBody is required
|
||||||
|
:param nullable: Whether RequestBody is nullable
|
||||||
|
:param encoding: Encoding of body (passed via Content-Type header)
|
||||||
|
:param kwargs: Additional arguments for subclasses
|
||||||
|
:param strict_validation: Whether to allow parameters not defined in the spec
|
||||||
|
"""
|
||||||
|
self._schema = schema
|
||||||
|
self._nullable = nullable
|
||||||
|
self._required = required
|
||||||
|
self._encoding = encoding
|
||||||
|
self._strict_validation = strict_validation
|
||||||
|
|
||||||
|
async def _parse(
|
||||||
|
self, stream: t.AsyncGenerator[bytes, None], scope: Scope
|
||||||
|
) -> t.Any:
|
||||||
|
"""Parse the incoming stream."""
|
||||||
|
|
||||||
|
def _validate(self, body: t.Any) -> t.Optional[dict]:
|
||||||
|
"""
|
||||||
|
Validate the parsed body.
|
||||||
|
|
||||||
|
:raises: :class:`connexion.exceptions.BadRequestProblem`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _insert_body(self, receive: Receive, *, body: t.Any, scope: Scope) -> Receive:
|
||||||
|
"""
|
||||||
|
Insert messages transmitting the body at the start of the `receive` channel.
|
||||||
|
|
||||||
|
This method updates the provided `scope` in place with the right `Content-Length` header.
|
||||||
|
"""
|
||||||
|
if body is None:
|
||||||
|
return receive
|
||||||
|
|
||||||
|
bytes_body = json.dumps(body).encode(self._encoding)
|
||||||
|
|
||||||
|
# Update the content-length header
|
||||||
|
new_scope = copy.deepcopy(scope)
|
||||||
|
headers = MutableHeaders(scope=new_scope)
|
||||||
|
headers["content-length"] = str(len(bytes_body))
|
||||||
|
|
||||||
|
# Wrap in new receive channel
|
||||||
|
messages = (
|
||||||
|
{
|
||||||
|
"type": "http.request",
|
||||||
|
"body": bytes_body[i : i + self.MAX_MESSAGE_LENGTH],
|
||||||
|
"more_body": i + self.MAX_MESSAGE_LENGTH < len(bytes_body),
|
||||||
|
}
|
||||||
|
for i in range(0, len(bytes_body), self.MAX_MESSAGE_LENGTH)
|
||||||
|
)
|
||||||
|
|
||||||
|
receive = self._insert_messages(receive, messages=messages)
|
||||||
|
|
||||||
|
return receive
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _insert_messages(
|
||||||
|
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
|
||||||
|
) -> Receive:
|
||||||
|
"""Insert messages at the start of the `receive` channel."""
|
||||||
|
|
||||||
|
async def receive_() -> t.MutableMapping[str, t.Any]:
|
||||||
|
for message in messages:
|
||||||
|
return message
|
||||||
|
return await receive()
|
||||||
|
|
||||||
|
return receive_
|
||||||
|
|
||||||
|
async def wrap_receive(self, receive: Receive, *, scope: Scope) -> Receive:
|
||||||
|
"""
|
||||||
|
Wrap the provided `receive` channel with request body validation.
|
||||||
|
|
||||||
|
This method updates the provided `scope` in place with the right `Content-Length` header.
|
||||||
|
"""
|
||||||
|
# Handle missing bodies
|
||||||
|
headers = Headers(scope=scope)
|
||||||
|
if not int(headers.get("content-length", 0)):
|
||||||
|
body = self._schema.get("default")
|
||||||
|
if body is None and self._required:
|
||||||
|
raise BadRequestProblem("RequestBody is required")
|
||||||
|
# The default body is encoded as a `receive` channel to mimic an incoming body
|
||||||
|
receive = self._insert_body(receive, body=body, scope=scope)
|
||||||
|
|
||||||
|
# The receive channel is converted to a stream for convenient access
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
async def stream() -> t.AsyncGenerator[bytes, None]:
|
||||||
|
more_body = True
|
||||||
|
while more_body:
|
||||||
|
message = await receive()
|
||||||
|
messages.append(message)
|
||||||
|
more_body = message.get("more_body", False)
|
||||||
|
yield message.get("body", b"")
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
# The body is parsed and validated
|
||||||
|
body = await self._parse(stream(), scope=scope)
|
||||||
|
if not (body is None and self._nullable):
|
||||||
|
self._validate(body)
|
||||||
|
|
||||||
|
# If MUTABLE_VALIDATION is enabled, include any changes made during validation in the messages to send
|
||||||
|
if self.MUTABLE_VALIDATION:
|
||||||
|
# Include changes made during validation
|
||||||
|
receive = self._insert_body(receive, body=body, scope=scope)
|
||||||
|
else:
|
||||||
|
# Serialize original messages
|
||||||
|
receive = self._insert_messages(receive, messages=messages)
|
||||||
|
|
||||||
|
return receive
|
||||||
@@ -1,85 +1,57 @@
|
|||||||
import logging
|
import logging
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
|
from jsonschema import ValidationError, draft4_format_checker
|
||||||
from starlette.datastructures import FormData, Headers, UploadFile
|
from starlette.datastructures import Headers, UploadFile
|
||||||
from starlette.formparsers import FormParser, MultiPartParser
|
from starlette.formparsers import FormParser, MultiPartParser
|
||||||
from starlette.types import Receive, Scope
|
from starlette.types import Scope
|
||||||
|
|
||||||
from connexion.exceptions import BadRequestProblem, ExtraParameterProblem
|
from connexion.exceptions import BadRequestProblem, ExtraParameterProblem
|
||||||
from connexion.json_schema import Draft4RequestValidator
|
from connexion.json_schema import Draft4RequestValidator, format_error_with_path
|
||||||
from connexion.uri_parsing import AbstractURIParser
|
from connexion.uri_parsing import AbstractURIParser
|
||||||
from connexion.utils import is_null
|
from connexion.validators import AbstractRequestBodyValidator
|
||||||
|
|
||||||
logger = logging.getLogger("connexion.validators.form_data")
|
logger = logging.getLogger("connexion.validators.form_data")
|
||||||
|
|
||||||
|
|
||||||
class FormDataValidator:
|
class FormDataValidator(AbstractRequestBodyValidator):
|
||||||
"""Request body validator for form content types."""
|
"""Request body validator for form content types."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scope: Scope,
|
|
||||||
receive: Receive,
|
|
||||||
*,
|
*,
|
||||||
schema: dict,
|
schema: dict,
|
||||||
validator: t.Type[Draft4Validator] = None,
|
|
||||||
required=False,
|
required=False,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
uri_parser: t.Optional[AbstractURIParser] = None,
|
|
||||||
strict_validation: bool,
|
strict_validation: bool,
|
||||||
|
uri_parser: t.Optional[AbstractURIParser] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._scope = scope
|
super().__init__(
|
||||||
self._receive = receive
|
schema=schema,
|
||||||
self.schema = schema
|
required=required,
|
||||||
self.has_default = schema.get("default", False)
|
nullable=nullable,
|
||||||
self.nullable = nullable
|
encoding=encoding,
|
||||||
self.required = required
|
strict_validation=strict_validation,
|
||||||
validator_cls = validator or Draft4RequestValidator
|
)
|
||||||
self.validator = validator_cls(schema, format_checker=draft4_format_checker)
|
self._uri_parser = uri_parser
|
||||||
self.uri_parser = uri_parser
|
|
||||||
self.encoding = encoding
|
|
||||||
self._messages: t.List[t.MutableMapping[str, t.Any]] = []
|
|
||||||
self.headers = Headers(scope=scope)
|
|
||||||
self.strict_validation = strict_validation
|
|
||||||
self.check_empty()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def form_parser_cls(self):
|
def _validator(self):
|
||||||
|
return Draft4RequestValidator(
|
||||||
|
self._schema, format_checker=draft4_format_checker
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _form_parser_cls(self):
|
||||||
return FormParser
|
return FormParser
|
||||||
|
|
||||||
def check_empty(self):
|
async def _parse(self, stream: t.AsyncGenerator[bytes, None], scope: Scope) -> dict:
|
||||||
"""`receive` is never called if body is empty, so we need to check this case at
|
headers = Headers(scope=scope)
|
||||||
initialization."""
|
form_parser = self._form_parser_cls(headers, stream)
|
||||||
if not int(self.headers.get("content-length", 0)):
|
data = await form_parser.parse()
|
||||||
# TODO: default should be passed along and content-length updated
|
|
||||||
if self.schema.get("default"):
|
|
||||||
self.validate(self.schema.get("default"))
|
|
||||||
elif self.required: # RequestBody itself is required
|
|
||||||
raise BadRequestProblem("RequestBody is required")
|
|
||||||
elif self.schema.get("required", []): # Required top level properties
|
|
||||||
self._validate({})
|
|
||||||
|
|
||||||
@classmethod
|
if self._uri_parser is not None:
|
||||||
def _error_path_message(cls, exception):
|
|
||||||
error_path = ".".join(str(item) for item in exception.path)
|
|
||||||
error_path_msg = f" - '{error_path}'" if error_path else ""
|
|
||||||
return error_path_msg
|
|
||||||
|
|
||||||
def _validate(self, data: dict) -> None:
|
|
||||||
try:
|
|
||||||
self.validator.validate(data)
|
|
||||||
except ValidationError as exception:
|
|
||||||
error_path_msg = self._error_path_message(exception=exception)
|
|
||||||
logger.error(
|
|
||||||
f"Validation error: {exception.message}{error_path_msg}",
|
|
||||||
extra={"validator": "body"},
|
|
||||||
)
|
|
||||||
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
|
||||||
|
|
||||||
def _parse(self, data: FormData) -> dict:
|
|
||||||
if self.uri_parser is not None:
|
|
||||||
# Don't parse file_data
|
# Don't parse file_data
|
||||||
form_data = {}
|
form_data = {}
|
||||||
file_data = {}
|
file_data = {}
|
||||||
@@ -90,7 +62,7 @@ class FormDataValidator:
|
|||||||
# Replace files with empty strings for validation
|
# Replace files with empty strings for validation
|
||||||
file_data[k] = ""
|
file_data[k] = ""
|
||||||
|
|
||||||
data = self.uri_parser.resolve_form(form_data)
|
data = self._uri_parser.resolve_form(form_data)
|
||||||
# Add the files again
|
# Add the files again
|
||||||
data.update(file_data)
|
data.update(file_data)
|
||||||
else:
|
else:
|
||||||
@@ -98,45 +70,29 @@ class FormDataValidator:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _validate_strictly(self, data: FormData) -> None:
|
def _validate(self, data: dict) -> None:
|
||||||
|
if self._strict_validation:
|
||||||
|
self._validate_params_strictly(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._validator.validate(data)
|
||||||
|
except ValidationError as exception:
|
||||||
|
error_path_msg = format_error_with_path(exception=exception)
|
||||||
|
logger.error(
|
||||||
|
f"Validation error: {exception.message}{error_path_msg}",
|
||||||
|
extra={"validator": "body"},
|
||||||
|
)
|
||||||
|
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
||||||
|
|
||||||
|
def _validate_params_strictly(self, data: dict) -> None:
|
||||||
form_params = data.keys()
|
form_params = data.keys()
|
||||||
spec_params = self.schema.get("properties", {}).keys()
|
spec_params = self._schema.get("properties", {}).keys()
|
||||||
errors = set(form_params).difference(set(spec_params))
|
errors = set(form_params).difference(set(spec_params))
|
||||||
if errors:
|
if errors:
|
||||||
raise ExtraParameterProblem(param_type="formData", extra_params=errors)
|
raise ExtraParameterProblem(param_type="formData", extra_params=errors)
|
||||||
|
|
||||||
def validate(self, data: FormData) -> None:
|
|
||||||
if self.strict_validation:
|
|
||||||
self._validate_strictly(data)
|
|
||||||
|
|
||||||
data = self._parse(data)
|
|
||||||
self._validate(data)
|
|
||||||
|
|
||||||
async def wrapped_receive(self) -> Receive:
|
|
||||||
async def stream() -> t.AsyncGenerator[bytes, None]:
|
|
||||||
more_body = True
|
|
||||||
while more_body:
|
|
||||||
message = await self._receive()
|
|
||||||
self._messages.append(message)
|
|
||||||
more_body = message.get("more_body", False)
|
|
||||||
yield message.get("body", b"")
|
|
||||||
yield b""
|
|
||||||
|
|
||||||
form_parser = self.form_parser_cls(self.headers, stream())
|
|
||||||
form = await form_parser.parse()
|
|
||||||
|
|
||||||
if form and not (self.nullable and is_null(form)):
|
|
||||||
self.validate(form)
|
|
||||||
|
|
||||||
async def receive() -> t.MutableMapping[str, t.Any]:
|
|
||||||
while self._messages:
|
|
||||||
return self._messages.pop(0)
|
|
||||||
return await self._receive()
|
|
||||||
|
|
||||||
return receive
|
|
||||||
|
|
||||||
|
|
||||||
class MultiPartFormDataValidator(FormDataValidator):
|
class MultiPartFormDataValidator(FormDataValidator):
|
||||||
@property
|
@property
|
||||||
def form_parser_cls(self):
|
def _form_parser_cls(self):
|
||||||
return MultiPartParser
|
return MultiPartParser
|
||||||
|
|||||||
@@ -4,107 +4,83 @@ import typing as t
|
|||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
|
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
|
||||||
from starlette.datastructures import Headers
|
from starlette.types import Scope, Send
|
||||||
from starlette.types import Receive, Scope, Send
|
|
||||||
|
|
||||||
from connexion.exceptions import BadRequestProblem, NonConformingResponseBody
|
from connexion.exceptions import BadRequestProblem, NonConformingResponseBody
|
||||||
from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator
|
from connexion.json_schema import (
|
||||||
from connexion.utils import is_null
|
Draft4RequestValidator,
|
||||||
|
Draft4ResponseValidator,
|
||||||
|
format_error_with_path,
|
||||||
|
)
|
||||||
|
from connexion.validators import AbstractRequestBodyValidator
|
||||||
|
|
||||||
logger = logging.getLogger("connexion.validators.json")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class JSONRequestBodyValidator:
|
class JSONRequestBodyValidator(AbstractRequestBodyValidator):
|
||||||
"""Request body validator for json content types."""
|
"""Request body validator for json content types."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scope: Scope,
|
|
||||||
receive: Receive,
|
|
||||||
*,
|
*,
|
||||||
schema: dict,
|
schema: dict,
|
||||||
validator: t.Type[Draft4Validator] = Draft4RequestValidator,
|
|
||||||
required=False,
|
required=False,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
encoding: str,
|
encoding: str,
|
||||||
|
strict_validation: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._scope = scope
|
super().__init__(
|
||||||
self._receive = receive
|
schema=schema,
|
||||||
self.schema = schema
|
required=required,
|
||||||
self.has_default = schema.get("default", False)
|
nullable=nullable,
|
||||||
self.nullable = nullable
|
encoding=encoding,
|
||||||
self.required = required
|
strict_validation=strict_validation,
|
||||||
self.validator = validator(schema, format_checker=draft4_format_checker)
|
)
|
||||||
self.encoding = encoding
|
|
||||||
self.headers = Headers(scope=scope)
|
|
||||||
self.check_empty()
|
|
||||||
|
|
||||||
def check_empty(self):
|
@property
|
||||||
"""receive` is never called if body is empty, so we need to check this case at
|
def _validator(self):
|
||||||
initialization."""
|
return Draft4RequestValidator(
|
||||||
if not int(self.headers.get("content-length", 0)):
|
self._schema, format_checker=draft4_format_checker
|
||||||
# TODO: default should be passed along and content-length updated
|
)
|
||||||
if self.schema.get("default"):
|
|
||||||
self.validate(self.schema.get("default"))
|
|
||||||
elif self.required: # RequestBody itself is required
|
|
||||||
raise BadRequestProblem("RequestBody is required")
|
|
||||||
elif self.schema.get("required", []): # Required top level properties
|
|
||||||
self.validate({})
|
|
||||||
|
|
||||||
@classmethod
|
async def _parse(
|
||||||
def _error_path_message(cls, exception):
|
self, stream: t.AsyncGenerator[bytes, None], scope: Scope
|
||||||
error_path = ".".join(str(item) for item in exception.path)
|
) -> t.Any:
|
||||||
error_path_msg = f" - '{error_path}'" if error_path else ""
|
bytes_body = b"".join([message async for message in stream])
|
||||||
return error_path_msg
|
body = bytes_body.decode(self._encoding)
|
||||||
|
|
||||||
|
if not body:
|
||||||
|
return None
|
||||||
|
|
||||||
def validate(self, body: dict):
|
|
||||||
try:
|
try:
|
||||||
self.validator.validate(body)
|
return json.loads(body)
|
||||||
|
except json.decoder.JSONDecodeError as e:
|
||||||
|
raise BadRequestProblem(detail=str(e))
|
||||||
|
|
||||||
|
def _validate(self, body: dict) -> None:
|
||||||
|
try:
|
||||||
|
return self._validator.validate(body)
|
||||||
except ValidationError as exception:
|
except ValidationError as exception:
|
||||||
error_path_msg = self._error_path_message(exception=exception)
|
error_path_msg = format_error_with_path(exception=exception)
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Validation error: {exception.message}{error_path_msg}",
|
f"Validation error: {exception.message}{error_path_msg}",
|
||||||
extra={"validator": "body"},
|
extra={"validator": "body"},
|
||||||
)
|
)
|
||||||
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
||||||
|
|
||||||
def parse(self, body: str) -> dict:
|
|
||||||
try:
|
|
||||||
return json.loads(body)
|
|
||||||
except json.decoder.JSONDecodeError as e:
|
|
||||||
raise BadRequestProblem(str(e))
|
|
||||||
|
|
||||||
async def wrapped_receive(self) -> Receive:
|
|
||||||
more_body = True
|
|
||||||
messages = []
|
|
||||||
while more_body:
|
|
||||||
message = await self._receive()
|
|
||||||
messages.append(message)
|
|
||||||
more_body = message.get("more_body", False)
|
|
||||||
|
|
||||||
bytes_body = b"".join([message.get("body", b"") for message in messages])
|
|
||||||
decoded_body = bytes_body.decode(self.encoding)
|
|
||||||
|
|
||||||
if decoded_body and not (self.nullable and is_null(decoded_body)):
|
|
||||||
body = self.parse(decoded_body)
|
|
||||||
self.validate(body)
|
|
||||||
|
|
||||||
async def receive() -> t.MutableMapping[str, t.Any]:
|
|
||||||
while messages:
|
|
||||||
return messages.pop(0)
|
|
||||||
return await self._receive()
|
|
||||||
|
|
||||||
return receive
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator):
|
class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator):
|
||||||
"""Request body validator for json content types which fills in default values. This Validator
|
"""Request body validator for json content types which fills in default values. This Validator
|
||||||
intercepts the body, makes changes to it, and replays it for the next ASGI application."""
|
intercepts the body, makes changes to it, and replays it for the next ASGI application."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
MUTABLE_VALIDATION = True
|
||||||
defaults_validator = self.extend_with_set_default(Draft4RequestValidator)
|
"""This validator might mutate to the body."""
|
||||||
super().__init__(*args, validator=defaults_validator, **kwargs)
|
|
||||||
|
@property
|
||||||
|
def _validator(self):
|
||||||
|
validator_cls = self.extend_with_set_default(Draft4RequestValidator)
|
||||||
|
return validator_cls(self._schema, format_checker=draft4_format_checker)
|
||||||
|
|
||||||
# via https://python-jsonschema.readthedocs.io/
|
# via https://python-jsonschema.readthedocs.io/
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -122,58 +98,6 @@ class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator):
|
|||||||
validator_class, {"properties": set_defaults}
|
validator_class, {"properties": set_defaults}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def read_body(self) -> t.Tuple[str, int]:
|
|
||||||
"""Read the body from the receive channel.
|
|
||||||
|
|
||||||
:return: A tuple (body, max_length) where max_length is the length of the largest message.
|
|
||||||
"""
|
|
||||||
more_body = True
|
|
||||||
max_length = 256000
|
|
||||||
messages = []
|
|
||||||
while more_body:
|
|
||||||
message = await self._receive()
|
|
||||||
max_length = max(max_length, len(message.get("body", b"")))
|
|
||||||
messages.append(message)
|
|
||||||
more_body = message.get("more_body", False)
|
|
||||||
|
|
||||||
bytes_body = b"".join([message.get("body", b"") for message in messages])
|
|
||||||
|
|
||||||
return bytes_body.decode(self.encoding), max_length
|
|
||||||
|
|
||||||
async def wrapped_receive(self) -> Receive:
|
|
||||||
"""Receive channel to pass on to next ASGI application."""
|
|
||||||
decoded_body, max_length = await self.read_body()
|
|
||||||
|
|
||||||
# Validate the body if not null
|
|
||||||
if decoded_body and not (self.nullable and is_null(decoded_body)):
|
|
||||||
body = self.parse(decoded_body)
|
|
||||||
del decoded_body
|
|
||||||
self.validate(body)
|
|
||||||
str_body = json.dumps(body)
|
|
||||||
else:
|
|
||||||
str_body = decoded_body
|
|
||||||
|
|
||||||
bytes_body = str_body.encode(self.encoding)
|
|
||||||
del str_body
|
|
||||||
|
|
||||||
# Recreate ASGI messages from validated body so changes made by the validator are propagated
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"type": "http.request",
|
|
||||||
"body": bytes_body[i : i + max_length],
|
|
||||||
"more_body": i + max_length < len(bytes_body),
|
|
||||||
}
|
|
||||||
for i in range(0, len(bytes_body), max_length)
|
|
||||||
]
|
|
||||||
del bytes_body
|
|
||||||
|
|
||||||
async def receive() -> t.MutableMapping[str, t.Any]:
|
|
||||||
while messages:
|
|
||||||
return messages.pop(0)
|
|
||||||
return await self._receive()
|
|
||||||
|
|
||||||
return receive
|
|
||||||
|
|
||||||
|
|
||||||
class JSONResponseBodyValidator:
|
class JSONResponseBodyValidator:
|
||||||
"""Response body validator for json content types."""
|
"""Response body validator for json content types."""
|
||||||
|
|||||||
@@ -314,7 +314,10 @@ def test_mixed_formdata(simple_app):
|
|||||||
|
|
||||||
def test_formdata_file_upload_bad_request(simple_app):
|
def test_formdata_file_upload_bad_request(simple_app):
|
||||||
app_client = simple_app.test_client()
|
app_client = simple_app.test_client()
|
||||||
resp = app_client.post("/v1.0/test-formData-file-upload")
|
resp = app_client.post(
|
||||||
|
"/v1.0/test-formData-file-upload",
|
||||||
|
headers={"Content-Type": b"multipart/form-data; boundary=-"},
|
||||||
|
)
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
assert resp.json()["detail"] in [
|
assert resp.json()["detail"] in [
|
||||||
"Missing formdata parameter 'fileData'",
|
"Missing formdata parameter 'fileData'",
|
||||||
@@ -443,11 +446,8 @@ def test_nullable_parameter(simple_app):
|
|||||||
resp = app_client.put("/v1.0/nullable-parameters", content="null", headers=headers)
|
resp = app_client.put("/v1.0/nullable-parameters", content="null", headers=headers)
|
||||||
assert resp.json() == "it was None"
|
assert resp.json() == "it was None"
|
||||||
|
|
||||||
resp = app_client.put("/v1.0/nullable-parameters", content="None", headers=headers)
|
|
||||||
assert resp.json() == "it was None"
|
|
||||||
|
|
||||||
resp = app_client.put(
|
resp = app_client.put(
|
||||||
"/v1.0/nullable-parameters-noargs", content="None", headers=headers
|
"/v1.0/nullable-parameters-noargs", content="null", headers=headers
|
||||||
)
|
)
|
||||||
assert resp.json() == "hello"
|
assert resp.json() == "hello"
|
||||||
|
|
||||||
|
|||||||
@@ -28,8 +28,9 @@ def test_validator_map(json_validation_spec_dir, spec):
|
|||||||
MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type})
|
MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type})
|
||||||
|
|
||||||
class MyJSONBodyValidator(JSONRequestBodyValidator):
|
class MyJSONBodyValidator(JSONRequestBodyValidator):
|
||||||
def __init__(self, *args, **kwargs):
|
@property
|
||||||
super().__init__(*args, validator=MinLengthRequestValidator, **kwargs)
|
def _validator(self):
|
||||||
|
return MinLengthRequestValidator(self._schema)
|
||||||
|
|
||||||
validator_map = {"body": {"application/json": MyJSONBodyValidator}}
|
validator_map = {"body": {"application/json": MyJSONBodyValidator}}
|
||||||
|
|
||||||
|
|||||||
@@ -749,6 +749,7 @@ def test_form_transformation(api):
|
|||||||
"param": {
|
"param": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"format": "email",
|
"format": "email",
|
||||||
|
"default": "foo@bar.com",
|
||||||
},
|
},
|
||||||
"array_param": {
|
"array_param": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
|||||||
Reference in New Issue
Block a user