Skip to content

Commit fb8d312

Browse files
[v2 bug] Disable websocket max_size (#156)
1 parent eb70fc0 commit fb8d312

File tree

3 files changed

+109
-6
lines changed

3 files changed

+109
-6
lines changed

src/replit_river/client_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ async def _establish_new_connection(
170170

171171
try:
172172
uri_and_metadata = await self._uri_and_metadata_factory()
173-
ws = await websockets.connect(uri_and_metadata["uri"])
173+
ws = await websockets.connect(uri_and_metadata["uri"], max_size=None)
174174
session_id = (
175175
self.generate_nanoid()
176176
if not old_session

src/replit_river/v2/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,10 @@ async def _do_ensure_connected[HandshakeMetadata](
10841084
ws: ClientConnection | None = None
10851085
try:
10861086
uri_and_metadata = await uri_and_metadata_factory()
1087-
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
1087+
ws = await websockets.asyncio.client.connect(
1088+
uri_and_metadata["uri"],
1089+
max_size=None,
1090+
)
10881091
transition_connecting(ws)
10891092

10901093
try:

tests/v2/test_v2_session_lifecycle.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
import asyncio
2+
import logging
23
from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict
34

5+
import msgpack
6+
import nanoid
47
import pytest
5-
from websockets import ConnectionClosedOK
8+
from websockets import ConnectionClosed, ConnectionClosedOK
69
from websockets.asyncio.server import ServerConnection, serve
710
from websockets.typing import Data
811

912
from replit_river.common_session import SessionState
1013
from replit_river.messages import parse_transport_msg
1114
from replit_river.rate_limiter import RateLimiter
12-
from replit_river.rpc import TransportMessage
15+
from replit_river.rpc import (
16+
ControlMessageHandshakeRequest,
17+
ControlMessageHandshakeResponse,
18+
HandShakeStatus,
19+
TransportMessage,
20+
)
1321
from replit_river.transport_options import TransportOptions, UriAndMetadata
14-
from replit_river.v2.session import Session
22+
from replit_river.v2.client import Client
23+
from replit_river.v2.session import STREAM_CLOSED_BIT, Session
1524

1625

1726
class _PermissiveRateLimiter(RateLimiter):
@@ -54,6 +63,8 @@ async def handle(websocket: ServerConnection) -> None:
5463
await recv.put(datagram)
5564
except ConnectionClosedOK:
5665
pass
66+
except ConnectionClosed:
67+
pass
5768

5869
port: int | None = None
5970
if state["ipv4_laddr"]:
@@ -65,7 +76,10 @@ async def handle(websocket: ServerConnection) -> None:
6576
state["ipv4_laddr"] = pair
6677
serve_forever = asyncio.create_task(server.serve_forever())
6778
yield None
68-
serve_forever.cancel()
79+
server.close()
80+
await server.wait_closed()
81+
# "serve_forever" should always be done after wait_closed finishes
82+
assert serve_forever.done()
6983

7084

7185
@pytest.fixture
@@ -145,3 +159,89 @@ def close_session_callback(_session: Session) -> None:
145159
await connecting
146160
assert session._state == SessionState.CLOSED
147161
assert callcount == 1
162+
163+
164+
async def test_big_packet(ws_server: WsServerFixture) -> None:
165+
(urimeta, recv, conn) = ws_server
166+
167+
client = Client(
168+
client_id="CLIENT1",
169+
server_id="SERVER",
170+
transport_options=TransportOptions(),
171+
uri_and_metadata_factory=urimeta,
172+
)
173+
174+
connecting = asyncio.create_task(client.ensure_connected())
175+
request_msg = parse_transport_msg(await recv.get())
176+
177+
assert not isinstance(request_msg, str)
178+
assert (serverconn := conn())
179+
handshake_request: ControlMessageHandshakeRequest[None] = (
180+
ControlMessageHandshakeRequest(**request_msg.payload)
181+
)
182+
183+
handshake_resp = ControlMessageHandshakeResponse(
184+
status=HandShakeStatus(
185+
ok=True,
186+
),
187+
)
188+
handshake_request.sessionId
189+
190+
msg = TransportMessage(
191+
from_=request_msg.from_,
192+
to=request_msg.to,
193+
streamId=request_msg.streamId,
194+
controlFlags=0,
195+
id=nanoid.generate(),
196+
seq=0,
197+
ack=0,
198+
payload=handshake_resp.model_dump(),
199+
)
200+
packed = msgpack.packb(
201+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
202+
)
203+
await serverconn.send(packed)
204+
205+
async def handle_server_messages() -> None:
206+
request_msg = parse_transport_msg(await recv.get())
207+
assert not isinstance(request_msg, str)
208+
msg = TransportMessage(
209+
from_=request_msg.to,
210+
to=request_msg.from_,
211+
streamId=request_msg.streamId,
212+
controlFlags=STREAM_CLOSED_BIT,
213+
id=nanoid.generate(),
214+
seq=0,
215+
ack=0,
216+
payload={
217+
"ok": True,
218+
"payload": {
219+
"big": "a" * (2**20 + 1), # One more than the default max_size
220+
},
221+
},
222+
)
223+
224+
packed = msgpack.packb(
225+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
226+
)
227+
await serverconn.send(packed)
228+
229+
stream_close_msg = msgpack.unpackb(await recv.get())
230+
assert stream_close_msg["controlFlags"] == STREAM_CLOSED_BIT
231+
232+
stream_handler = asyncio.create_task(handle_server_messages())
233+
234+
try:
235+
async for datagram in client.send_subscription(
236+
"test", "bigstream", {}, lambda x: x, lambda x: x, lambda x: x
237+
):
238+
print(datagram)
239+
except Exception:
240+
logging.exception("Interrupted")
241+
242+
await client.close()
243+
await connecting
244+
245+
# Ensure we're listening to close messages as well
246+
stream_handler.cancel()
247+
await stream_handler

0 commit comments

Comments
 (0)