Skip to content

Commit eb9088a

Browse files
Adding a test for big packets
1 parent 4adb4a0 commit eb9088a

File tree

1 file changed

+104
-4
lines changed

1 file changed

+104
-4
lines changed

tests/v2/test_v2_session_lifecycle.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
import asyncio
2+
import logging
23
from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict
34

5+
import msgpack
6+
import nanoid
47
import pytest
5-
from websockets import ConnectionClosedOK
8+
from websockets import ConnectionClosed, ConnectionClosedOK
69
from websockets.asyncio.server import ServerConnection, serve
710
from websockets.typing import Data
811

912
from replit_river.common_session import SessionState
1013
from replit_river.messages import parse_transport_msg
1114
from replit_river.rate_limiter import RateLimiter
12-
from replit_river.rpc import TransportMessage
15+
from replit_river.rpc import (
16+
ControlMessageHandshakeRequest,
17+
ControlMessageHandshakeResponse,
18+
HandShakeStatus,
19+
TransportMessage,
20+
)
1321
from replit_river.transport_options import TransportOptions, UriAndMetadata
14-
from replit_river.v2.session import Session
22+
from replit_river.v2.client import Client
23+
from replit_river.v2.session import STREAM_CLOSED_BIT, Session
1524

1625

1726
class _PermissiveRateLimiter(RateLimiter):
@@ -54,6 +63,8 @@ async def handle(websocket: ServerConnection) -> None:
5463
await recv.put(datagram)
5564
except ConnectionClosedOK:
5665
pass
66+
except ConnectionClosed:
67+
pass
5768

5869
port: int | None = None
5970
if state["ipv4_laddr"]:
@@ -65,7 +76,10 @@ async def handle(websocket: ServerConnection) -> None:
6576
state["ipv4_laddr"] = pair
6677
serve_forever = asyncio.create_task(server.serve_forever())
6778
yield None
68-
serve_forever.cancel()
79+
server.close()
80+
await server.wait_closed()
81+
# "serve_forever" should always be done after wait_closed finishes
82+
assert serve_forever.done()
6983

7084

7185
@pytest.fixture
@@ -145,3 +159,89 @@ def close_session_callback(_session: Session) -> None:
145159
await connecting
146160
assert session._state == SessionState.CLOSED
147161
assert callcount == 1
162+
163+
164+
async def test_big_packet(ws_server: WsServerFixture) -> None:
165+
(urimeta, recv, conn) = ws_server
166+
167+
client = Client(
168+
client_id="CLIENT1",
169+
server_id="SERVER",
170+
transport_options=TransportOptions(),
171+
uri_and_metadata_factory=urimeta,
172+
)
173+
174+
connecting = asyncio.create_task(client.ensure_connected())
175+
request_msg = parse_transport_msg(await recv.get())
176+
177+
assert not isinstance(request_msg, str)
178+
assert (serverconn := conn())
179+
handshake_request: ControlMessageHandshakeRequest[None] = (
180+
ControlMessageHandshakeRequest(**request_msg.payload)
181+
)
182+
183+
handshake_resp = ControlMessageHandshakeResponse(
184+
status=HandShakeStatus(
185+
ok=True,
186+
),
187+
)
188+
handshake_request.sessionId
189+
190+
msg = TransportMessage(
191+
from_=request_msg.from_,
192+
to=request_msg.to,
193+
streamId=request_msg.streamId,
194+
controlFlags=0,
195+
id=nanoid.generate(),
196+
seq=0,
197+
ack=0,
198+
payload=handshake_resp.model_dump(),
199+
)
200+
packed = msgpack.packb(
201+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
202+
)
203+
await serverconn.send(packed)
204+
205+
async def handle_server_messages() -> None:
206+
request_msg = parse_transport_msg(await recv.get())
207+
assert not isinstance(request_msg, str)
208+
msg = TransportMessage(
209+
from_=request_msg.to,
210+
to=request_msg.from_,
211+
streamId=request_msg.streamId,
212+
controlFlags=STREAM_CLOSED_BIT,
213+
id=nanoid.generate(),
214+
seq=0,
215+
ack=0,
216+
payload={
217+
"ok": True,
218+
"payload": {
219+
"big": "a" * (2**20 + 1), # One more than the default max_size
220+
},
221+
},
222+
)
223+
224+
packed = msgpack.packb(
225+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
226+
)
227+
await serverconn.send(packed)
228+
229+
stream_close_msg = msgpack.unpackb(await recv.get())
230+
assert stream_close_msg["controlFlags"] == STREAM_CLOSED_BIT
231+
232+
stream_handler = asyncio.create_task(handle_server_messages())
233+
234+
try:
235+
async for datagram in client.send_subscription(
236+
"test", "bigstream", {}, lambda x: x, lambda x: x, lambda x: x
237+
):
238+
print(datagram)
239+
except Exception:
240+
logging.exception("Interrupted")
241+
242+
await client.close()
243+
await connecting
244+
245+
# Ensure we're listening to close messages as well
246+
stream_handler.cancel()
247+
await stream_handler

0 commit comments

Comments
 (0)