mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-09 20:37:46 +00:00
Extract JSON request body validation to middleware (#1588)
* Set up code skeleton for validation middleware * Add more boilerplate code * WIP * Add ASGI JSONBodyValidator * Revert example changes * Remove incorrect content type test Co-authored-by: Ruwan <ruwanlambrichts@gmail.com>
This commit is contained in:
@@ -12,15 +12,11 @@ from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
|
||||
from jsonschema.validators import extend
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from ..exceptions import (
|
||||
BadRequestProblem,
|
||||
ExtraParameterProblem,
|
||||
UnsupportedMediaTypeProblem,
|
||||
)
|
||||
from ..exceptions import BadRequestProblem, ExtraParameterProblem
|
||||
from ..http_facts import FORM_CONTENT_TYPES
|
||||
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
|
||||
from ..lifecycle import ConnexionResponse
|
||||
from ..utils import all_json, boolean, is_json_mimetype, is_null, is_nullable
|
||||
from ..utils import boolean, is_null, is_nullable
|
||||
|
||||
logger = logging.getLogger("connexion.decorators.validation")
|
||||
|
||||
@@ -141,33 +137,7 @@ class RequestBodyValidator:
|
||||
|
||||
@functools.wraps(function)
|
||||
def wrapper(request):
|
||||
if all_json(self.consumes):
|
||||
data = request.json
|
||||
|
||||
empty_body = not (request.body or request.form or request.files)
|
||||
if data is None and not empty_body and not self.is_null_value_valid:
|
||||
try:
|
||||
ctype_is_json = is_json_mimetype(
|
||||
request.headers.get("Content-Type", "")
|
||||
)
|
||||
except ValueError:
|
||||
ctype_is_json = False
|
||||
|
||||
if ctype_is_json:
|
||||
# Content-Type is json but actual body was not parsed
|
||||
raise BadRequestProblem(detail="Request body is not valid JSON")
|
||||
else:
|
||||
# the body has contents that were not parsed as JSON
|
||||
raise UnsupportedMediaTypeProblem(
|
||||
detail="Invalid Content-type ({content_type}), expected JSON data".format(
|
||||
content_type=request.headers.get("Content-Type", "")
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("%s validating schema...", request.url)
|
||||
if data is not None or not self.has_default:
|
||||
self.validate_schema(data, request.url)
|
||||
elif self.consumes[0] in FORM_CONTENT_TYPES:
|
||||
if self.consumes[0] in FORM_CONTENT_TYPES:
|
||||
data = dict(request.form.items()) or (
|
||||
request.body if len(request.body) > 0 else {}
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from connexion.middleware.exceptions import ExceptionMiddleware
|
||||
from connexion.middleware.routing import RoutingMiddleware
|
||||
from connexion.middleware.security import SecurityMiddleware
|
||||
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
|
||||
from connexion.middleware.validation import ValidationMiddleware
|
||||
|
||||
|
||||
class ConnexionMiddleware:
|
||||
@@ -17,6 +18,7 @@ class ConnexionMiddleware:
|
||||
SwaggerUIMiddleware,
|
||||
RoutingMiddleware,
|
||||
SecurityMiddleware,
|
||||
ValidationMiddleware,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -6,7 +6,6 @@ from starlette.routing import Router
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.apis import AbstractRoutingAPI
|
||||
from connexion.exceptions import NotFoundProblem
|
||||
from connexion.middleware import AppMiddleware
|
||||
from connexion.operations import AbstractOperation
|
||||
from connexion.resolver import Resolver
|
||||
@@ -61,10 +60,7 @@ class RoutingMiddleware(AppMiddleware):
|
||||
|
||||
# Needs to be set so starlette router throws exceptions instead of returning error responses
|
||||
scope["app"] = self
|
||||
try:
|
||||
await self.router(scope, receive, send)
|
||||
except ValueError:
|
||||
raise NotFoundProblem
|
||||
|
||||
|
||||
class RoutingAPI(AbstractRoutingAPI):
|
||||
|
||||
@@ -142,6 +142,7 @@ class SecurityOperation:
|
||||
operation: t.Union[AbstractOperation, Specification],
|
||||
security_handler_factory: SecurityHandlerFactory,
|
||||
):
|
||||
# TODO: Turn Operation class into OperationSpec and use as init argument instead
|
||||
return cls(
|
||||
security_handler_factory,
|
||||
security=operation.security,
|
||||
|
||||
232
connexion/middleware/validation.py
Normal file
232
connexion/middleware/validation.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Validation Middleware.
|
||||
"""
|
||||
import logging
|
||||
import pathlib
|
||||
import typing as t
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.apis.abstract import AbstractSpecAPI
|
||||
from connexion.decorators.uri_parsing import AbstractURIParser
|
||||
from connexion.exceptions import MissingMiddleware, UnsupportedMediaTypeProblem
|
||||
from connexion.http_facts import METHODS
|
||||
from connexion.middleware import AppMiddleware
|
||||
from connexion.middleware.routing import ROUTING_CONTEXT
|
||||
from connexion.operations import AbstractOperation
|
||||
from connexion.resolver import ResolverError
|
||||
from connexion.utils import is_nullable
|
||||
from connexion.validators import JSONBodyValidator
|
||||
|
||||
from ..decorators.response import ResponseValidator
|
||||
from ..decorators.validation import ParameterValidator
|
||||
|
||||
logger = logging.getLogger("connexion.middleware.validation")
|
||||
|
||||
VALIDATOR_MAP = {
|
||||
"parameter": ParameterValidator,
|
||||
"body": {"application/json": JSONBodyValidator},
|
||||
"response": ResponseValidator,
|
||||
}
|
||||
|
||||
|
||||
class ValidationMiddleware(AppMiddleware):
|
||||
"""Middleware for validating requests according to the API contract."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
self.apis: t.Dict[str, ValidationAPI] = {}
|
||||
|
||||
def add_api(
|
||||
self, specification: t.Union[pathlib.Path, str, dict], **kwargs
|
||||
) -> None:
|
||||
api = ValidationAPI(specification, next_app=self.app, **kwargs)
|
||||
self.apis[api.base_path] = api
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
try:
|
||||
connexion_context = scope["extensions"][ROUTING_CONTEXT]
|
||||
except KeyError:
|
||||
raise MissingMiddleware(
|
||||
"Could not find routing information in scope. Please make sure "
|
||||
"you have a routing middleware registered upstream. "
|
||||
)
|
||||
api_base_path = connexion_context.get("api_base_path")
|
||||
if api_base_path:
|
||||
api = self.apis[api_base_path]
|
||||
operation_id = connexion_context.get("operation_id")
|
||||
try:
|
||||
operation = api.operations[operation_id]
|
||||
except KeyError as e:
|
||||
if operation_id is None:
|
||||
logger.debug("Skipping validation check for operation without id.")
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
else:
|
||||
raise MissingValidationOperation(
|
||||
"Encountered unknown operation_id."
|
||||
) from e
|
||||
else:
|
||||
return await operation(scope, receive, send)
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class ValidationAPI(AbstractSpecAPI):
|
||||
"""Validation API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specification: t.Union[pathlib.Path, str, dict],
|
||||
*args,
|
||||
next_app: ASGIApp,
|
||||
validate_responses=False,
|
||||
strict_validation=False,
|
||||
validator_map=None,
|
||||
uri_parser_class=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(specification, *args, **kwargs)
|
||||
self.next_app = next_app
|
||||
|
||||
self.validator_map = validator_map
|
||||
|
||||
logger.debug("Validate Responses: %s", str(validate_responses))
|
||||
self.validate_responses = validate_responses
|
||||
|
||||
logger.debug("Strict Request Validation: %s", str(strict_validation))
|
||||
self.strict_validation = strict_validation
|
||||
|
||||
self.uri_parser_class = uri_parser_class
|
||||
|
||||
self.operations: t.Dict[str, ValidationOperation] = {}
|
||||
self.add_paths()
|
||||
|
||||
def add_paths(self):
|
||||
paths = self.specification.get("paths", {})
|
||||
for path, methods in paths.items():
|
||||
for method in methods:
|
||||
if method not in METHODS:
|
||||
continue
|
||||
try:
|
||||
self.add_operation(path, method)
|
||||
except ResolverError:
|
||||
# ResolverErrors are either raised or handled in routing middleware.
|
||||
pass
|
||||
|
||||
def add_operation(self, path: str, method: str) -> None:
|
||||
operation_cls = self.specification.operation_cls
|
||||
operation = operation_cls.from_spec(
|
||||
self.specification, self, path, method, self.resolver
|
||||
)
|
||||
validation_operation = self.make_operation(operation)
|
||||
self._add_operation_internal(operation.operation_id, validation_operation)
|
||||
|
||||
def make_operation(self, operation: AbstractOperation):
|
||||
return ValidationOperation(
|
||||
operation,
|
||||
self.next_app,
|
||||
validate_responses=self.validate_responses,
|
||||
strict_validation=self.strict_validation,
|
||||
validator_map=self.validator_map,
|
||||
uri_parser_class=self.uri_parser_class,
|
||||
)
|
||||
|
||||
def _add_operation_internal(
|
||||
self, operation_id: str, operation: "ValidationOperation"
|
||||
):
|
||||
self.operations[operation_id] = operation
|
||||
|
||||
|
||||
class ValidationOperation:
|
||||
def __init__(
|
||||
self,
|
||||
operation: AbstractOperation,
|
||||
next_app: ASGIApp,
|
||||
validate_responses: bool = False,
|
||||
strict_validation: bool = False,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
) -> None:
|
||||
self._operation = operation
|
||||
self.next_app = next_app
|
||||
self.validate_responses = validate_responses
|
||||
self.strict_validation = strict_validation
|
||||
self._validator_map = VALIDATOR_MAP
|
||||
self._validator_map.update(validator_map or {})
|
||||
self.uri_parser_class = uri_parser_class
|
||||
|
||||
def extract_content_type(self, headers: dict) -> t.Tuple[str, str]:
|
||||
"""Extract the mime type and encoding from the content type headers.
|
||||
|
||||
:param headers: Header dict from ASGI scope
|
||||
|
||||
:return: A tuple of mime type, encoding
|
||||
"""
|
||||
encoding = "utf-8"
|
||||
for key, value in headers:
|
||||
# Headers can always be decoded using latin-1:
|
||||
# https://stackoverflow.com/a/27357138/4098821
|
||||
key = key.decode("latin-1")
|
||||
if key.lower() == "content-type":
|
||||
content_type = value.decode("latin-1")
|
||||
if ";" in content_type:
|
||||
mime_type, parameters = content_type.split(";", maxsplit=1)
|
||||
|
||||
prefix = "charset="
|
||||
for parameter in parameters.split(";"):
|
||||
if parameter.startswith(prefix):
|
||||
encoding = parameter[len(prefix) :]
|
||||
else:
|
||||
mime_type = content_type
|
||||
break
|
||||
else:
|
||||
# Content-type header is not required. Take a best guess.
|
||||
mime_type = self._operation.consumes[0]
|
||||
|
||||
return mime_type, encoding
|
||||
|
||||
def validate_mime_type(self, mime_type: str) -> None:
|
||||
"""Validate the mime type against the spec.
|
||||
|
||||
:param mime_type: mime type from content type header
|
||||
"""
|
||||
if mime_type.lower() not in [c.lower() for c in self._operation.consumes]:
|
||||
raise UnsupportedMediaTypeProblem(
|
||||
detail=f"Invalid Content-type ({mime_type}), "
|
||||
f"expected {self._operation.consumes}"
|
||||
)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
headers = scope["headers"]
|
||||
mime_type, encoding = self.extract_content_type(headers)
|
||||
self.validate_mime_type(mime_type)
|
||||
|
||||
# TODO: Validate parameters
|
||||
|
||||
# Validate body
|
||||
try:
|
||||
body_validator = self._validator_map["body"][mime_type] # type: ignore
|
||||
except KeyError:
|
||||
logging.info(
|
||||
f"Skipping validation. No validator registered for content type: "
|
||||
f"{mime_type}."
|
||||
)
|
||||
else:
|
||||
validator = body_validator(
|
||||
self.next_app,
|
||||
schema=self._operation.body_schema,
|
||||
nullable=is_nullable(self._operation.body_definition),
|
||||
encoding=encoding,
|
||||
)
|
||||
return await validator(scope, receive, send)
|
||||
|
||||
await self.next_app(scope, receive, send)
|
||||
|
||||
|
||||
class MissingValidationOperation(Exception):
|
||||
"""Missing validation operation"""
|
||||
@@ -465,12 +465,12 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
ParameterValidator = self.validator_map["parameter"]
|
||||
RequestBodyValidator = self.validator_map["body"]
|
||||
if self.parameters:
|
||||
yield ParameterValidator(
|
||||
self.parameters, self.api, strict_validation=self.strict_validation
|
||||
)
|
||||
if self.body_schema:
|
||||
# TODO: temporarily hardcoded, remove RequestBodyValidator completely
|
||||
yield RequestBodyValidator(
|
||||
self.body_schema,
|
||||
self.consumes,
|
||||
|
||||
87
connexion/validators.py
Normal file
87
connexion/validators.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Contains validator classes used by the validation middleware.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from connexion.exceptions import BadRequestProblem
|
||||
from connexion.json_schema import Draft4RequestValidator
|
||||
from connexion.utils import is_null
|
||||
|
||||
logger = logging.getLogger("connexion.middleware.validators")
|
||||
|
||||
|
||||
class JSONBodyValidator:
|
||||
"""Request body validator for json content types."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
next_app: ASGIApp,
|
||||
*,
|
||||
schema: dict,
|
||||
validator: t.Type[Draft4Validator] = None,
|
||||
nullable=False,
|
||||
encoding: str,
|
||||
) -> None:
|
||||
self.next_app = next_app
|
||||
self.schema = schema
|
||||
self.has_default = schema.get("default", False)
|
||||
self.nullable = nullable
|
||||
self.validator_cls = validator or Draft4RequestValidator
|
||||
self.validator = self.validator_cls(
|
||||
schema, format_checker=draft4_format_checker
|
||||
)
|
||||
self.encoding = encoding
|
||||
|
||||
@classmethod
|
||||
def _error_path_message(cls, exception):
|
||||
error_path = ".".join(str(item) for item in exception.path)
|
||||
error_path_msg = f" - '{error_path}'" if error_path else ""
|
||||
return error_path_msg
|
||||
|
||||
def validate(self, body: dict):
|
||||
|
||||
try:
|
||||
self.validator.validate(body)
|
||||
except ValidationError as exception:
|
||||
error_path_msg = self._error_path_message(exception=exception)
|
||||
logger.error(
|
||||
f"Validation error: {exception.message}{error_path_msg}",
|
||||
extra={"validator": "body"},
|
||||
)
|
||||
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Based on https://github.com/encode/starlette/pull/1519#issuecomment-1060633787
|
||||
# Ingest all body messages from the ASGI `receive` callable.
|
||||
messages = []
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
messages.append(message)
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# TODO: make json library pluggable
|
||||
bytes_body = b"".join([message.get("body", b"") for message in messages])
|
||||
decoded_body = bytes_body.decode(self.encoding)
|
||||
|
||||
if decoded_body and not (self.nullable and is_null(decoded_body)):
|
||||
try:
|
||||
body = json.loads(decoded_body)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
raise BadRequestProblem(str(e))
|
||||
|
||||
self.validate(body)
|
||||
|
||||
async def wrapped_receive():
|
||||
# First up we want to return any messages we've stashed.
|
||||
if messages:
|
||||
return messages.pop(0)
|
||||
# Once that's done we can just await any other messages.
|
||||
return await receive()
|
||||
|
||||
await self.next_app(scope, wrapped_receive, send)
|
||||
2202
docs/images/validation.excalidraw
Normal file
2202
docs/images/validation.excalidraw
Normal file
File diff suppressed because it is too large
Load Diff
@@ -95,8 +95,7 @@ def test_errors(problem_app):
|
||||
)
|
||||
assert unsupported_media_type_body["type"] == "about:blank"
|
||||
assert unsupported_media_type_body["title"] == "Unsupported Media Type"
|
||||
assert (
|
||||
unsupported_media_type_body["detail"]
|
||||
== "Invalid Content-type (text/html), expected JSON data"
|
||||
assert unsupported_media_type_body["detail"].startswith(
|
||||
"Invalid Content-type (text/html)"
|
||||
)
|
||||
assert unsupported_media_type_body["status"] == 415
|
||||
|
||||
@@ -356,11 +356,6 @@ def test_post_wrong_content_type(simple_app):
|
||||
)
|
||||
assert resp.status_code == 415
|
||||
|
||||
resp = app_client.post(
|
||||
"/v1.0/post_wrong_content_type", data=json.dumps({"some": "data"})
|
||||
)
|
||||
assert resp.status_code == 415
|
||||
|
||||
resp = app_client.post(
|
||||
"/v1.0/post_wrong_content_type",
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
@@ -368,31 +363,6 @@ def test_post_wrong_content_type(simple_app):
|
||||
)
|
||||
assert resp.status_code == 415
|
||||
|
||||
# this test checks exactly what the test directly above is supposed to check,
|
||||
# i.e. no content-type is provided in the header
|
||||
# unfortunately there is an issue with the werkzeug test environment
|
||||
# (https://github.com/pallets/werkzeug/issues/1159)
|
||||
# so that content-type is added to every request, we remove it here manually for our test
|
||||
# this test can be removed once the werkzeug issue is addressed
|
||||
builder = EnvironBuilder(
|
||||
path="/v1.0/post_wrong_content_type",
|
||||
method="POST",
|
||||
data=json.dumps({"some": "data"}),
|
||||
)
|
||||
try:
|
||||
environ = builder.get_environ()
|
||||
finally:
|
||||
builder.close()
|
||||
|
||||
content_type = "CONTENT_TYPE"
|
||||
if content_type in environ:
|
||||
environ.pop("CONTENT_TYPE")
|
||||
# we cannot just call app_client.open() since app_client is a flask.testing.FlaskClient
|
||||
# which overrides werkzeug.test.Client.open() but does not allow passing an environment
|
||||
# directly
|
||||
resp = Client.open(app_client, environ)
|
||||
assert resp.status_code == 415
|
||||
|
||||
resp = app_client.post(
|
||||
"/v1.0/post_wrong_content_type",
|
||||
content_type="application/json",
|
||||
|
||||
@@ -8,7 +8,7 @@ from connexion.resolver import MethodResolver, MethodViewResolver
|
||||
from connexion.security import SecurityHandlerFactory
|
||||
from werkzeug.test import Client, EnvironBuilder
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
TEST_FOLDER = pathlib.Path(__file__).parent
|
||||
FIXTURES_FOLDER = TEST_FOLDER / "fixtures"
|
||||
|
||||
2
tests/fixtures/simple/openapi.yaml
vendored
2
tests/fixtures/simple/openapi.yaml
vendored
@@ -980,7 +980,7 @@ paths:
|
||||
type: object
|
||||
requestBody:
|
||||
content:
|
||||
multipart/form-data:
|
||||
application/x-www-form-urlencoded:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
|
||||
4
tests/fixtures/simple/swagger.yaml
vendored
4
tests/fixtures/simple/swagger.yaml
vendored
@@ -437,6 +437,8 @@ paths:
|
||||
|
||||
/test-formData-missing-param:
|
||||
post:
|
||||
consumes:
|
||||
- application/x-www-form-urlencoded
|
||||
summary: Test formData missing parameter in handler
|
||||
operationId: fakeapi.hello.test_formdata_missing_param
|
||||
parameters:
|
||||
@@ -804,7 +806,7 @@ paths:
|
||||
post:
|
||||
operationId: fakeapi.hello.test_param_sanitization
|
||||
consumes:
|
||||
- multipart/form-data
|
||||
- application/x-www-form-urlencoded
|
||||
produces:
|
||||
- application/json
|
||||
parameters:
|
||||
|
||||
@@ -6,6 +6,7 @@ from connexion import App
|
||||
from connexion.decorators.validation import RequestBodyValidator
|
||||
from connexion.json_schema import Draft4RequestValidator
|
||||
from connexion.spec import Specification
|
||||
from connexion.validators import JSONBodyValidator
|
||||
from jsonschema.validators import _utils, extend
|
||||
|
||||
from conftest import build_app_from_fixture
|
||||
@@ -30,11 +31,11 @@ def test_validator_map(json_validation_spec_dir, spec):
|
||||
|
||||
MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type})
|
||||
|
||||
class MyRequestBodyValidator(RequestBodyValidator):
|
||||
class MyJSONBodyValidator(JSONBodyValidator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, validator=MinLengthRequestValidator, **kwargs)
|
||||
|
||||
validator_map = {"body": MyRequestBodyValidator}
|
||||
validator_map = {"body": {"application/json": MyJSONBodyValidator}}
|
||||
|
||||
app = App(__name__, specification_dir=json_validation_spec_dir)
|
||||
app.add_api(spec, validate_responses=True, validator_map=validator_map)
|
||||
|
||||
Reference in New Issue
Block a user