Skip to content

Commit d12efb8

Browse files
[bug] Adding explicit CancelledError handlers during async waiting loops (#157)
1 parent fb8d312 commit d12efb8

File tree

10 files changed

+651
-89
lines changed

10 files changed

+651
-89
lines changed

src/replit_river/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,11 @@ def _trace_procedure(
235235
except RiverException as e:
236236
span.record_exception(e, escaped=True)
237237
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
238-
raise e
238+
raise
239239
except BaseException as e:
240240
span.record_exception(e, escaped=True)
241241
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
242-
raise e
242+
raise
243243
finally:
244244
span.end()
245245

src/replit_river/error_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ class SessionClosedRiverServiceException(RiverException):
8686
def __init__(
8787
self,
8888
message: str,
89+
streamId: str,
8990
) -> None:
9091
super().__init__(SYNTHETIC_ERROR_CODE_SESSION_CLOSED, message)
92+
self.streamId = streamId
9193

9294

9395
def exception_from_message(code: str) -> type[RiverServiceException]:

src/replit_river/task_manager.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import logging
33
from typing import Coroutine, Set
44

5-
from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException
5+
from replit_river.error_schema import (
6+
ERROR_CODE_STREAM_CLOSED,
7+
RiverException,
8+
SessionClosedRiverServiceException,
9+
)
610

711
logger = logging.getLogger(__name__)
812

@@ -37,6 +41,13 @@ async def cancel_task(
3741
# If we cancel the task manager we will get called here as well,
3842
# if we want to handle the cancellation differently we can do it here.
3943
logger.debug("Task was cancelled %r", task_to_remove)
44+
except SessionClosedRiverServiceException as e:
45+
logger.warning(
46+
"Session was closed",
47+
extra={
48+
"stream_id": e.streamId,
49+
},
50+
)
4051
except RiverException as e:
4152
if e.code == ERROR_CODE_STREAM_CLOSED:
4253
# Task is cancelled
@@ -76,6 +87,14 @@ def _task_done_callback(
7687
):
7788
# Task is cancelled
7889
pass
90+
elif isinstance(exception, SessionClosedRiverServiceException):
91+
# Session is closed, don't bother logging
92+
logger.info(
93+
"Session closed",
94+
extra={
95+
"stream_id": exception.streamId,
96+
},
97+
)
7998
else:
8099
logger.error(
81100
"Exception on cancelling task",

src/replit_river/v2/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,11 @@ def _trace_procedure(
190190
except RiverException as e:
191191
span.record_exception(e, escaped=True)
192192
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
193-
raise e
193+
raise
194194
except BaseException as e:
195195
span.record_exception(e, escaped=True)
196196
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
197-
raise e
197+
raise
198198
finally:
199199
span.end()
200200

src/replit_river/v2/session.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ async def _enqueue_message(
353353
# session is closing / closed, raise
354354
raise SessionClosedRiverServiceException(
355355
"river session is closed, dropping message",
356+
stream_id,
356357
)
357358

358359
# Begin critical section: Avoid any await between here and _send_buffer.append
@@ -448,14 +449,15 @@ async def do_close() -> None:
448449

449450
await self._task_manager.cancel_all_tasks()
450451

451-
for stream_meta in self._streams.values():
452+
for stream_id, stream_meta in self._streams.items():
452453
stream_meta["output"].close()
453454
# Wake up backpressured writers
454455
try:
455456
stream_meta["error_channel"].put_nowait(
456457
reason
457458
or SessionClosedRiverServiceException(
458459
"river session is closed",
460+
stream_id,
459461
)
460462
)
461463
except ChannelFull:
@@ -751,6 +753,13 @@ async def send_rpc[R, A](
751753
# Block for backpressure and emission errors from the ws
752754
await backpressured_waiter()
753755
result = await anext(output)
756+
except asyncio.CancelledError:
757+
await self._send_cancel_stream(
758+
stream_id=stream_id,
759+
message="RPC cancelled",
760+
span=span,
761+
)
762+
raise
754763
except asyncio.TimeoutError as e:
755764
await self._send_cancel_stream(
756765
stream_id=stream_id,
@@ -835,6 +844,13 @@ async def send_upload[I, R, A](
835844
payload=payload,
836845
span=span,
837846
)
847+
except asyncio.CancelledError:
848+
await self._send_cancel_stream(
849+
stream_id=stream_id,
850+
message="Upload cancelled",
851+
span=span,
852+
)
853+
raise
838854
except Exception as e:
839855
# If we get any exception other than WebsocketClosedException,
840856
# cancel the stream.
@@ -916,6 +932,13 @@ async def send_subscription[I, E, A](
916932
continue
917933
yield response_deserializer(item["payload"])
918934
await self._send_close_stream(stream_id, span)
935+
except asyncio.CancelledError:
936+
await self._send_cancel_stream(
937+
stream_id=stream_id,
938+
message="Subscription cancelled",
939+
span=span,
940+
)
941+
raise
919942
except Exception as e:
920943
await self._send_cancel_stream(
921944
stream_id=stream_id,
@@ -1002,6 +1025,15 @@ async def _encode_stream() -> None:
10021025
# ... block the outer function until the emitter is finished emitting,
10031026
# possibly raising a terminal exception.
10041027
await emitter_task
1028+
except asyncio.CancelledError as e:
1029+
await self._send_cancel_stream(
1030+
stream_id=stream_id,
1031+
message="Stream cancelled",
1032+
span=span,
1033+
)
1034+
if emitter_task.done() and (err := emitter_task.exception()):
1035+
raise e from err
1036+
raise
10051037
except Exception as e:
10061038
await self._send_cancel_stream(
10071039
stream_id=stream_id,
@@ -1288,6 +1320,7 @@ async def _recv_from_ws(
12881320
# the outer loop.
12891321
await transition_no_connection()
12901322
break
1323+
msg: TransportMessage | str | None = None
12911324
try:
12921325
msg = parse_transport_msg(message)
12931326
logger.debug(
@@ -1367,19 +1400,27 @@ async def _recv_from_ws(
13671400
stream_meta["output"].close()
13681401
except OutOfOrderMessageException:
13691402
logger.exception("Out of order message, closing connection")
1403+
stream_id = "unknown"
1404+
if isinstance(msg, TransportMessage):
1405+
stream_id = msg.streamId
13701406
close_session(
13711407
SessionClosedRiverServiceException(
1372-
"Out of order message, closing connection"
1408+
"Out of order message, closing connection",
1409+
stream_id,
13731410
)
13741411
)
13751412
continue
13761413
except InvalidMessageException:
13771414
logger.exception(
13781415
"Got invalid transport message, closing session",
13791416
)
1417+
stream_id = "unknown"
1418+
if isinstance(msg, TransportMessage):
1419+
stream_id = msg.streamId
13801420
close_session(
13811421
SessionClosedRiverServiceException(
1382-
"Out of order message, closing connection"
1422+
"Out of order message, closing connection",
1423+
stream_id,
13831424
)
13841425
)
13851426
continue

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
pytest_plugins = [
1818
"tests.v1.river_fixtures.logging",
1919
"tests.v1.river_fixtures.clientserver",
20-
"tests.v2.fixtures",
20+
"tests.v2.fixtures.bound_client",
21+
"tests.v2.fixtures.raw_ws_server",
2122
]
2223

2324
HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]

tests/v2/fixtures/raw_ws_server.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import asyncio
2+
from typing import (
3+
AsyncIterator,
4+
Awaitable,
5+
Callable,
6+
Literal,
7+
TypeAlias,
8+
TypedDict,
9+
)
10+
11+
import pytest
12+
from websockets import ConnectionClosed, ConnectionClosedOK, Data
13+
from websockets.asyncio.server import ServerConnection, serve
14+
15+
from replit_river.transport_options import UriAndMetadata
16+
17+
WsServerFixture: TypeAlias = tuple[
18+
Callable[[], Awaitable[UriAndMetadata[None]]],
19+
asyncio.Queue[bytes],
20+
Callable[[], ServerConnection | None],
21+
]
22+
23+
24+
class OuterPayload[A](TypedDict):
25+
ok: Literal[True]
26+
payload: A
27+
28+
29+
class _WsServerState(TypedDict):
30+
ipv4_laddr: tuple[str, int] | None
31+
32+
33+
async def _ws_server_internal(
34+
recv: asyncio.Queue[bytes],
35+
set_conn: Callable[[ServerConnection], None],
36+
state: _WsServerState,
37+
) -> AsyncIterator[None]:
38+
async def handle(websocket: ServerConnection) -> None:
39+
set_conn(websocket)
40+
datagram: Data
41+
try:
42+
while datagram := await websocket.recv(decode=False):
43+
if isinstance(datagram, str):
44+
continue
45+
await recv.put(datagram)
46+
except ConnectionClosedOK:
47+
pass
48+
except ConnectionClosed:
49+
pass
50+
51+
port: int | None = None
52+
if state["ipv4_laddr"]:
53+
port = state["ipv4_laddr"][1]
54+
async with serve(handle, "localhost", port=port) as server:
55+
for sock in server.sockets:
56+
if (pair := sock.getsockname())[0] == "127.0.0.1":
57+
if state["ipv4_laddr"] is None:
58+
state["ipv4_laddr"] = pair
59+
serve_forever = asyncio.create_task(server.serve_forever())
60+
yield None
61+
server.close()
62+
await server.wait_closed()
63+
# "serve_forever" should always be done after wait_closed finishes
64+
assert serve_forever.done()
65+
66+
67+
@pytest.fixture
68+
async def ws_server() -> AsyncIterator[WsServerFixture]:
69+
recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1)
70+
connection: ServerConnection | None = None
71+
state: _WsServerState = {"ipv4_laddr": None}
72+
73+
def set_conn(new_conn: ServerConnection) -> None:
74+
nonlocal connection
75+
connection = new_conn
76+
77+
server_generator = _ws_server_internal(recv, set_conn, state)
78+
await anext(server_generator)
79+
80+
async def urimeta() -> UriAndMetadata[None]:
81+
ipv4_laddr = state["ipv4_laddr"]
82+
assert ipv4_laddr
83+
return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None)
84+
85+
yield (urimeta, recv, lambda: connection)
86+
87+
connection = None
88+
89+
try:
90+
await anext(server_generator)
91+
except StopAsyncIteration:
92+
pass

0 commit comments

Comments
 (0)