Fix stream replay in validators (#1678)

The current implementation of replaying the stream will always replay
the first message. This PR fixes this by progressing through the
messages with each call.
This commit is contained in:
Robbe Sneyders
2023-03-30 22:12:56 +02:00
committed by GitHub
parent 8cebebc3a4
commit 55e376f816
2 changed files with 34 additions and 4 deletions

View File

@@ -99,10 +99,13 @@ class AbstractRequestBodyValidator:
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]] receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
) -> Receive: ) -> Receive:
"""Insert messages at the start of the `receive` channel.""" """Insert messages at the start of the `receive` channel."""
# Ensure that messages is an iterator so each message is replayed once.
message_iterator = iter(messages)
async def receive_() -> t.MutableMapping[str, t.Any]: async def receive_() -> t.MutableMapping[str, t.Any]:
for message in messages: try:
return message return next(message_iterator)
except StopIteration:
return await receive() return await receive()
return receive_ return receive_

View File

@@ -4,7 +4,7 @@ from urllib.parse import quote_plus
import pytest import pytest
from connexion.exceptions import BadRequestProblem from connexion.exceptions import BadRequestProblem
from connexion.uri_parsing import Swagger2URIParser from connexion.uri_parsing import Swagger2URIParser
from connexion.validators.parameter import ParameterValidator from connexion.validators import AbstractRequestBodyValidator, ParameterValidator
from starlette.datastructures import QueryParams from starlette.datastructures import QueryParams
@@ -140,3 +140,30 @@ def test_parameter_validator(monkeypatch):
with pytest.raises(BadRequestProblem) as exc: with pytest.raises(BadRequestProblem) as exc:
validator.validate_request(request) validator.validate_request(request)
assert exc.value.detail.startswith("'x' is not one of ['a', 'b']") assert exc.value.detail.startswith("'x' is not one of ['a', 'b']")
async def test_stream_replay():
messages = [
{"body": b"message 1", "more_body": True},
{"body": b"message 2", "more_body": False},
]
async def receive():
return b""
wrapped_receive = AbstractRequestBodyValidator._insert_messages(
receive, messages=messages
)
replay = []
more_body = True
while more_body:
message = await wrapped_receive()
replay.append(message)
more_body = message.get("more_body", False)
assert len(replay) <= len(
messages
), "Replayed more messages than received, break out of while loop"
assert messages == replay