Skip to content

Commit d6cda64

Browse files
Adding an RPC cancellation test
1 parent a93b0e9 commit d6cda64

File tree

1 file changed

+95
-1
lines changed

1 file changed

+95
-1
lines changed

tests/v2/test_v2_cancellation.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
from datetime import timedelta
34
from typing import (
45
Any,
56
AsyncIterator,
@@ -10,17 +11,110 @@
1011

1112
from replit_river.messages import parse_transport_msg
1213
from replit_river.rpc import (
14+
STREAM_OPEN_BIT,
1315
ControlMessageHandshakeRequest,
1416
ControlMessageHandshakeResponse,
1517
HandShakeStatus,
1618
TransportMessage,
1719
)
1820
from replit_river.transport_options import TransportOptions
1921
from replit_river.v2.client import Client
20-
from replit_river.v2.session import STREAM_CANCEL_BIT
22+
from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT
2123
from tests.v2.fixtures.raw_ws_server import OuterPayload, WsServerFixture
2224

2325

26+
async def test_rpc_cancel(ws_server: WsServerFixture) -> None:
27+
(urimeta, recv, conn) = ws_server
28+
29+
client = Client(
30+
client_id="CLIENT1",
31+
server_id="SERVER",
32+
transport_options=TransportOptions(),
33+
uri_and_metadata_factory=urimeta,
34+
)
35+
36+
connecting = asyncio.create_task(client.ensure_connected())
37+
request_msg = parse_transport_msg(await recv.get())
38+
39+
assert not isinstance(request_msg, str)
40+
assert (serverconn := conn())
41+
handshake_request: ControlMessageHandshakeRequest[None] = (
42+
ControlMessageHandshakeRequest(**request_msg.payload)
43+
)
44+
45+
handshake_resp = ControlMessageHandshakeResponse(
46+
status=HandShakeStatus(
47+
ok=True,
48+
),
49+
)
50+
handshake_request.sessionId
51+
52+
msg = TransportMessage(
53+
from_=request_msg.from_,
54+
to=request_msg.to,
55+
streamId=request_msg.streamId,
56+
controlFlags=0,
57+
id=nanoid.generate(),
58+
seq=0,
59+
ack=0,
60+
payload=handshake_resp.model_dump(),
61+
)
62+
packed = msgpack.packb(
63+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
64+
)
65+
await serverconn.send(packed)
66+
67+
sent_waiter = asyncio.Event()
68+
69+
async def handle_server_messages() -> None:
70+
request_msg = parse_transport_msg(await recv.get())
71+
assert not isinstance(request_msg, str)
72+
73+
logging.debug("request_msg: %r", repr(request_msg))
74+
75+
assert request_msg.payload.get("payload", {}).get("hello") == "world"
76+
logging.debug("Found a hello:world %r", repr(request_msg))
77+
78+
sent_waiter.set()
79+
80+
assert request_msg.controlFlags == STREAM_OPEN_BIT | STREAM_CLOSED_BIT
81+
82+
cancel_msg = parse_transport_msg(await recv.get())
83+
assert not isinstance(cancel_msg, str)
84+
assert cancel_msg.controlFlags == STREAM_CANCEL_BIT
85+
86+
server_handler = asyncio.create_task(handle_server_messages())
87+
88+
rpc_task = asyncio.create_task(
89+
client.send_rpc(
90+
"test",
91+
"bigstream",
92+
{"ok": True, "payload": {"hello": "world"}},
93+
lambda x: x,
94+
lambda x: x,
95+
lambda x: x,
96+
timedelta(seconds=2),
97+
)
98+
)
99+
100+
# Wait until we've seen at least a few messages from the upload Task
101+
await sent_waiter.wait()
102+
103+
rpc_task.cancel()
104+
105+
try:
106+
await rpc_task
107+
except asyncio.CancelledError:
108+
pass
109+
110+
await client.close()
111+
await connecting
112+
113+
# Ensure we're listening to close messages as well
114+
server_handler.cancel()
115+
await server_handler
116+
117+
24118
async def test_upload_cancel(ws_server: WsServerFixture) -> None:
25119
(urimeta, recv, conn) = ws_server
26120

0 commit comments

Comments
 (0)