Skip to content

Commit 3ce93d6

Browse files
Adding stream cancel test
1 parent 31084f4 commit 3ce93d6

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed

tests/v2/test_v2_cancellation.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from typing import (
55
Any,
66
AsyncIterator,
7+
Literal,
78
)
89

910
import msgpack
1011
import nanoid
12+
import pytest
1113

1214
from replit_river.messages import parse_transport_msg
1315
from replit_river.rpc import (
@@ -115,6 +117,151 @@ async def handle_server_messages() -> None:
115117
await server_handler
116118

117119

120+
@pytest.mark.parametrize("direction", ["send", "receive"])
121+
async def test_stream_cancel(
122+
ws_server: WsServerFixture, direction: Literal["send", "receive"]
123+
) -> None:
124+
(urimeta, recv, conn) = ws_server
125+
126+
client = Client(
127+
client_id="CLIENT1",
128+
server_id="SERVER",
129+
transport_options=TransportOptions(),
130+
uri_and_metadata_factory=urimeta,
131+
)
132+
133+
connecting = asyncio.create_task(client.ensure_connected())
134+
request_msg = parse_transport_msg(await recv.get())
135+
136+
assert not isinstance(request_msg, str)
137+
assert (serverconn := conn())
138+
handshake_request: ControlMessageHandshakeRequest[None] = (
139+
ControlMessageHandshakeRequest(**request_msg.payload)
140+
)
141+
142+
handshake_resp = ControlMessageHandshakeResponse(
143+
status=HandShakeStatus(
144+
ok=True,
145+
),
146+
)
147+
handshake_request.sessionId
148+
149+
msg = TransportMessage(
150+
from_=request_msg.from_,
151+
to=request_msg.to,
152+
streamId=request_msg.streamId,
153+
controlFlags=0,
154+
id=nanoid.generate(),
155+
seq=0,
156+
ack=0,
157+
payload=handshake_resp.model_dump(),
158+
)
159+
packed = msgpack.packb(
160+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
161+
)
162+
await serverconn.send(packed)
163+
164+
bidi_waiter = asyncio.Event()
165+
166+
async def send_server_messages(request_msg: TransportMessage) -> None:
167+
seq = 0
168+
169+
while True:
170+
msg = TransportMessage(
171+
from_=request_msg.to,
172+
to=request_msg.from_,
173+
streamId=request_msg.streamId,
174+
controlFlags=0,
175+
id=nanoid.generate(),
176+
seq=seq,
177+
ack=0,
178+
payload={
179+
"ok": True,
180+
"payload": {
181+
"hello": "world",
182+
},
183+
},
184+
)
185+
seq += 1
186+
packed = msgpack.packb(
187+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
188+
)
189+
await serverconn.send(packed)
190+
await asyncio.sleep(0.1)
191+
192+
if seq > 5 and direction == "send":
193+
bidi_waiter.set()
194+
195+
async def handle_server_messages(request_msg: TransportMessage) -> None:
196+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
197+
while msg.payload.get("payload", {}).get("hello") == "world":
198+
logging.debug("Found a hello:world %r", repr(msg))
199+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
200+
201+
assert msg.controlFlags == STREAM_CANCEL_BIT
202+
203+
async def receive_chunks() -> None:
204+
async def _upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]:
205+
count = 0
206+
while True:
207+
await asyncio.sleep(0.1)
208+
yield {
209+
"ok": True,
210+
"payload": {
211+
"hello": "world",
212+
},
213+
}
214+
count += 1
215+
if count > 5 and direction == "receive":
216+
# We've sent enough messages, interrupt the stream.
217+
bidi_waiter.set()
218+
219+
async for chunk in client.send_stream(
220+
"test",
221+
"bigstream",
222+
{},
223+
_upload_chunks(),
224+
lambda x: x,
225+
lambda x: x,
226+
lambda x: x,
227+
lambda x: x,
228+
):
229+
print(repr(chunk))
230+
231+
receive_task = asyncio.create_task(receive_chunks())
232+
request_msg = parse_transport_msg(await recv.get())
233+
logging.debug("request_msg: %r", repr(request_msg))
234+
assert not isinstance(request_msg, str)
235+
236+
server_sender = asyncio.create_task(send_server_messages(request_msg))
237+
server_receiver = asyncio.create_task(handle_server_messages(request_msg))
238+
239+
# Wait until we've seen at least a few messages from the requisite Task
240+
await bidi_waiter.wait()
241+
242+
receive_task.cancel()
243+
try:
244+
await receive_task
245+
except asyncio.CancelledError:
246+
pass
247+
248+
await client.close()
249+
await connecting
250+
251+
# Ensure we're listening to close messages as well
252+
assert server_sender
253+
server_sender.cancel()
254+
try:
255+
await server_sender
256+
except asyncio.CancelledError:
257+
pass
258+
server_receiver.cancel()
259+
try:
260+
await server_receiver
261+
except Exception:
262+
pass
263+
264+
118265
async def test_subscription_cancel(ws_server: WsServerFixture) -> None:
119266
(urimeta, recv, conn) = ws_server
120267

0 commit comments

Comments
 (0)