Skip to content

Commit 6b2b84b

Browse files
Moving fixtures out
1 parent b7074c6 commit 6b2b84b

File tree

3 files changed

+88
-73
lines changed

3 files changed

+88
-73
lines changed

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"tests.v1.river_fixtures.logging",
1919
"tests.v1.river_fixtures.clientserver",
2020
"tests.v2.fixtures.bound_client",
21+
"tests.v2.fixtures.raw_ws_server",
2122
]
2223

2324
HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]

tests/v2/fixtures/raw_ws_server.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
from typing import (
3+
AsyncIterator,
4+
Awaitable,
5+
Callable,
6+
TypeAlias,
7+
TypedDict,
8+
)
9+
10+
import pytest
11+
from websockets import ConnectionClosed, ConnectionClosedOK, Data
12+
from websockets.asyncio.server import ServerConnection, serve
13+
14+
from replit_river.transport_options import UriAndMetadata
15+
16+
WsServerFixture: TypeAlias = tuple[
17+
Callable[[], Awaitable[UriAndMetadata[None]]],
18+
asyncio.Queue[bytes],
19+
Callable[[], ServerConnection | None],
20+
]
21+
22+
23+
class _WsServerState(TypedDict):
24+
ipv4_laddr: tuple[str, int] | None
25+
26+
27+
async def _ws_server_internal(
28+
recv: asyncio.Queue[bytes],
29+
set_conn: Callable[[ServerConnection], None],
30+
state: _WsServerState,
31+
) -> AsyncIterator[None]:
32+
async def handle(websocket: ServerConnection) -> None:
33+
set_conn(websocket)
34+
datagram: Data
35+
try:
36+
while datagram := await websocket.recv(decode=False):
37+
if isinstance(datagram, str):
38+
continue
39+
await recv.put(datagram)
40+
except ConnectionClosedOK:
41+
pass
42+
except ConnectionClosed:
43+
pass
44+
45+
port: int | None = None
46+
if state["ipv4_laddr"]:
47+
port = state["ipv4_laddr"][1]
48+
async with serve(handle, "localhost", port=port) as server:
49+
for sock in server.sockets:
50+
if (pair := sock.getsockname())[0] == "127.0.0.1":
51+
if state["ipv4_laddr"] is None:
52+
state["ipv4_laddr"] = pair
53+
serve_forever = asyncio.create_task(server.serve_forever())
54+
yield None
55+
server.close()
56+
await server.wait_closed()
57+
# "serve_forever" should always be done after wait_closed finishes
58+
assert serve_forever.done()
59+
60+
61+
@pytest.fixture
62+
async def ws_server() -> AsyncIterator[WsServerFixture]:
63+
recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1)
64+
connection: ServerConnection | None = None
65+
state: _WsServerState = {"ipv4_laddr": None}
66+
67+
def set_conn(new_conn: ServerConnection) -> None:
68+
nonlocal connection
69+
connection = new_conn
70+
71+
server_generator = _ws_server_internal(recv, set_conn, state)
72+
await anext(server_generator)
73+
74+
async def urimeta() -> UriAndMetadata[None]:
75+
ipv4_laddr = state["ipv4_laddr"]
76+
assert ipv4_laddr
77+
return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None)
78+
79+
yield (urimeta, recv, lambda: connection)
80+
81+
connection = None
82+
83+
try:
84+
await anext(server_generator)
85+
except StopAsyncIteration:
86+
pass

tests/v2/test_v2_session_lifecycle.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from replit_river.transport_options import TransportOptions, UriAndMetadata
3030
from replit_river.v2.client import Client
3131
from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session
32+
from tests.v2.fixtures.raw_ws_server import WsServerFixture
3233

3334

3435
class OuterPayload[A](TypedDict):
@@ -50,79 +51,6 @@ def consume_budget(self, user: str) -> None:
5051
pass
5152

5253

53-
WsServerFixture: TypeAlias = tuple[
54-
Callable[[], Awaitable[UriAndMetadata[None]]],
55-
asyncio.Queue[bytes],
56-
Callable[[], ServerConnection | None],
57-
]
58-
59-
60-
class _WsServerState(TypedDict):
61-
ipv4_laddr: tuple[str, int] | None
62-
63-
64-
async def _ws_server_internal(
65-
recv: asyncio.Queue[bytes],
66-
set_conn: Callable[[ServerConnection], None],
67-
state: _WsServerState,
68-
) -> AsyncIterator[None]:
69-
async def handle(websocket: ServerConnection) -> None:
70-
set_conn(websocket)
71-
datagram: Data
72-
try:
73-
while datagram := await websocket.recv(decode=False):
74-
if isinstance(datagram, str):
75-
continue
76-
await recv.put(datagram)
77-
except ConnectionClosedOK:
78-
pass
79-
except ConnectionClosed:
80-
pass
81-
82-
port: int | None = None
83-
if state["ipv4_laddr"]:
84-
port = state["ipv4_laddr"][1]
85-
async with serve(handle, "localhost", port=port) as server:
86-
for sock in server.sockets:
87-
if (pair := sock.getsockname())[0] == "127.0.0.1":
88-
if state["ipv4_laddr"] is None:
89-
state["ipv4_laddr"] = pair
90-
serve_forever = asyncio.create_task(server.serve_forever())
91-
yield None
92-
server.close()
93-
await server.wait_closed()
94-
# "serve_forever" should always be done after wait_closed finishes
95-
assert serve_forever.done()
96-
97-
98-
@pytest.fixture
99-
async def ws_server() -> AsyncIterator[WsServerFixture]:
100-
recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1)
101-
connection: ServerConnection | None = None
102-
state: _WsServerState = {"ipv4_laddr": None}
103-
104-
def set_conn(new_conn: ServerConnection) -> None:
105-
nonlocal connection
106-
connection = new_conn
107-
108-
server_generator = _ws_server_internal(recv, set_conn, state)
109-
await anext(server_generator)
110-
111-
async def urimeta() -> UriAndMetadata[None]:
112-
ipv4_laddr = state["ipv4_laddr"]
113-
assert ipv4_laddr
114-
return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None)
115-
116-
yield (urimeta, recv, lambda: connection)
117-
118-
connection = None
119-
120-
try:
121-
await anext(server_generator)
122-
except StopAsyncIteration:
123-
pass
124-
125-
12654
async def test_connect(ws_server: WsServerFixture) -> None:
12755
(urimeta, recv, conn) = ws_server
12856

0 commit comments

Comments
 (0)