mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 12:27:45 +00:00
Fix stream replay in validators (#1678)
The current implementation of replaying the stream will always replay the first message. This PR fixes this by progressing through the messages with each call.
This commit is contained in:
@@ -99,10 +99,13 @@ class AbstractRequestBodyValidator:
|
|||||||
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
|
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
|
||||||
) -> Receive:
|
) -> Receive:
|
||||||
"""Insert messages at the start of the `receive` channel."""
|
"""Insert messages at the start of the `receive` channel."""
|
||||||
|
# Ensure that messages is an iterator so each message is replayed once.
|
||||||
|
message_iterator = iter(messages)
|
||||||
|
|
||||||
async def receive_() -> t.MutableMapping[str, t.Any]:
|
async def receive_() -> t.MutableMapping[str, t.Any]:
|
||||||
for message in messages:
|
try:
|
||||||
return message
|
return next(message_iterator)
|
||||||
|
except StopIteration:
|
||||||
return await receive()
|
return await receive()
|
||||||
|
|
||||||
return receive_
|
return receive_
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from urllib.parse import quote_plus
|
|||||||
import pytest
|
import pytest
|
||||||
from connexion.exceptions import BadRequestProblem
|
from connexion.exceptions import BadRequestProblem
|
||||||
from connexion.uri_parsing import Swagger2URIParser
|
from connexion.uri_parsing import Swagger2URIParser
|
||||||
from connexion.validators.parameter import ParameterValidator
|
from connexion.validators import AbstractRequestBodyValidator, ParameterValidator
|
||||||
from starlette.datastructures import QueryParams
|
from starlette.datastructures import QueryParams
|
||||||
|
|
||||||
|
|
||||||
@@ -140,3 +140,30 @@ def test_parameter_validator(monkeypatch):
|
|||||||
with pytest.raises(BadRequestProblem) as exc:
|
with pytest.raises(BadRequestProblem) as exc:
|
||||||
validator.validate_request(request)
|
validator.validate_request(request)
|
||||||
assert exc.value.detail.startswith("'x' is not one of ['a', 'b']")
|
assert exc.value.detail.startswith("'x' is not one of ['a', 'b']")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stream_replay():
|
||||||
|
messages = [
|
||||||
|
{"body": b"message 1", "more_body": True},
|
||||||
|
{"body": b"message 2", "more_body": False},
|
||||||
|
]
|
||||||
|
|
||||||
|
async def receive():
|
||||||
|
return b""
|
||||||
|
|
||||||
|
wrapped_receive = AbstractRequestBodyValidator._insert_messages(
|
||||||
|
receive, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
replay = []
|
||||||
|
more_body = True
|
||||||
|
while more_body:
|
||||||
|
message = await wrapped_receive()
|
||||||
|
replay.append(message)
|
||||||
|
more_body = message.get("more_body", False)
|
||||||
|
|
||||||
|
assert len(replay) <= len(
|
||||||
|
messages
|
||||||
|
), "Replayed more messages than received, break out of while loop"
|
||||||
|
|
||||||
|
assert messages == replay
|
||||||
|
|||||||
Reference in New Issue
Block a user