Create MediaTypeDict class for range matching (#1603)

This commit is contained in:
Robbe Sneyders
2022-11-04 00:28:23 +01:00
committed by Robbe Sneyders
parent b8bdcc999d
commit 9d7258c25d
8 changed files with 95 additions and 11 deletions

View File

@@ -0,0 +1,31 @@
from fnmatch import fnmatch
class MediaTypeDict(dict):
"""
A dictionary where keys can be either media types or media type ranges. When fetching a
value from the dictionary, the provided key is checked against the ranges. The most specific
key is chosen as prescribed by the OpenAPI spec, with `type/*` being preferred above
`*/subtype`.
"""
def __getitem__(self, item):
# Sort keys in order of specificity
for key in sorted(self, key=lambda k: ("*" not in k, k), reverse=True):
if fnmatch(item, key):
return super().__getitem__(key)
raise super().__getitem__(item)
def get(self, item, default=None):
try:
return self[item]
except KeyError:
return default
def __contains__(self, item):
try:
self[item]
except KeyError:
return False
else:
return True

View File

@@ -7,6 +7,7 @@ import typing as t
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion import utils
from connexion.datastructures import MediaTypeDict
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import UnsupportedMediaTypeProblem
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
@@ -59,7 +60,11 @@ class RequestValidationOperation:
:param mime_type: mime type from content type header
"""
if mime_type.lower() not in [c.lower() for c in self._operation.consumes]:
# Convert to MediaTypeDict to handle media-ranges
media_type_dict = MediaTypeDict(
[(c.lower(), None) for c in self._operation.consumes]
)
if mime_type.lower() not in media_type_dict:
raise UnsupportedMediaTypeProblem(
detail=f"Invalid Content-type ({mime_type}), "
f"expected {self._operation.consumes}"

View File

@@ -5,6 +5,7 @@ This module defines an OpenAPIOperation class, a Connexion operation specific fo
import logging
from copy import copy, deepcopy
from connexion.datastructures import MediaTypeDict
from connexion.operations.abstract import AbstractOperation
from ..decorators.uri_parsing import OpenAPIURIParser
@@ -274,7 +275,8 @@ class OpenAPIOperation(AbstractOperation):
"this operation accepts multiple content types, using %s",
content_type,
)
res = self._request_body.get("content", {}).get(content_type, {})
content_type_dict = MediaTypeDict(self._request_body.get("content", {}))
res = content_type_dict.get(content_type, {})
return self.with_definitions(res)
return {}

View File

@@ -10,6 +10,7 @@ from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import FormParser, MultiPartParser
from starlette.types import Receive, Scope, Send
from connexion.datastructures import MediaTypeDict
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.decorators.validation import (
ParameterValidator,
@@ -306,13 +307,17 @@ class MultiPartFormDataValidator(FormDataValidator):
VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {
"application/json": JSONRequestBodyValidator,
"application/x-www-form-urlencoded": FormDataValidator,
"multipart/form-data": MultiPartFormDataValidator,
},
"response": {
"application/json": JSONResponseBodyValidator,
"text/plain": TextResponseBodyValidator,
},
"body": MediaTypeDict(
{
"*/*json": JSONRequestBodyValidator,
"application/x-www-form-urlencoded": FormDataValidator,
"multipart/form-data": MultiPartFormDataValidator,
}
),
"response": MediaTypeDict(
{
"*/*json": JSONResponseBodyValidator,
"text/plain": TextResponseBodyValidator,
}
),
}

View File

@@ -309,3 +309,13 @@ def test_global_response_definitions(schema_app):
app_client = schema_app.app.test_client()
resp = app_client.get("/v1.0/define_global_response")
assert json.loads(resp.data.decode("utf-8", "replace")) == ["general", "list"]
def test_media_range(schema_app):
app_client = schema_app.app.test_client()
headers = {"Content-type": "application/json"}
array_request = app_client.post(
"/v1.0/media_range", headers=headers, data=json.dumps({})
)
assert array_request.status_code == 200, array_request.text

View File

@@ -389,6 +389,10 @@ def test_global_response_definition():
return ["general", "list"], 200
def test_media_range():
return "OK"
def test_nullable_parameters(time_start):
if time_start is None:
return "it was None"

View File

@@ -264,6 +264,18 @@ paths:
responses:
'200':
$ref: '#/components/responses/GeneralList'
/media_range:
post:
description: Test media range
operationId: fakeapi.hello.test_media_range
requestBody:
content:
'*/*':
schema:
type: object
responses:
'200':
description: OK
components:
responses:
GeneralList:

View File

@@ -291,6 +291,21 @@ paths:
200:
$ref: '#/responses/GeneralList'
/media_range:
post:
description: Test media range
operationId: fakeapi.hello.test_media_range
consumes:
- '*/*'
parameters:
- name: body
in: body
schema:
type: object
responses:
'200':
description: OK
definitions:
new_stack:
type: object