Skip to content

Commit 313467e

Browse files
Adding a raw connection test
1 parent 2617a0b commit 313467e

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import asyncio
2+
from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict
3+
4+
import pytest
5+
from websockets import ConnectionClosedOK
6+
from websockets.asyncio.server import ServerConnection, serve
7+
from websockets.typing import Data
8+
9+
from replit_river.messages import parse_transport_msg
10+
from replit_river.rate_limiter import RateLimiter
11+
from replit_river.rpc import TransportMessage
12+
from replit_river.transport_options import TransportOptions, UriAndMetadata
13+
from replit_river.v2.session import Session
14+
15+
16+
class _PermissiveRateLimiter(RateLimiter):
17+
def start_restoring_budget(self, user: str) -> None:
18+
pass
19+
20+
def get_backoff_ms(self, user: str) -> float:
21+
return 0
22+
23+
def has_budget(self, user: str) -> bool:
24+
return True
25+
26+
def consume_budget(self, user: str) -> None:
27+
pass
28+
29+
30+
WsServerFixture: TypeAlias = tuple[
31+
Callable[[], Awaitable[UriAndMetadata[None]]],
32+
asyncio.Queue[bytes],
33+
Callable[[], ServerConnection | None],
34+
]
35+
36+
37+
class _WsServerState(TypedDict):
38+
ipv4_laddr: tuple[str, int] | None
39+
40+
41+
async def _ws_server_internal(
42+
recv: asyncio.Queue[bytes],
43+
set_conn: Callable[[ServerConnection], None],
44+
state: _WsServerState,
45+
) -> AsyncIterator[None]:
46+
async def handle(websocket: ServerConnection) -> None:
47+
set_conn(websocket)
48+
datagram: Data
49+
try:
50+
while datagram := await websocket.recv(decode=False):
51+
if isinstance(datagram, str):
52+
continue
53+
await recv.put(datagram)
54+
except ConnectionClosedOK:
55+
pass
56+
57+
port: int | None = None
58+
if state["ipv4_laddr"]:
59+
port = state["ipv4_laddr"][1]
60+
async with serve(handle, "localhost", port=port) as server:
61+
for sock in server.sockets:
62+
if (pair := sock.getsockname())[0] == "127.0.0.1":
63+
if state["ipv4_laddr"] is None:
64+
state["ipv4_laddr"] = pair
65+
serve_forever = asyncio.create_task(server.serve_forever())
66+
yield None
67+
serve_forever.cancel()
68+
69+
70+
@pytest.fixture
71+
async def ws_server() -> AsyncIterator[WsServerFixture]:
72+
recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1)
73+
connection: ServerConnection | None = None
74+
state: _WsServerState = {"ipv4_laddr": None}
75+
76+
def set_conn(new_conn: ServerConnection) -> None:
77+
nonlocal connection
78+
connection = new_conn
79+
80+
server_generator = _ws_server_internal(recv, set_conn, state)
81+
await anext(server_generator)
82+
83+
async def urimeta() -> UriAndMetadata[None]:
84+
ipv4_laddr = state["ipv4_laddr"]
85+
assert ipv4_laddr
86+
return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None)
87+
88+
yield (urimeta, recv, lambda: connection)
89+
90+
try:
91+
await anext(server_generator)
92+
except StopAsyncIteration:
93+
pass
94+
95+
96+
async def test_connect(ws_server: WsServerFixture) -> None:
97+
(urimeta, recv, conn) = ws_server
98+
99+
session = Session(
100+
server_id="SERVER",
101+
session_id="SESSION1",
102+
transport_options=TransportOptions(),
103+
close_session_callback=lambda _: None,
104+
client_id="CLIENT1",
105+
rate_limiter=_PermissiveRateLimiter(),
106+
uri_and_metadata_factory=urimeta,
107+
)
108+
109+
connecting = asyncio.create_task(session.ensure_connected())
110+
msg = parse_transport_msg(await recv.get())
111+
assert isinstance(msg, TransportMessage)
112+
assert msg.payload["type"] == "HANDSHAKE_REQ"
113+
await session.close()
114+
await connecting
115+
116+
117+
async def test_reconnect(ws_server: WsServerFixture) -> None:
118+
(urimeta, recv, conn) = ws_server
119+
120+
session = Session(
121+
server_id="SERVER",
122+
session_id="SESSION1",
123+
transport_options=TransportOptions(),
124+
close_session_callback=lambda _: None,
125+
client_id="CLIENT1",
126+
rate_limiter=_PermissiveRateLimiter(),
127+
uri_and_metadata_factory=urimeta,
128+
)
129+
130+
connecting = asyncio.create_task(session.ensure_connected())
131+
msg = parse_transport_msg(await recv.get())
132+
assert isinstance(msg, TransportMessage)
133+
assert msg.payload["type"] == "HANDSHAKE_REQ"
134+
await session.close()
135+
await connecting

0 commit comments

Comments
 (0)