11import asyncio
2+ import logging
23from typing import AsyncIterator , Awaitable , Callable , TypeAlias , TypedDict
34
5+ import msgpack
6+ import nanoid
47import pytest
5- from websockets import ConnectionClosedOK
8+ from websockets import ConnectionClosed , ConnectionClosedOK
69from websockets .asyncio .server import ServerConnection , serve
710from websockets .typing import Data
811
912from replit_river .common_session import SessionState
1013from replit_river .messages import parse_transport_msg
1114from replit_river .rate_limiter import RateLimiter
12- from replit_river .rpc import TransportMessage
15+ from replit_river .rpc import (
16+ ControlMessageHandshakeRequest ,
17+ ControlMessageHandshakeResponse ,
18+ HandShakeStatus ,
19+ TransportMessage ,
20+ )
1321from replit_river .transport_options import TransportOptions , UriAndMetadata
14- from replit_river .v2 .session import Session
22+ from replit_river .v2 .client import Client
23+ from replit_river .v2 .session import STREAM_CLOSED_BIT , Session
1524
1625
1726class _PermissiveRateLimiter (RateLimiter ):
@@ -54,6 +63,8 @@ async def handle(websocket: ServerConnection) -> None:
5463 await recv .put (datagram )
5564 except ConnectionClosedOK :
5665 pass
66+ except ConnectionClosed :
67+ pass
5768
5869 port : int | None = None
5970 if state ["ipv4_laddr" ]:
@@ -65,7 +76,10 @@ async def handle(websocket: ServerConnection) -> None:
6576 state ["ipv4_laddr" ] = pair
6677 serve_forever = asyncio .create_task (server .serve_forever ())
6778 yield None
68- serve_forever .cancel ()
79+ server .close ()
80+ await server .wait_closed ()
81+ # "serve_forever" should always be done after wait_closed finishes
82+ assert serve_forever .done ()
6983
7084
7185@pytest .fixture
@@ -145,3 +159,89 @@ def close_session_callback(_session: Session) -> None:
145159 await connecting
146160 assert session ._state == SessionState .CLOSED
147161 assert callcount == 1
162+
163+
164+ async def test_big_packet (ws_server : WsServerFixture ) -> None :
165+ (urimeta , recv , conn ) = ws_server
166+
167+ client = Client (
168+ client_id = "CLIENT1" ,
169+ server_id = "SERVER" ,
170+ transport_options = TransportOptions (),
171+ uri_and_metadata_factory = urimeta ,
172+ )
173+
174+ connecting = asyncio .create_task (client .ensure_connected ())
175+ request_msg = parse_transport_msg (await recv .get ())
176+
177+ assert not isinstance (request_msg , str )
178+ assert (serverconn := conn ())
179+ handshake_request : ControlMessageHandshakeRequest [None ] = (
180+ ControlMessageHandshakeRequest (** request_msg .payload )
181+ )
182+
183+ handshake_resp = ControlMessageHandshakeResponse (
184+ status = HandShakeStatus (
185+ ok = True ,
186+ ),
187+ )
188+ handshake_request .sessionId
189+
190+ msg = TransportMessage (
191+ from_ = request_msg .from_ ,
192+ to = request_msg .to ,
193+ streamId = request_msg .streamId ,
194+ controlFlags = 0 ,
195+ id = nanoid .generate (),
196+ seq = 0 ,
197+ ack = 0 ,
198+ payload = handshake_resp .model_dump (),
199+ )
200+ packed = msgpack .packb (
201+ msg .model_dump (by_alias = True , exclude_none = True ), datetime = True
202+ )
203+ await serverconn .send (packed )
204+
205+ async def handle_server_messages () -> None :
206+ request_msg = parse_transport_msg (await recv .get ())
207+ assert not isinstance (request_msg , str )
208+ msg = TransportMessage (
209+ from_ = request_msg .to ,
210+ to = request_msg .from_ ,
211+ streamId = request_msg .streamId ,
212+ controlFlags = STREAM_CLOSED_BIT ,
213+ id = nanoid .generate (),
214+ seq = 0 ,
215+ ack = 0 ,
216+ payload = {
217+ "ok" : True ,
218+ "payload" : {
219+ "big" : "a" * (2 ** 20 + 1 ), # One more than the default max_size
220+ },
221+ },
222+ )
223+
224+ packed = msgpack .packb (
225+ msg .model_dump (by_alias = True , exclude_none = True ), datetime = True
226+ )
227+ await serverconn .send (packed )
228+
229+ stream_close_msg = msgpack .unpackb (await recv .get ())
230+ assert stream_close_msg ["controlFlags" ] == STREAM_CLOSED_BIT
231+
232+ stream_handler = asyncio .create_task (handle_server_messages ())
233+
234+ try :
235+ async for datagram in client .send_subscription (
236+ "test" , "bigstream" , {}, lambda x : x , lambda x : x , lambda x : x
237+ ):
238+ print (datagram )
239+ except Exception :
240+ logging .exception ("Interrupted" )
241+
242+ await client .close ()
243+ await connecting
244+
245+ # Ensure we're listening to close messages as well
246+ stream_handler .cancel ()
247+ await stream_handler
0 commit comments