mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 12:27:45 +00:00
Coerce types only in uri parser (#1627)
This PR moves all type coercing into the URI parsers and makes sure it's only done once for each code path.
This commit is contained in:
@@ -9,7 +9,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from connexion.exceptions import TypeValidationError
|
from connexion.exceptions import TypeValidationError
|
||||||
from connexion.utils import all_json, coerce_type, deep_merge, is_null, is_nullable
|
from connexion.utils import all_json, coerce_type, deep_merge
|
||||||
|
|
||||||
logger = logging.getLogger("connexion.decorators.uri_parsing")
|
logger = logging.getLogger("connexion.decorators.uri_parsing")
|
||||||
|
|
||||||
@@ -119,14 +119,12 @@ class AbstractURIParser(metaclass=abc.ABCMeta):
|
|||||||
else:
|
else:
|
||||||
resolved_param[k] = values[-1]
|
resolved_param[k] = values[-1]
|
||||||
|
|
||||||
if not (is_nullable(param_defn) and is_null(resolved_param[k])):
|
try:
|
||||||
try:
|
resolved_param[k] = coerce_type(
|
||||||
# TODO: coerce types in a single place
|
param_defn, resolved_param[k], "parameter", k
|
||||||
resolved_param[k] = coerce_type(
|
)
|
||||||
param_defn, resolved_param[k], "parameter", k
|
except TypeValidationError:
|
||||||
)
|
pass
|
||||||
except TypeValidationError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return resolved_param
|
return resolved_param
|
||||||
|
|
||||||
@@ -166,6 +164,7 @@ class OpenAPIURIParser(AbstractURIParser):
|
|||||||
form_data[k] = self._split(form_data[k], encoding, "form")
|
form_data[k] = self._split(form_data[k], encoding, "form")
|
||||||
elif "contentType" in encoding and all_json([encoding.get("contentType")]):
|
elif "contentType" in encoding and all_json([encoding.get("contentType")]):
|
||||||
form_data[k] = json.loads(form_data[k])
|
form_data[k] = json.loads(form_data[k])
|
||||||
|
form_data[k] = coerce_type(defn, form_data[k], "requestBody", k)
|
||||||
return form_data
|
return form_data
|
||||||
|
|
||||||
def _make_deep_object(self, k, v):
|
def _make_deep_object(self, k, v):
|
||||||
|
|||||||
@@ -6,14 +6,10 @@ from starlette.datastructures import FormData, Headers, UploadFile
|
|||||||
from starlette.formparsers import FormParser, MultiPartParser
|
from starlette.formparsers import FormParser, MultiPartParser
|
||||||
from starlette.types import Receive, Scope
|
from starlette.types import Receive, Scope
|
||||||
|
|
||||||
from connexion.exceptions import (
|
from connexion.exceptions import BadRequestProblem, ExtraParameterProblem
|
||||||
BadRequestProblem,
|
|
||||||
ExtraParameterProblem,
|
|
||||||
TypeValidationError,
|
|
||||||
)
|
|
||||||
from connexion.json_schema import Draft4RequestValidator
|
from connexion.json_schema import Draft4RequestValidator
|
||||||
from connexion.uri_parsing import AbstractURIParser
|
from connexion.uri_parsing import AbstractURIParser
|
||||||
from connexion.utils import coerce_type, is_null
|
from connexion.utils import is_null
|
||||||
|
|
||||||
logger = logging.getLogger("connexion.validators.form_data")
|
logger = logging.getLogger("connexion.validators.form_data")
|
||||||
|
|
||||||
@@ -76,16 +72,7 @@ class FormDataValidator:
|
|||||||
)
|
)
|
||||||
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
|
||||||
|
|
||||||
def validate(self, data: FormData) -> None:
|
def _parse(self, data: FormData) -> dict:
|
||||||
if self.strict_validation:
|
|
||||||
form_params = data.keys()
|
|
||||||
spec_params = self.schema.get("properties", {}).keys()
|
|
||||||
errors = set(form_params).difference(set(spec_params))
|
|
||||||
if errors:
|
|
||||||
raise ExtraParameterProblem(errors, [])
|
|
||||||
|
|
||||||
props = self.schema.get("properties", {})
|
|
||||||
errs = []
|
|
||||||
if self.uri_parser is not None:
|
if self.uri_parser is not None:
|
||||||
# Don't parse file_data
|
# Don't parse file_data
|
||||||
form_data = {}
|
form_data = {}
|
||||||
@@ -94,7 +81,8 @@ class FormDataValidator:
|
|||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
form_data[k] = data.getlist(k)
|
form_data[k] = data.getlist(k)
|
||||||
elif isinstance(v, UploadFile):
|
elif isinstance(v, UploadFile):
|
||||||
file_data[k] = data.getlist(k)
|
# Replace files with empty strings for validation
|
||||||
|
file_data[k] = ""
|
||||||
|
|
||||||
data = self.uri_parser.resolve_form(form_data)
|
data = self.uri_parser.resolve_form(form_data)
|
||||||
# Add the files again
|
# Add the files again
|
||||||
@@ -102,22 +90,20 @@ class FormDataValidator:
|
|||||||
else:
|
else:
|
||||||
data = {k: data.getlist(k) for k in data}
|
data = {k: data.getlist(k) for k in data}
|
||||||
|
|
||||||
for k, param_defn in props.items():
|
return data
|
||||||
if k in data:
|
|
||||||
if param_defn.get("format", "") == "binary":
|
|
||||||
# Replace files with empty strings for validation
|
|
||||||
data[k] = ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
def _validate_strictly(self, data: FormData) -> None:
|
||||||
data[k] = coerce_type(param_defn, data[k], "requestBody", k)
|
form_params = data.keys()
|
||||||
except TypeValidationError as e:
|
spec_params = self.schema.get("properties", {}).keys()
|
||||||
logger.exception(e)
|
errors = set(form_params).difference(set(spec_params))
|
||||||
errs += [str(e)]
|
if errors:
|
||||||
|
raise ExtraParameterProblem(errors, [])
|
||||||
|
|
||||||
if errs:
|
def validate(self, data: FormData) -> None:
|
||||||
raise BadRequestProblem(detail=errs)
|
if self.strict_validation:
|
||||||
|
self._validate_strictly(data)
|
||||||
|
|
||||||
|
data = self._parse(data)
|
||||||
self._validate(data)
|
self._validate(data)
|
||||||
|
|
||||||
async def wrapped_receive(self) -> Receive:
|
async def wrapped_receive(self) -> Receive:
|
||||||
|
|||||||
@@ -5,12 +5,8 @@ import logging
|
|||||||
from jsonschema import Draft4Validator, ValidationError
|
from jsonschema import Draft4Validator, ValidationError
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from connexion.exceptions import (
|
from connexion.exceptions import BadRequestProblem, ExtraParameterProblem
|
||||||
BadRequestProblem,
|
from connexion.utils import boolean, is_null, is_nullable
|
||||||
ExtraParameterProblem,
|
|
||||||
TypeValidationError,
|
|
||||||
)
|
|
||||||
from connexion.utils import boolean, coerce_type, is_null, is_nullable
|
|
||||||
|
|
||||||
logger = logging.getLogger("connexion.validators.parameter")
|
logger = logging.getLogger("connexion.validators.parameter")
|
||||||
|
|
||||||
@@ -38,35 +34,17 @@ class ParameterValidator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_parameter(parameter_type, value, param, param_name=None):
|
def validate_parameter(parameter_type, value, param, param_name=None):
|
||||||
if value is not None:
|
if is_nullable(param) and is_null(value):
|
||||||
if is_nullable(param) and is_null(value):
|
return
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
converted_value = coerce_type(param, value, parameter_type, param_name)
|
|
||||||
except TypeValidationError as e:
|
|
||||||
return str(e)
|
|
||||||
|
|
||||||
|
elif value is not None:
|
||||||
param = copy.deepcopy(param)
|
param = copy.deepcopy(param)
|
||||||
param = param.get("schema", param)
|
param = param.get("schema", param)
|
||||||
if "required" in param:
|
|
||||||
del param["required"]
|
|
||||||
try:
|
try:
|
||||||
Draft4Validator(param, format_checker=draft4_format_checker).validate(
|
Draft4Validator(param, format_checker=draft4_format_checker).validate(
|
||||||
converted_value
|
value
|
||||||
)
|
)
|
||||||
except ValidationError as exception:
|
except ValidationError as exception:
|
||||||
debug_msg = (
|
|
||||||
"Error while converting value {converted_value} from param "
|
|
||||||
"{type_converted_value} of type real type {param_type} to the declared type {param}"
|
|
||||||
)
|
|
||||||
fmt_params = dict(
|
|
||||||
converted_value=str(converted_value),
|
|
||||||
type_converted_value=type(converted_value),
|
|
||||||
param_type=param.get("type"),
|
|
||||||
param=param,
|
|
||||||
)
|
|
||||||
logger.info(debug_msg.format(**fmt_params))
|
|
||||||
return str(exception)
|
return str(exception)
|
||||||
|
|
||||||
elif param.get("required"):
|
elif param.get("required"):
|
||||||
@@ -102,10 +80,8 @@ class ParameterValidator:
|
|||||||
return self.validate_parameter("query", val, param)
|
return self.validate_parameter("query", val, param)
|
||||||
|
|
||||||
def validate_path_parameter(self, param, request):
|
def validate_path_parameter(self, param, request):
|
||||||
# TODO: activate
|
path_params = self.uri_parser.resolve_path(request.path_params)
|
||||||
# path_params = self.uri_parser.resolve_path(request.path_params)
|
val = path_params.get(param["name"].replace("-", "_"))
|
||||||
# val = path_params.get(param["name"].replace("-", "_"))
|
|
||||||
val = request.path_params.get(param["name"].replace("-", "_"))
|
|
||||||
return self.validate_parameter("path", val, param)
|
return self.validate_parameter("path", val, param)
|
||||||
|
|
||||||
def validate_header_parameter(self, param, request):
|
def validate_header_parameter(self, param, request):
|
||||||
|
|||||||
@@ -561,10 +561,7 @@ def test_parameters_snake_case(snake_case_app):
|
|||||||
assert resp.get_json() == {"truthiness": True, "order_by": "asc"}
|
assert resp.get_json() == {"truthiness": True, "order_by": "asc"}
|
||||||
resp = app_client.get("/v1.0/test-get-camel-case-version?truthiness=5")
|
resp = app_client.get("/v1.0/test-get-camel-case-version?truthiness=5")
|
||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
assert (
|
assert resp.get_json()["detail"].startswith("'5' is not of type 'boolean'")
|
||||||
resp.get_json()["detail"]
|
|
||||||
== "Wrong type, expected 'boolean' for query parameter 'truthiness'"
|
|
||||||
)
|
|
||||||
# Incorrectly cased params should be ignored
|
# Incorrectly cased params should be ignored
|
||||||
resp = app_client.get(
|
resp = app_client.get(
|
||||||
"/v1.0/test-get-camel-case-version?Truthiness=true&order_by=asc"
|
"/v1.0/test-get-camel-case-version?Truthiness=true&order_by=asc"
|
||||||
|
|||||||
@@ -8,7 +8,4 @@ def test_app(unordered_definition_app):
|
|||||||
) # type: flask.Response
|
) # type: flask.Response
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
response_data = json.loads(response.data.decode("utf-8", "replace"))
|
response_data = json.loads(response.data.decode("utf-8", "replace"))
|
||||||
assert (
|
assert response_data["detail"].startswith("'first' is not of type 'integer'")
|
||||||
response_data["detail"]
|
|
||||||
== "Wrong type, expected 'integer' for query parameter 'first'"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator
|
from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator
|
||||||
|
from connexion.utils import coerce_type
|
||||||
from connexion.validators.parameter import ParameterValidator
|
from connexion.validators.parameter import ParameterValidator
|
||||||
from jsonschema import ValidationError
|
from jsonschema import ValidationError
|
||||||
|
|
||||||
@@ -68,6 +69,7 @@ def test_get_valid_parameter_with_enum_array_header():
|
|||||||
},
|
},
|
||||||
"name": "test_header_param",
|
"name": "test_header_param",
|
||||||
}
|
}
|
||||||
|
value = coerce_type(param, value, "header", "test_header_param")
|
||||||
result = ParameterValidator.validate_parameter("header", value, param)
|
result = ParameterValidator.validate_parameter("header", value, param)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@@ -86,7 +88,6 @@ Failed validating 'type' in schema:
|
|||||||
On instance:
|
On instance:
|
||||||
20"""
|
20"""
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
logger.info.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_type_value_error(monkeypatch):
|
def test_invalid_type_value_error(monkeypatch):
|
||||||
@@ -94,7 +95,7 @@ def test_invalid_type_value_error(monkeypatch):
|
|||||||
result = ParameterValidator.validate_parameter(
|
result = ParameterValidator.validate_parameter(
|
||||||
"formdata", value, {"type": "boolean", "name": "foo"}
|
"formdata", value, {"type": "boolean", "name": "foo"}
|
||||||
)
|
)
|
||||||
assert result == "Wrong type, expected 'boolean' for formdata parameter 'foo'"
|
assert result.startswith("{'test': 1, 'second': 2} is not of type 'boolean'")
|
||||||
|
|
||||||
|
|
||||||
def test_enum_error(monkeypatch):
|
def test_enum_error(monkeypatch):
|
||||||
|
|||||||
@@ -42,17 +42,17 @@ def test_parameter_validator(monkeypatch):
|
|||||||
request = MagicMock(path_params={"p1": ""}, **kwargs)
|
request = MagicMock(path_params={"p1": ""}, **kwargs)
|
||||||
with pytest.raises(BadRequestProblem) as exc:
|
with pytest.raises(BadRequestProblem) as exc:
|
||||||
validator.validate_request(request)
|
validator.validate_request(request)
|
||||||
assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'"
|
assert exc.value.detail.startswith("'' is not of type 'integer'")
|
||||||
|
|
||||||
request = MagicMock(path_params={"p1": "foo"}, **kwargs)
|
request = MagicMock(path_params={"p1": "foo"}, **kwargs)
|
||||||
with pytest.raises(BadRequestProblem) as exc:
|
with pytest.raises(BadRequestProblem) as exc:
|
||||||
validator.validate_request(request)
|
validator.validate_request(request)
|
||||||
assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'"
|
assert exc.value.detail.startswith("'foo' is not of type 'integer'")
|
||||||
|
|
||||||
request = MagicMock(path_params={"p1": "1.2"}, **kwargs)
|
request = MagicMock(path_params={"p1": "1.2"}, **kwargs)
|
||||||
with pytest.raises(BadRequestProblem) as exc:
|
with pytest.raises(BadRequestProblem) as exc:
|
||||||
validator.validate_request(request)
|
validator.validate_request(request)
|
||||||
assert exc.value.detail == "Wrong type, expected 'integer' for path parameter 'p1'"
|
assert exc.value.detail.startswith("'1.2' is not of type 'integer'")
|
||||||
|
|
||||||
request = MagicMock(
|
request = MagicMock(
|
||||||
path_params={"p1": 1}, query_params=QueryParams("q1=4"), headers={}, cookies={}
|
path_params={"p1": 1}, query_params=QueryParams("q1=4"), headers={}, cookies={}
|
||||||
|
|||||||
Reference in New Issue
Block a user