Skip to content

Commit d5a42a9

Browse files
Adding a test for cancelling upload
1 parent 46f65ee commit d5a42a9

File tree

1 file changed

+121
-5
lines changed

1 file changed

+121
-5
lines changed

tests/v2/test_v2_session_lifecycle.py

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import asyncio
22
import logging
3-
from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict
3+
from typing import (
4+
Any,
5+
AsyncIterator,
6+
Awaitable,
7+
Callable,
8+
Literal,
9+
TypeAlias,
10+
TypedDict,
11+
)
412

513
import msgpack
614
import nanoid
@@ -20,7 +28,12 @@
2028
)
2129
from replit_river.transport_options import TransportOptions, UriAndMetadata
2230
from replit_river.v2.client import Client
23-
from replit_river.v2.session import STREAM_CLOSED_BIT, Session
31+
from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session
32+
33+
34+
class OuterPayload[A](TypedDict):
35+
ok: Literal[True]
36+
payload: A
2437

2538

2639
class _PermissiveRateLimiter(RateLimiter):
@@ -231,7 +244,7 @@ async def handle_server_messages() -> None:
231244
stream_close_msg = msgpack.unpackb(await recv.get())
232245
assert stream_close_msg["controlFlags"] == STREAM_CLOSED_BIT
233246

234-
stream_handler = asyncio.create_task(handle_server_messages())
247+
server_handler = asyncio.create_task(handle_server_messages())
235248

236249
try:
237250
async for datagram in client.send_subscription(
@@ -245,5 +258,108 @@ async def handle_server_messages() -> None:
245258
await connecting
246259

247260
# Ensure we're listening to close messages as well
248-
stream_handler.cancel()
249-
await stream_handler
261+
server_handler.cancel()
262+
await server_handler
263+
264+
265+
async def test_upload_cancel(ws_server: WsServerFixture) -> None:
266+
(urimeta, recv, conn) = ws_server
267+
268+
client = Client(
269+
client_id="CLIENT1",
270+
server_id="SERVER",
271+
transport_options=TransportOptions(),
272+
uri_and_metadata_factory=urimeta,
273+
)
274+
275+
connecting = asyncio.create_task(client.ensure_connected())
276+
request_msg = parse_transport_msg(await recv.get())
277+
278+
assert not isinstance(request_msg, str)
279+
assert (serverconn := conn())
280+
handshake_request: ControlMessageHandshakeRequest[None] = (
281+
ControlMessageHandshakeRequest(**request_msg.payload)
282+
)
283+
284+
handshake_resp = ControlMessageHandshakeResponse(
285+
status=HandShakeStatus(
286+
ok=True,
287+
),
288+
)
289+
handshake_request.sessionId
290+
291+
msg = TransportMessage(
292+
from_=request_msg.from_,
293+
to=request_msg.to,
294+
streamId=request_msg.streamId,
295+
controlFlags=0,
296+
id=nanoid.generate(),
297+
seq=0,
298+
ack=0,
299+
payload=handshake_resp.model_dump(),
300+
)
301+
packed = msgpack.packb(
302+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
303+
)
304+
await serverconn.send(packed)
305+
306+
async def handle_server_messages() -> None:
307+
request_msg = parse_transport_msg(await recv.get())
308+
assert not isinstance(request_msg, str)
309+
310+
logging.debug("request_msg: %r", repr(request_msg))
311+
312+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
313+
while msg.payload.get("payload", {}).get("hello") == "world":
314+
logging.debug("Found a hello:world %r", repr(msg))
315+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
316+
317+
assert msg.controlFlags == STREAM_CANCEL_BIT
318+
319+
server_handler = asyncio.create_task(handle_server_messages())
320+
321+
sent_waiter = asyncio.Event()
322+
323+
async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]:
324+
count = 0
325+
while True:
326+
await asyncio.sleep(0.1)
327+
yield {
328+
"ok": True,
329+
"payload": {
330+
"hello": "world",
331+
},
332+
}
333+
count += 1
334+
if count > 5:
335+
# We've sent enough messages, interrupt the stream.
336+
sent_waiter.set()
337+
338+
upload_task = asyncio.create_task(
339+
client.send_upload(
340+
"test",
341+
"bigstream",
342+
{},
343+
upload_chunks(),
344+
lambda x: x,
345+
lambda x: x,
346+
lambda x: x,
347+
lambda x: x,
348+
)
349+
)
350+
351+
# Wait until we've seen at least a few messages from the upload Task
352+
await sent_waiter.wait()
353+
354+
upload_task.cancel()
355+
try:
356+
await upload_task
357+
except asyncio.CancelledError:
358+
pass
359+
360+
await client.close()
361+
await connecting
362+
363+
# Ensure we're listening to close messages as well
364+
server_handler.cancel()
365+
await server_handler

0 commit comments

Comments
 (0)