Extract JSON request body validation to middleware (#1588)

* Set up code skeleton for validation middleware

* Add more boilerplate code

* WIP

* Add ASGI JSONBodyValidator

* Revert example changes

* Remove incorrect content type test

Co-authored-by: Ruwan <ruwanlambrichts@gmail.com>
This commit is contained in:
Robbe Sneyders
2022-09-18 10:55:16 +02:00
committed by GitHub
parent e4b7827b6d
commit fb071ea56f
14 changed files with 2539 additions and 77 deletions

View File

@@ -12,15 +12,11 @@ from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
from jsonschema.validators import extend
from werkzeug.datastructures import FileStorage
from ..exceptions import (
BadRequestProblem,
ExtraParameterProblem,
UnsupportedMediaTypeProblem,
)
from ..exceptions import BadRequestProblem, ExtraParameterProblem
from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
from ..lifecycle import ConnexionResponse
from ..utils import all_json, boolean, is_json_mimetype, is_null, is_nullable
from ..utils import boolean, is_null, is_nullable
logger = logging.getLogger("connexion.decorators.validation")
@@ -141,33 +137,7 @@ class RequestBodyValidator:
@functools.wraps(function)
def wrapper(request):
if all_json(self.consumes):
data = request.json
empty_body = not (request.body or request.form or request.files)
if data is None and not empty_body and not self.is_null_value_valid:
try:
ctype_is_json = is_json_mimetype(
request.headers.get("Content-Type", "")
)
except ValueError:
ctype_is_json = False
if ctype_is_json:
# Content-Type is json but actual body was not parsed
raise BadRequestProblem(detail="Request body is not valid JSON")
else:
# the body has contents that were not parsed as JSON
raise UnsupportedMediaTypeProblem(
detail="Invalid Content-type ({content_type}), expected JSON data".format(
content_type=request.headers.get("Content-Type", "")
)
)
logger.debug("%s validating schema...", request.url)
if data is not None or not self.has_default:
self.validate_schema(data, request.url)
elif self.consumes[0] in FORM_CONTENT_TYPES:
if self.consumes[0] in FORM_CONTENT_TYPES:
data = dict(request.form.items()) or (
request.body if len(request.body) > 0 else {}
)

View File

@@ -8,6 +8,7 @@ from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.middleware.validation import ValidationMiddleware
class ConnexionMiddleware:
@@ -17,6 +18,7 @@ class ConnexionMiddleware:
SwaggerUIMiddleware,
RoutingMiddleware,
SecurityMiddleware,
ValidationMiddleware,
]
def __init__(

View File

@@ -6,7 +6,6 @@ from starlette.routing import Router
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractRoutingAPI
from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
@@ -61,10 +60,7 @@ class RoutingMiddleware(AppMiddleware):
# Needs to be set so starlette router throws exceptions instead of returning error responses
scope["app"] = self
try:
await self.router(scope, receive, send)
except ValueError:
raise NotFoundProblem
class RoutingAPI(AbstractRoutingAPI):

View File

@@ -142,6 +142,7 @@ class SecurityOperation:
operation: t.Union[AbstractOperation, Specification],
security_handler_factory: SecurityHandlerFactory,
):
# TODO: Turn Operation class into OperationSpec and use as init argument instead
return cls(
security_handler_factory,
security=operation.security,

View File

@@ -0,0 +1,232 @@
"""
Validation Middleware.
"""
import logging
import pathlib
import typing as t
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis.abstract import AbstractSpecAPI
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import MissingMiddleware, UnsupportedMediaTypeProblem
from connexion.http_facts import METHODS
from connexion.middleware import AppMiddleware
from connexion.middleware.routing import ROUTING_CONTEXT
from connexion.operations import AbstractOperation
from connexion.resolver import ResolverError
from connexion.utils import is_nullable
from connexion.validators import JSONBodyValidator
from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator
logger = logging.getLogger("connexion.middleware.validation")
VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {"application/json": JSONBodyValidator},
"response": ResponseValidator,
}
class ValidationMiddleware(AppMiddleware):
"""Middleware for validating requests according to the API contract."""
def __init__(self, app: ASGIApp) -> None:
self.app = app
self.apis: t.Dict[str, ValidationAPI] = {}
def add_api(
self, specification: t.Union[pathlib.Path, str, dict], **kwargs
) -> None:
api = ValidationAPI(specification, next_app=self.app, **kwargs)
self.apis[api.base_path] = api
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
try:
connexion_context = scope["extensions"][ROUTING_CONTEXT]
except KeyError:
raise MissingMiddleware(
"Could not find routing information in scope. Please make sure "
"you have a routing middleware registered upstream. "
)
api_base_path = connexion_context.get("api_base_path")
if api_base_path:
api = self.apis[api_base_path]
operation_id = connexion_context.get("operation_id")
try:
operation = api.operations[operation_id]
except KeyError as e:
if operation_id is None:
logger.debug("Skipping validation check for operation without id.")
await self.app(scope, receive, send)
return
else:
raise MissingValidationOperation(
"Encountered unknown operation_id."
) from e
else:
return await operation(scope, receive, send)
await self.app(scope, receive, send)
class ValidationAPI(AbstractSpecAPI):
"""Validation API."""
def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
*args,
next_app: ASGIApp,
validate_responses=False,
strict_validation=False,
validator_map=None,
uri_parser_class=None,
**kwargs,
):
super().__init__(specification, *args, **kwargs)
self.next_app = next_app
self.validator_map = validator_map
logger.debug("Validate Responses: %s", str(validate_responses))
self.validate_responses = validate_responses
logger.debug("Strict Request Validation: %s", str(strict_validation))
self.strict_validation = strict_validation
self.uri_parser_class = uri_parser_class
self.operations: t.Dict[str, ValidationOperation] = {}
self.add_paths()
def add_paths(self):
paths = self.specification.get("paths", {})
for path, methods in paths.items():
for method in methods:
if method not in METHODS:
continue
try:
self.add_operation(path, method)
except ResolverError:
# ResolverErrors are either raised or handled in routing middleware.
pass
def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(
self.specification, self, path, method, self.resolver
)
validation_operation = self.make_operation(operation)
self._add_operation_internal(operation.operation_id, validation_operation)
def make_operation(self, operation: AbstractOperation):
return ValidationOperation(
operation,
self.next_app,
validate_responses=self.validate_responses,
strict_validation=self.strict_validation,
validator_map=self.validator_map,
uri_parser_class=self.uri_parser_class,
)
def _add_operation_internal(
self, operation_id: str, operation: "ValidationOperation"
):
self.operations[operation_id] = operation
class ValidationOperation:
def __init__(
self,
operation: AbstractOperation,
next_app: ASGIApp,
validate_responses: bool = False,
strict_validation: bool = False,
validator_map: t.Optional[dict] = None,
uri_parser_class: t.Optional[AbstractURIParser] = None,
) -> None:
self._operation = operation
self.next_app = next_app
self.validate_responses = validate_responses
self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP
self._validator_map.update(validator_map or {})
self.uri_parser_class = uri_parser_class
def extract_content_type(self, headers: dict) -> t.Tuple[str, str]:
"""Extract the mime type and encoding from the content type headers.
:param headers: Header dict from ASGI scope
:return: A tuple of mime type, encoding
"""
encoding = "utf-8"
for key, value in headers:
# Headers can always be decoded using latin-1:
# https://stackoverflow.com/a/27357138/4098821
key = key.decode("latin-1")
if key.lower() == "content-type":
content_type = value.decode("latin-1")
if ";" in content_type:
mime_type, parameters = content_type.split(";", maxsplit=1)
prefix = "charset="
for parameter in parameters.split(";"):
if parameter.startswith(prefix):
encoding = parameter[len(prefix) :]
else:
mime_type = content_type
break
else:
# Content-type header is not required. Take a best guess.
mime_type = self._operation.consumes[0]
return mime_type, encoding
def validate_mime_type(self, mime_type: str) -> None:
"""Validate the mime type against the spec.
:param mime_type: mime type from content type header
"""
if mime_type.lower() not in [c.lower() for c in self._operation.consumes]:
raise UnsupportedMediaTypeProblem(
detail=f"Invalid Content-type ({mime_type}), "
f"expected {self._operation.consumes}"
)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
headers = scope["headers"]
mime_type, encoding = self.extract_content_type(headers)
self.validate_mime_type(mime_type)
# TODO: Validate parameters
# Validate body
try:
body_validator = self._validator_map["body"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
self.next_app,
schema=self._operation.body_schema,
nullable=is_nullable(self._operation.body_definition),
encoding=encoding,
)
return await validator(scope, receive, send)
await self.next_app(scope, receive, send)
class MissingValidationOperation(Exception):
"""Missing validation operation"""

View File

@@ -465,12 +465,12 @@ class AbstractOperation(metaclass=abc.ABCMeta):
:rtype: types.FunctionType
"""
ParameterValidator = self.validator_map["parameter"]
RequestBodyValidator = self.validator_map["body"]
if self.parameters:
yield ParameterValidator(
self.parameters, self.api, strict_validation=self.strict_validation
)
if self.body_schema:
# TODO: temporarily hardcoded, remove RequestBodyValidator completely
yield RequestBodyValidator(
self.body_schema,
self.consumes,

87
connexion/validators.py Normal file
View File

@@ -0,0 +1,87 @@
"""
Contains validator classes used by the validation middleware.
"""
import json
import logging
import typing as t
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.exceptions import BadRequestProblem
from connexion.json_schema import Draft4RequestValidator
from connexion.utils import is_null
logger = logging.getLogger("connexion.middleware.validators")
class JSONBodyValidator:
"""Request body validator for json content types."""
def __init__(
self,
next_app: ASGIApp,
*,
schema: dict,
validator: t.Type[Draft4Validator] = None,
nullable=False,
encoding: str,
) -> None:
self.next_app = next_app
self.schema = schema
self.has_default = schema.get("default", False)
self.nullable = nullable
self.validator_cls = validator or Draft4RequestValidator
self.validator = self.validator_cls(
schema, format_checker=draft4_format_checker
)
self.encoding = encoding
@classmethod
def _error_path_message(cls, exception):
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg
def validate(self, body: dict):
try:
self.validator.validate(body)
except ValidationError as exception:
error_path_msg = self._error_path_message(exception=exception)
logger.error(
f"Validation error: {exception.message}{error_path_msg}",
extra={"validator": "body"},
)
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Based on https://github.com/encode/starlette/pull/1519#issuecomment-1060633787
# Ingest all body messages from the ASGI `receive` callable.
messages = []
more_body = True
while more_body:
message = await receive()
messages.append(message)
more_body = message.get("more_body", False)
# TODO: make json library pluggable
bytes_body = b"".join([message.get("body", b"") for message in messages])
decoded_body = bytes_body.decode(self.encoding)
if decoded_body and not (self.nullable and is_null(decoded_body)):
try:
body = json.loads(decoded_body)
except json.decoder.JSONDecodeError as e:
raise BadRequestProblem(str(e))
self.validate(body)
async def wrapped_receive():
# First up we want to return any messages we've stashed.
if messages:
return messages.pop(0)
# Once that's done we can just await any other messages.
return await receive()
await self.next_app(scope, wrapped_receive, send)

File diff suppressed because it is too large Load Diff

View File

@@ -95,8 +95,7 @@ def test_errors(problem_app):
)
assert unsupported_media_type_body["type"] == "about:blank"
assert unsupported_media_type_body["title"] == "Unsupported Media Type"
assert (
unsupported_media_type_body["detail"]
== "Invalid Content-type (text/html), expected JSON data"
assert unsupported_media_type_body["detail"].startswith(
"Invalid Content-type (text/html)"
)
assert unsupported_media_type_body["status"] == 415

View File

@@ -356,11 +356,6 @@ def test_post_wrong_content_type(simple_app):
)
assert resp.status_code == 415
resp = app_client.post(
"/v1.0/post_wrong_content_type", data=json.dumps({"some": "data"})
)
assert resp.status_code == 415
resp = app_client.post(
"/v1.0/post_wrong_content_type",
content_type="application/x-www-form-urlencoded",
@@ -368,31 +363,6 @@ def test_post_wrong_content_type(simple_app):
)
assert resp.status_code == 415
# this test checks exactly what the test directly above is supposed to check,
# i.e. no content-type is provided in the header
# unfortunately there is an issue with the werkzeug test environment
# (https://github.com/pallets/werkzeug/issues/1159)
# so that content-type is added to every request, we remove it here manually for our test
# this test can be removed once the werkzeug issue is addressed
builder = EnvironBuilder(
path="/v1.0/post_wrong_content_type",
method="POST",
data=json.dumps({"some": "data"}),
)
try:
environ = builder.get_environ()
finally:
builder.close()
content_type = "CONTENT_TYPE"
if content_type in environ:
environ.pop("CONTENT_TYPE")
# we cannot just call app_client.open() since app_client is a flask.testing.FlaskClient
# which overrides werkzeug.test.Client.open() but does not allow passing an environment
# directly
resp = Client.open(app_client, environ)
assert resp.status_code == 415
resp = app_client.post(
"/v1.0/post_wrong_content_type",
content_type="application/json",

View File

@@ -8,7 +8,7 @@ from connexion.resolver import MethodResolver, MethodViewResolver
from connexion.security import SecurityHandlerFactory
from werkzeug.test import Client, EnvironBuilder
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
TEST_FOLDER = pathlib.Path(__file__).parent
FIXTURES_FOLDER = TEST_FOLDER / "fixtures"

View File

@@ -980,7 +980,7 @@ paths:
type: object
requestBody:
content:
multipart/form-data:
application/x-www-form-urlencoded:
schema:
type: object
properties:

View File

@@ -437,6 +437,8 @@ paths:
/test-formData-missing-param:
post:
consumes:
- application/x-www-form-urlencoded
summary: Test formData missing parameter in handler
operationId: fakeapi.hello.test_formdata_missing_param
parameters:
@@ -804,7 +806,7 @@ paths:
post:
operationId: fakeapi.hello.test_param_sanitization
consumes:
- multipart/form-data
- application/x-www-form-urlencoded
produces:
- application/json
parameters:

View File

@@ -6,6 +6,7 @@ from connexion import App
from connexion.decorators.validation import RequestBodyValidator
from connexion.json_schema import Draft4RequestValidator
from connexion.spec import Specification
from connexion.validators import JSONBodyValidator
from jsonschema.validators import _utils, extend
from conftest import build_app_from_fixture
@@ -30,11 +31,11 @@ def test_validator_map(json_validation_spec_dir, spec):
MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type})
class MyRequestBodyValidator(RequestBodyValidator):
class MyJSONBodyValidator(JSONBodyValidator):
def __init__(self, *args, **kwargs):
super().__init__(*args, validator=MinLengthRequestValidator, **kwargs)
validator_map = {"body": MyRequestBodyValidator}
validator_map = {"body": {"application/json": MyJSONBodyValidator}}
app = App(__name__, specification_dir=json_validation_spec_dir)
app.add_api(spec, validate_responses=True, validator_map=validator_map)