11import asyncio
22import logging
3- from typing import AsyncIterator , Awaitable , Callable , TypeAlias , TypedDict
3+ from typing import (
4+ Any ,
5+ AsyncIterator ,
6+ Awaitable ,
7+ Callable ,
8+ Literal ,
9+ TypeAlias ,
10+ TypedDict ,
11+ )
412
513import msgpack
614import nanoid
2028)
2129from replit_river .transport_options import TransportOptions , UriAndMetadata
2230from replit_river .v2 .client import Client
23- from replit_river .v2 .session import STREAM_CLOSED_BIT , Session
31+ from replit_river .v2 .session import STREAM_CANCEL_BIT , STREAM_CLOSED_BIT , Session
32+
33+
34+ class OuterPayload [A ](TypedDict ):
35+ ok : Literal [True ]
36+ payload : A
2437
2538
2639class _PermissiveRateLimiter (RateLimiter ):
@@ -231,7 +244,7 @@ async def handle_server_messages() -> None:
231244 stream_close_msg = msgpack .unpackb (await recv .get ())
232245 assert stream_close_msg ["controlFlags" ] == STREAM_CLOSED_BIT
233246
234- stream_handler = asyncio .create_task (handle_server_messages ())
247+ server_handler = asyncio .create_task (handle_server_messages ())
235248
236249 try :
237250 async for datagram in client .send_subscription (
@@ -245,5 +258,108 @@ async def handle_server_messages() -> None:
245258 await connecting
246259
247260 # Ensure we're listening to close messages as well
248- stream_handler .cancel ()
249- await stream_handler
261+ server_handler .cancel ()
262+ await server_handler
263+
264+
265+ async def test_upload_cancel (ws_server : WsServerFixture ) -> None :
266+ (urimeta , recv , conn ) = ws_server
267+
268+ client = Client (
269+ client_id = "CLIENT1" ,
270+ server_id = "SERVER" ,
271+ transport_options = TransportOptions (),
272+ uri_and_metadata_factory = urimeta ,
273+ )
274+
275+ connecting = asyncio .create_task (client .ensure_connected ())
276+ request_msg = parse_transport_msg (await recv .get ())
277+
278+ assert not isinstance (request_msg , str )
279+ assert (serverconn := conn ())
280+ handshake_request : ControlMessageHandshakeRequest [None ] = (
281+ ControlMessageHandshakeRequest (** request_msg .payload )
282+ )
283+
284+ handshake_resp = ControlMessageHandshakeResponse (
285+ status = HandShakeStatus (
286+ ok = True ,
287+ ),
288+ )
289+ handshake_request .sessionId
290+
291+ msg = TransportMessage (
292+ from_ = request_msg .from_ ,
293+ to = request_msg .to ,
294+ streamId = request_msg .streamId ,
295+ controlFlags = 0 ,
296+ id = nanoid .generate (),
297+ seq = 0 ,
298+ ack = 0 ,
299+ payload = handshake_resp .model_dump (),
300+ )
301+ packed = msgpack .packb (
302+ msg .model_dump (by_alias = True , exclude_none = True ), datetime = True
303+ )
304+ await serverconn .send (packed )
305+
306+ async def handle_server_messages () -> None :
307+ request_msg = parse_transport_msg (await recv .get ())
308+ assert not isinstance (request_msg , str )
309+
310+ logging .debug ("request_msg: %r" , repr (request_msg ))
311+
312+ msg = TransportMessage (** msgpack .unpackb (await recv .get ()))
313+ while msg .payload .get ("payload" , {}).get ("hello" ) == "world" :
314+ logging .debug ("Found a hello:world %r" , repr (msg ))
315+ msg = TransportMessage (** msgpack .unpackb (await recv .get ()))
316+
317+ assert msg .controlFlags == STREAM_CANCEL_BIT
318+
319+ server_handler = asyncio .create_task (handle_server_messages ())
320+
321+ sent_waiter = asyncio .Event ()
322+
323+ async def upload_chunks () -> AsyncIterator [OuterPayload [dict [Any , Any ]]]:
324+ count = 0
325+ while True :
326+ await asyncio .sleep (0.1 )
327+ yield {
328+ "ok" : True ,
329+ "payload" : {
330+ "hello" : "world" ,
331+ },
332+ }
333+ count += 1
334+ if count > 5 :
335+ # We've sent enough messages, interrupt the stream.
336+ sent_waiter .set ()
337+
338+ upload_task = asyncio .create_task (
339+ client .send_upload (
340+ "test" ,
341+ "bigstream" ,
342+ {},
343+ upload_chunks (),
344+ lambda x : x ,
345+ lambda x : x ,
346+ lambda x : x ,
347+ lambda x : x ,
348+ )
349+ )
350+
351+ # Wait until we've seen at least a few messages from the upload Task
352+ await sent_waiter .wait ()
353+
354+ upload_task .cancel ()
355+ try :
356+ await upload_task
357+ except asyncio .CancelledError :
358+ pass
359+
360+ await client .close ()
361+ await connecting
362+
363+ # Ensure we're listening to close messages as well
364+ server_handler .cancel ()
365+ await server_handler
0 commit comments