|
1 | 1 | import asyncio |
2 | 2 | import logging |
3 | 3 | from typing import ( |
4 | | - Any, |
5 | | - AsyncIterator, |
6 | | - Awaitable, |
7 | | - Callable, |
8 | 4 | Literal, |
9 | | - TypeAlias, |
10 | 5 | TypedDict, |
11 | 6 | ) |
12 | 7 |
|
13 | 8 | import msgpack |
14 | 9 | import nanoid |
15 | | -import pytest |
16 | | -from websockets import ConnectionClosed, ConnectionClosedOK |
17 | | -from websockets.asyncio.server import ServerConnection, serve |
18 | | -from websockets.typing import Data |
19 | 10 |
|
20 | 11 | from replit_river.common_session import SessionState |
21 | 12 | from replit_river.messages import parse_transport_msg |
|
26 | 17 | HandShakeStatus, |
27 | 18 | TransportMessage, |
28 | 19 | ) |
29 | | -from replit_river.transport_options import TransportOptions, UriAndMetadata |
| 20 | +from replit_river.transport_options import TransportOptions |
30 | 21 | from replit_river.v2.client import Client |
31 | | -from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session |
| 22 | +from replit_river.v2.session import STREAM_CLOSED_BIT, Session |
32 | 23 | from tests.v2.fixtures.raw_ws_server import WsServerFixture |
33 | 24 |
|
34 | 25 |
|
@@ -188,106 +179,3 @@ async def handle_server_messages() -> None: |
188 | 179 | # Ensure we're listening to close messages as well |
189 | 180 | server_handler.cancel() |
190 | 181 | await server_handler |
191 | | - |
192 | | - |
193 | | -async def test_upload_cancel(ws_server: WsServerFixture) -> None: |
194 | | - (urimeta, recv, conn) = ws_server |
195 | | - |
196 | | - client = Client( |
197 | | - client_id="CLIENT1", |
198 | | - server_id="SERVER", |
199 | | - transport_options=TransportOptions(), |
200 | | - uri_and_metadata_factory=urimeta, |
201 | | - ) |
202 | | - |
203 | | - connecting = asyncio.create_task(client.ensure_connected()) |
204 | | - request_msg = parse_transport_msg(await recv.get()) |
205 | | - |
206 | | - assert not isinstance(request_msg, str) |
207 | | - assert (serverconn := conn()) |
208 | | - handshake_request: ControlMessageHandshakeRequest[None] = ( |
209 | | - ControlMessageHandshakeRequest(**request_msg.payload) |
210 | | - ) |
211 | | - |
212 | | - handshake_resp = ControlMessageHandshakeResponse( |
213 | | - status=HandShakeStatus( |
214 | | - ok=True, |
215 | | - ), |
216 | | - ) |
217 | | - handshake_request.sessionId |
218 | | - |
219 | | - msg = TransportMessage( |
220 | | - from_=request_msg.from_, |
221 | | - to=request_msg.to, |
222 | | - streamId=request_msg.streamId, |
223 | | - controlFlags=0, |
224 | | - id=nanoid.generate(), |
225 | | - seq=0, |
226 | | - ack=0, |
227 | | - payload=handshake_resp.model_dump(), |
228 | | - ) |
229 | | - packed = msgpack.packb( |
230 | | - msg.model_dump(by_alias=True, exclude_none=True), datetime=True |
231 | | - ) |
232 | | - await serverconn.send(packed) |
233 | | - |
234 | | - async def handle_server_messages() -> None: |
235 | | - request_msg = parse_transport_msg(await recv.get()) |
236 | | - assert not isinstance(request_msg, str) |
237 | | - |
238 | | - logging.debug("request_msg: %r", repr(request_msg)) |
239 | | - |
240 | | - msg = TransportMessage(**msgpack.unpackb(await recv.get())) |
241 | | - while msg.payload.get("payload", {}).get("hello") == "world": |
242 | | - logging.debug("Found a hello:world %r", repr(msg)) |
243 | | - msg = TransportMessage(**msgpack.unpackb(await recv.get())) |
244 | | - |
245 | | - assert msg.controlFlags == STREAM_CANCEL_BIT |
246 | | - |
247 | | - server_handler = asyncio.create_task(handle_server_messages()) |
248 | | - |
249 | | - sent_waiter = asyncio.Event() |
250 | | - |
251 | | - async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: |
252 | | - count = 0 |
253 | | - while True: |
254 | | - await asyncio.sleep(0.1) |
255 | | - yield { |
256 | | - "ok": True, |
257 | | - "payload": { |
258 | | - "hello": "world", |
259 | | - }, |
260 | | - } |
261 | | - count += 1 |
262 | | - if count > 5: |
263 | | - # We've sent enough messages, interrupt the stream. |
264 | | - sent_waiter.set() |
265 | | - |
266 | | - upload_task = asyncio.create_task( |
267 | | - client.send_upload( |
268 | | - "test", |
269 | | - "bigstream", |
270 | | - {}, |
271 | | - upload_chunks(), |
272 | | - lambda x: x, |
273 | | - lambda x: x, |
274 | | - lambda x: x, |
275 | | - lambda x: x, |
276 | | - ) |
277 | | - ) |
278 | | - |
279 | | - # Wait until we've seen at least a few messages from the upload Task |
280 | | - await sent_waiter.wait() |
281 | | - |
282 | | - upload_task.cancel() |
283 | | - try: |
284 | | - await upload_task |
285 | | - except asyncio.CancelledError: |
286 | | - pass |
287 | | - |
288 | | - await client.close() |
289 | | - await connecting |
290 | | - |
291 | | - # Ensure we're listening to close messages as well |
292 | | - server_handler.cancel() |
293 | | - await server_handler |
0 commit comments