mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 04:19:26 +00:00
Fix stream replay in validators (#1678)
The current implementation of replaying the stream will always replay the first message. This PR fixes this by progressing through the messages with each call.
This commit is contained in:
@@ -99,10 +99,13 @@ class AbstractRequestBodyValidator:
|
||||
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
|
||||
) -> Receive:
|
||||
"""Insert messages at the start of the `receive` channel."""
|
||||
# Ensure that messages is an iterator so each message is replayed once.
|
||||
message_iterator = iter(messages)
|
||||
|
||||
async def receive_() -> t.MutableMapping[str, t.Any]:
|
||||
for message in messages:
|
||||
return message
|
||||
try:
|
||||
return next(message_iterator)
|
||||
except StopIteration:
|
||||
return await receive()
|
||||
|
||||
return receive_
|
||||
|
||||
@@ -4,7 +4,7 @@ from urllib.parse import quote_plus
|
||||
import pytest
|
||||
from connexion.exceptions import BadRequestProblem
|
||||
from connexion.uri_parsing import Swagger2URIParser
|
||||
from connexion.validators.parameter import ParameterValidator
|
||||
from connexion.validators import AbstractRequestBodyValidator, ParameterValidator
|
||||
from starlette.datastructures import QueryParams
|
||||
|
||||
|
||||
@@ -140,3 +140,30 @@ def test_parameter_validator(monkeypatch):
|
||||
with pytest.raises(BadRequestProblem) as exc:
|
||||
validator.validate_request(request)
|
||||
assert exc.value.detail.startswith("'x' is not one of ['a', 'b']")
|
||||
|
||||
|
||||
async def test_stream_replay():
|
||||
messages = [
|
||||
{"body": b"message 1", "more_body": True},
|
||||
{"body": b"message 2", "more_body": False},
|
||||
]
|
||||
|
||||
async def receive():
|
||||
return b""
|
||||
|
||||
wrapped_receive = AbstractRequestBodyValidator._insert_messages(
|
||||
receive, messages=messages
|
||||
)
|
||||
|
||||
replay = []
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await wrapped_receive()
|
||||
replay.append(message)
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
assert len(replay) <= len(
|
||||
messages
|
||||
), "Replayed more messages than received, break out of while loop"
|
||||
|
||||
assert messages == replay
|
||||
|
||||
Reference in New Issue
Block a user