Skip to content

Commit b5971b6

Browse files
Add a subscription test
1 parent d6cda64 commit b5971b6

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

tests/v2/test_v2_cancellation.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,123 @@ async def handle_server_messages() -> None:
115115
await server_handler
116116

117117

118+
async def test_subscription_cancel(ws_server: WsServerFixture) -> None:
119+
(urimeta, recv, conn) = ws_server
120+
121+
client = Client(
122+
client_id="CLIENT1",
123+
server_id="SERVER",
124+
transport_options=TransportOptions(),
125+
uri_and_metadata_factory=urimeta,
126+
)
127+
128+
connecting = asyncio.create_task(client.ensure_connected())
129+
request_msg = parse_transport_msg(await recv.get())
130+
131+
assert not isinstance(request_msg, str)
132+
assert (serverconn := conn())
133+
handshake_request: ControlMessageHandshakeRequest[None] = (
134+
ControlMessageHandshakeRequest(**request_msg.payload)
135+
)
136+
137+
handshake_resp = ControlMessageHandshakeResponse(
138+
status=HandShakeStatus(
139+
ok=True,
140+
),
141+
)
142+
handshake_request.sessionId
143+
144+
msg = TransportMessage(
145+
from_=request_msg.from_,
146+
to=request_msg.to,
147+
streamId=request_msg.streamId,
148+
controlFlags=0,
149+
id=nanoid.generate(),
150+
seq=0,
151+
ack=0,
152+
payload=handshake_resp.model_dump(),
153+
)
154+
packed = msgpack.packb(
155+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
156+
)
157+
await serverconn.send(packed)
158+
159+
received_waiter = asyncio.Event()
160+
161+
async def handle_server_messages() -> None:
162+
request_msg = parse_transport_msg(await recv.get())
163+
assert not isinstance(request_msg, str)
164+
165+
logging.debug("request_msg: %r", repr(request_msg))
166+
seq = 0
167+
168+
while True:
169+
try:
170+
cancel_msg = parse_transport_msg(recv.get_nowait())
171+
break
172+
except asyncio.queues.QueueEmpty:
173+
pass
174+
175+
msg = TransportMessage(
176+
from_=request_msg.from_,
177+
to=request_msg.to,
178+
streamId=request_msg.streamId,
179+
controlFlags=0,
180+
id=nanoid.generate(),
181+
seq=seq,
182+
ack=0,
183+
payload={
184+
"ok": True,
185+
"payload": {
186+
"hello": "world",
187+
},
188+
},
189+
)
190+
seq += 1
191+
packed = msgpack.packb(
192+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
193+
)
194+
await serverconn.send(packed)
195+
await asyncio.sleep(0.1)
196+
197+
if seq > 5:
198+
received_waiter.set()
199+
200+
assert not isinstance(cancel_msg, str)
201+
assert cancel_msg.controlFlags == STREAM_CANCEL_BIT
202+
203+
server_handler = asyncio.create_task(handle_server_messages())
204+
205+
async def receive_chunks() -> None:
206+
async for chunk in client.send_subscription(
207+
"test",
208+
"bigstream",
209+
{},
210+
lambda x: x,
211+
lambda x: x,
212+
lambda x: x,
213+
):
214+
print(repr(chunk))
215+
216+
receive_task = asyncio.create_task(receive_chunks())
217+
218+
# Wait until we've seen at least a few messages from the upload Task
219+
await received_waiter.wait()
220+
221+
receive_task.cancel()
222+
try:
223+
await receive_task
224+
except asyncio.CancelledError:
225+
pass
226+
227+
await client.close()
228+
await connecting
229+
230+
# Ensure we're listening to close messages as well
231+
server_handler.cancel()
232+
await server_handler
233+
234+
118235
async def test_upload_cancel(ws_server: WsServerFixture) -> None:
119236
(urimeta, recv, conn) = ws_server
120237

0 commit comments

Comments
 (0)