Skip to content

Commit b518386

Browse files
Moving test_upload_cancel out
1 parent 6b2b84b commit b518386

File tree

2 files changed

+133
-114
lines changed

2 files changed

+133
-114
lines changed

tests/v2/test_v2_cancellation.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import asyncio
2+
import logging
3+
from typing import (
4+
Any,
5+
AsyncIterator,
6+
Literal,
7+
TypedDict,
8+
)
9+
10+
import msgpack
11+
import nanoid
12+
13+
from replit_river.messages import parse_transport_msg
14+
from replit_river.rpc import (
15+
ControlMessageHandshakeRequest,
16+
ControlMessageHandshakeResponse,
17+
HandShakeStatus,
18+
TransportMessage,
19+
)
20+
from replit_river.transport_options import TransportOptions
21+
from replit_river.v2.client import Client
22+
from replit_river.v2.session import STREAM_CANCEL_BIT
23+
from tests.v2.fixtures.raw_ws_server import WsServerFixture
24+
25+
26+
class OuterPayload[A](TypedDict):
27+
ok: Literal[True]
28+
payload: A
29+
30+
31+
async def test_upload_cancel(ws_server: WsServerFixture) -> None:
32+
(urimeta, recv, conn) = ws_server
33+
34+
client = Client(
35+
client_id="CLIENT1",
36+
server_id="SERVER",
37+
transport_options=TransportOptions(),
38+
uri_and_metadata_factory=urimeta,
39+
)
40+
41+
connecting = asyncio.create_task(client.ensure_connected())
42+
request_msg = parse_transport_msg(await recv.get())
43+
44+
assert not isinstance(request_msg, str)
45+
assert (serverconn := conn())
46+
handshake_request: ControlMessageHandshakeRequest[None] = (
47+
ControlMessageHandshakeRequest(**request_msg.payload)
48+
)
49+
50+
handshake_resp = ControlMessageHandshakeResponse(
51+
status=HandShakeStatus(
52+
ok=True,
53+
),
54+
)
55+
handshake_request.sessionId
56+
57+
msg = TransportMessage(
58+
from_=request_msg.from_,
59+
to=request_msg.to,
60+
streamId=request_msg.streamId,
61+
controlFlags=0,
62+
id=nanoid.generate(),
63+
seq=0,
64+
ack=0,
65+
payload=handshake_resp.model_dump(),
66+
)
67+
packed = msgpack.packb(
68+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
69+
)
70+
await serverconn.send(packed)
71+
72+
async def handle_server_messages() -> None:
73+
request_msg = parse_transport_msg(await recv.get())
74+
assert not isinstance(request_msg, str)
75+
76+
logging.debug("request_msg: %r", repr(request_msg))
77+
78+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
79+
while msg.payload.get("payload", {}).get("hello") == "world":
80+
logging.debug("Found a hello:world %r", repr(msg))
81+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
82+
83+
assert msg.controlFlags == STREAM_CANCEL_BIT
84+
85+
server_handler = asyncio.create_task(handle_server_messages())
86+
87+
sent_waiter = asyncio.Event()
88+
89+
async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]:
90+
count = 0
91+
while True:
92+
await asyncio.sleep(0.1)
93+
yield {
94+
"ok": True,
95+
"payload": {
96+
"hello": "world",
97+
},
98+
}
99+
count += 1
100+
if count > 5:
101+
# We've sent enough messages, interrupt the stream.
102+
sent_waiter.set()
103+
104+
upload_task = asyncio.create_task(
105+
client.send_upload(
106+
"test",
107+
"bigstream",
108+
{},
109+
upload_chunks(),
110+
lambda x: x,
111+
lambda x: x,
112+
lambda x: x,
113+
lambda x: x,
114+
)
115+
)
116+
117+
# Wait until we've seen at least a few messages from the upload Task
118+
await sent_waiter.wait()
119+
120+
upload_task.cancel()
121+
try:
122+
await upload_task
123+
except asyncio.CancelledError:
124+
pass
125+
126+
await client.close()
127+
await connecting
128+
129+
# Ensure we're listening to close messages as well
130+
server_handler.cancel()
131+
await server_handler

tests/v2/test_v2_session_lifecycle.py

Lines changed: 2 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
11
import asyncio
22
import logging
33
from typing import (
4-
Any,
5-
AsyncIterator,
6-
Awaitable,
7-
Callable,
84
Literal,
9-
TypeAlias,
105
TypedDict,
116
)
127

138
import msgpack
149
import nanoid
15-
import pytest
16-
from websockets import ConnectionClosed, ConnectionClosedOK
17-
from websockets.asyncio.server import ServerConnection, serve
18-
from websockets.typing import Data
1910

2011
from replit_river.common_session import SessionState
2112
from replit_river.messages import parse_transport_msg
@@ -26,9 +17,9 @@
2617
HandShakeStatus,
2718
TransportMessage,
2819
)
29-
from replit_river.transport_options import TransportOptions, UriAndMetadata
20+
from replit_river.transport_options import TransportOptions
3021
from replit_river.v2.client import Client
31-
from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session
22+
from replit_river.v2.session import STREAM_CLOSED_BIT, Session
3223
from tests.v2.fixtures.raw_ws_server import WsServerFixture
3324

3425

@@ -188,106 +179,3 @@ async def handle_server_messages() -> None:
188179
# Ensure we're listening to close messages as well
189180
server_handler.cancel()
190181
await server_handler
191-
192-
193-
async def test_upload_cancel(ws_server: WsServerFixture) -> None:
194-
(urimeta, recv, conn) = ws_server
195-
196-
client = Client(
197-
client_id="CLIENT1",
198-
server_id="SERVER",
199-
transport_options=TransportOptions(),
200-
uri_and_metadata_factory=urimeta,
201-
)
202-
203-
connecting = asyncio.create_task(client.ensure_connected())
204-
request_msg = parse_transport_msg(await recv.get())
205-
206-
assert not isinstance(request_msg, str)
207-
assert (serverconn := conn())
208-
handshake_request: ControlMessageHandshakeRequest[None] = (
209-
ControlMessageHandshakeRequest(**request_msg.payload)
210-
)
211-
212-
handshake_resp = ControlMessageHandshakeResponse(
213-
status=HandShakeStatus(
214-
ok=True,
215-
),
216-
)
217-
handshake_request.sessionId
218-
219-
msg = TransportMessage(
220-
from_=request_msg.from_,
221-
to=request_msg.to,
222-
streamId=request_msg.streamId,
223-
controlFlags=0,
224-
id=nanoid.generate(),
225-
seq=0,
226-
ack=0,
227-
payload=handshake_resp.model_dump(),
228-
)
229-
packed = msgpack.packb(
230-
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
231-
)
232-
await serverconn.send(packed)
233-
234-
async def handle_server_messages() -> None:
235-
request_msg = parse_transport_msg(await recv.get())
236-
assert not isinstance(request_msg, str)
237-
238-
logging.debug("request_msg: %r", repr(request_msg))
239-
240-
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
241-
while msg.payload.get("payload", {}).get("hello") == "world":
242-
logging.debug("Found a hello:world %r", repr(msg))
243-
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
244-
245-
assert msg.controlFlags == STREAM_CANCEL_BIT
246-
247-
server_handler = asyncio.create_task(handle_server_messages())
248-
249-
sent_waiter = asyncio.Event()
250-
251-
async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]:
252-
count = 0
253-
while True:
254-
await asyncio.sleep(0.1)
255-
yield {
256-
"ok": True,
257-
"payload": {
258-
"hello": "world",
259-
},
260-
}
261-
count += 1
262-
if count > 5:
263-
# We've sent enough messages, interrupt the stream.
264-
sent_waiter.set()
265-
266-
upload_task = asyncio.create_task(
267-
client.send_upload(
268-
"test",
269-
"bigstream",
270-
{},
271-
upload_chunks(),
272-
lambda x: x,
273-
lambda x: x,
274-
lambda x: x,
275-
lambda x: x,
276-
)
277-
)
278-
279-
# Wait until we've seen at least a few messages from the upload Task
280-
await sent_waiter.wait()
281-
282-
upload_task.cancel()
283-
try:
284-
await upload_task
285-
except asyncio.CancelledError:
286-
pass
287-
288-
await client.close()
289-
await connecting
290-
291-
# Ensure we're listening to close messages as well
292-
server_handler.cancel()
293-
await server_handler

0 commit comments

Comments
 (0)