Create MediaTypeDict class for range matching (#1603)

This commit is contained in:
Robbe Sneyders
2022-11-04 00:28:23 +01:00
committed by Robbe Sneyders
parent b8bdcc999d
commit 9d7258c25d
8 changed files with 95 additions and 11 deletions

View File

@@ -0,0 +1,31 @@
from fnmatch import fnmatch
class MediaTypeDict(dict):
"""
A dictionary where keys can be either media types or media type ranges. When fetching a
value from the dictionary, the provided key is checked against the ranges. The most specific
key is chosen as prescribed by the OpenAPI spec, with `type/*` being preferred above
`*/subtype`.
"""
def __getitem__(self, item):
# Sort keys in order of specificity
for key in sorted(self, key=lambda k: ("*" not in k, k), reverse=True):
if fnmatch(item, key):
return super().__getitem__(key)
raise super().__getitem__(item)
def get(self, item, default=None):
try:
return self[item]
except KeyError:
return default
def __contains__(self, item):
try:
self[item]
except KeyError:
return False
else:
return True

View File

@@ -7,6 +7,7 @@ import typing as t
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion import utils from connexion import utils
from connexion.datastructures import MediaTypeDict
from connexion.decorators.uri_parsing import AbstractURIParser from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import UnsupportedMediaTypeProblem from connexion.exceptions import UnsupportedMediaTypeProblem
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
@@ -59,7 +60,11 @@ class RequestValidationOperation:
:param mime_type: mime type from content type header :param mime_type: mime type from content type header
""" """
if mime_type.lower() not in [c.lower() for c in self._operation.consumes]: # Convert to MediaTypeDict to handle media-ranges
media_type_dict = MediaTypeDict(
[(c.lower(), None) for c in self._operation.consumes]
)
if mime_type.lower() not in media_type_dict:
raise UnsupportedMediaTypeProblem( raise UnsupportedMediaTypeProblem(
detail=f"Invalid Content-type ({mime_type}), " detail=f"Invalid Content-type ({mime_type}), "
f"expected {self._operation.consumes}" f"expected {self._operation.consumes}"

View File

@@ -5,6 +5,7 @@ This module defines an OpenAPIOperation class, a Connexion operation specific fo
import logging import logging
from copy import copy, deepcopy from copy import copy, deepcopy
from connexion.datastructures import MediaTypeDict
from connexion.operations.abstract import AbstractOperation from connexion.operations.abstract import AbstractOperation
from ..decorators.uri_parsing import OpenAPIURIParser from ..decorators.uri_parsing import OpenAPIURIParser
@@ -274,7 +275,8 @@ class OpenAPIOperation(AbstractOperation):
"this operation accepts multiple content types, using %s", "this operation accepts multiple content types, using %s",
content_type, content_type,
) )
res = self._request_body.get("content", {}).get(content_type, {}) content_type_dict = MediaTypeDict(self._request_body.get("content", {}))
res = content_type_dict.get(content_type, {})
return self.with_definitions(res) return self.with_definitions(res)
return {} return {}

View File

@@ -10,6 +10,7 @@ from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import FormParser, MultiPartParser from starlette.formparsers import FormParser, MultiPartParser
from starlette.types import Receive, Scope, Send from starlette.types import Receive, Scope, Send
from connexion.datastructures import MediaTypeDict
from connexion.decorators.uri_parsing import AbstractURIParser from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.decorators.validation import ( from connexion.decorators.validation import (
ParameterValidator, ParameterValidator,
@@ -306,13 +307,17 @@ class MultiPartFormDataValidator(FormDataValidator):
VALIDATOR_MAP = { VALIDATOR_MAP = {
"parameter": ParameterValidator, "parameter": ParameterValidator,
"body": { "body": MediaTypeDict(
"application/json": JSONRequestBodyValidator, {
"*/*json": JSONRequestBodyValidator,
"application/x-www-form-urlencoded": FormDataValidator, "application/x-www-form-urlencoded": FormDataValidator,
"multipart/form-data": MultiPartFormDataValidator, "multipart/form-data": MultiPartFormDataValidator,
}, }
"response": { ),
"application/json": JSONResponseBodyValidator, "response": MediaTypeDict(
"text/plain": TextResponseBodyValidator, {
}, "*/*json": JSONResponseBodyValidator,
"text/plain": TextResponseBodyValidator,
}
),
} }

View File

@@ -309,3 +309,13 @@ def test_global_response_definitions(schema_app):
app_client = schema_app.app.test_client() app_client = schema_app.app.test_client()
resp = app_client.get("/v1.0/define_global_response") resp = app_client.get("/v1.0/define_global_response")
assert json.loads(resp.data.decode("utf-8", "replace")) == ["general", "list"] assert json.loads(resp.data.decode("utf-8", "replace")) == ["general", "list"]
def test_media_range(schema_app):
app_client = schema_app.app.test_client()
headers = {"Content-type": "application/json"}
array_request = app_client.post(
"/v1.0/media_range", headers=headers, data=json.dumps({})
)
assert array_request.status_code == 200, array_request.text

View File

@@ -389,6 +389,10 @@ def test_global_response_definition():
return ["general", "list"], 200 return ["general", "list"], 200
def test_media_range():
return "OK"
def test_nullable_parameters(time_start): def test_nullable_parameters(time_start):
if time_start is None: if time_start is None:
return "it was None" return "it was None"

View File

@@ -264,6 +264,18 @@ paths:
responses: responses:
'200': '200':
$ref: '#/components/responses/GeneralList' $ref: '#/components/responses/GeneralList'
/media_range:
post:
description: Test media range
operationId: fakeapi.hello.test_media_range
requestBody:
content:
'*/*':
schema:
type: object
responses:
'200':
description: OK
components: components:
responses: responses:
GeneralList: GeneralList:

View File

@@ -291,6 +291,21 @@ paths:
200: 200:
$ref: '#/responses/GeneralList' $ref: '#/responses/GeneralList'
/media_range:
post:
description: Test media range
operationId: fakeapi.hello.test_media_range
consumes:
- '*/*'
parameters:
- name: body
in: body
schema:
type: object
responses:
'200':
description: OK
definitions: definitions:
new_stack: new_stack:
type: object type: object