Skip to content

Commit 01b079d

Browse files
Push exception emission up into close_session directly
1 parent 01de8f0 commit 01b079d

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

src/replit_river/v2/session.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import nanoid
2323
import websockets.asyncio.client
24-
from aiochannel import Channel
24+
from aiochannel import Channel, ChannelFull
2525
from aiochannel.errors import ChannelClosed
2626
from opentelemetry.trace import Span, use_span
2727
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
@@ -376,7 +376,7 @@ async def _enqueue_message(
376376
# Wake up buffered_message_sender
377377
self._process_messages.set()
378378

379-
async def close(self) -> None:
379+
async def close(self, reason: Exception | None = None) -> None:
380380
"""Close the session and all associated streams."""
381381
logger.info(
382382
f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}"
@@ -399,16 +399,20 @@ async def close(self) -> None:
399399

400400
await self._task_manager.cancel_all_tasks()
401401

402-
# TODO: unexpected_close should close stream differently here to
403-
# throw exception correctly.
404402
for _, error_channel, stream in self._streams.values():
405403
stream.close()
406404
# Wake up backpressured writers
407-
await error_channel.put(
408-
SessionClosedRiverServiceException(
409-
"river session is closed",
405+
try:
406+
error_channel.put_nowait(
407+
reason
408+
or SessionClosedRiverServiceException(
409+
"river session is closed",
410+
)
411+
)
412+
except ChannelFull:
413+
logger.exception(
414+
"Unable to tell the caller that the session is going away",
410415
)
411-
)
412416
# Before we GC the streams, let's wait for all tasks to be closed gracefully.
413417
await asyncio.gather(
414418
*[stream.join() for _, _, stream in self._streams.values()]
@@ -1080,7 +1084,7 @@ async def _recv_from_ws(
10801084
get_state: Callable[[], SessionState],
10811085
get_ws: Callable[[], ClientConnection | None],
10821086
transition_no_connection: Callable[[], Awaitable[None]],
1083-
close_session: Callable[[], Awaitable[None]],
1087+
close_session: Callable[[Exception | None], Awaitable[None]],
10841088
assert_incoming_seq_bookkeeping: Callable[
10851089
[str, int, int], Literal[True] | _IgnoreMessage
10861090
],
@@ -1120,8 +1124,6 @@ async def _recv_from_ws(
11201124

11211125
logger.debug("client start handling messages from ws %r", ws)
11221126

1123-
error_channel: Channel[Exception | None] | None = None
1124-
11251127
# We should not process messages if the websocket is closed.
11261128
while (ws := get_ws()) and get_state() in ActiveStates:
11271129
connection_attempts = 0
@@ -1192,7 +1194,7 @@ async def _recv_from_ws(
11921194
)
11931195
continue
11941196

1195-
_, error_channel, output = stream_meta
1197+
_, _, output = stream_meta
11961198

11971199
if (
11981200
msg.controlFlags & STREAM_CLOSED_BIT != 0
@@ -1216,25 +1218,21 @@ async def _recv_from_ws(
12161218
output.close()
12171219
except OutOfOrderMessageException:
12181220
logger.exception("Out of order message, closing connection")
1219-
await close_session()
1220-
if error_channel:
1221-
await error_channel.put(
1222-
SessionClosedRiverServiceException(
1223-
"Out of order message, closing connection"
1224-
)
1221+
await close_session(
1222+
SessionClosedRiverServiceException(
1223+
"Out of order message, closing connection"
12251224
)
1225+
)
12261226
continue
12271227
except InvalidMessageException:
12281228
logger.exception(
12291229
"Got invalid transport message, closing session",
12301230
)
1231-
await close_session()
1232-
if error_channel:
1233-
await error_channel.put(
1234-
SessionClosedRiverServiceException(
1235-
"Out of order message, closing connection"
1236-
)
1231+
await close_session(
1232+
SessionClosedRiverServiceException(
1233+
"Out of order message, closing connection"
12371234
)
1235+
)
12381236
continue
12391237
except FailedSendingMessageException:
12401238
# Expected error if the connection is closed.

0 commit comments

Comments
 (0)