2121
2222import nanoid
2323import websockets .asyncio .client
24- from aiochannel import Channel
24+ from aiochannel import Channel , ChannelFull
2525from aiochannel .errors import ChannelClosed
2626from opentelemetry .trace import Span , use_span
2727from opentelemetry .trace .propagation .tracecontext import TraceContextTextMapPropagator
@@ -376,7 +376,7 @@ async def _enqueue_message(
376376 # Wake up buffered_message_sender
377377 self ._process_messages .set ()
378378
379- async def close (self ) -> None :
379+ async def close (self , reason : Exception | None = None ) -> None :
380380 """Close the session and all associated streams."""
381381 logger .info (
382382 f"{ self .session_id } closing session to { self ._server_id } , ws: { self ._ws } "
@@ -399,16 +399,20 @@ async def close(self) -> None:
399399
400400 await self ._task_manager .cancel_all_tasks ()
401401
402- # TODO: unexpected_close should close stream differently here to
403- # throw exception correctly.
404402 for _ , error_channel , stream in self ._streams .values ():
405403 stream .close ()
406404 # Wake up backpressured writers
407- await error_channel .put (
408- SessionClosedRiverServiceException (
409- "river session is closed" ,
405+ try :
406+ error_channel .put_nowait (
407+ reason
408+ or SessionClosedRiverServiceException (
409+ "river session is closed" ,
410+ )
411+ )
412+ except ChannelFull :
413+ logger .exception (
414+ "Unable to tell the caller that the session is going away" ,
410415 )
411- )
412416 # Before we GC the streams, let's wait for all tasks to be closed gracefully.
413417 await asyncio .gather (
414418 * [stream .join () for _ , _ , stream in self ._streams .values ()]
@@ -1080,7 +1084,7 @@ async def _recv_from_ws(
10801084 get_state : Callable [[], SessionState ],
10811085 get_ws : Callable [[], ClientConnection | None ],
10821086 transition_no_connection : Callable [[], Awaitable [None ]],
1083- close_session : Callable [[], Awaitable [None ]],
1087+ close_session : Callable [[Exception | None ], Awaitable [None ]],
10841088 assert_incoming_seq_bookkeeping : Callable [
10851089 [str , int , int ], Literal [True ] | _IgnoreMessage
10861090 ],
@@ -1120,8 +1124,6 @@ async def _recv_from_ws(
11201124
11211125 logger .debug ("client start handling messages from ws %r" , ws )
11221126
1123- error_channel : Channel [Exception | None ] | None = None
1124-
11251127 # We should not process messages if the websocket is closed.
11261128 while (ws := get_ws ()) and get_state () in ActiveStates :
11271129 connection_attempts = 0
@@ -1192,7 +1194,7 @@ async def _recv_from_ws(
11921194 )
11931195 continue
11941196
1195- _ , error_channel , output = stream_meta
1197+ _ , _ , output = stream_meta
11961198
11971199 if (
11981200 msg .controlFlags & STREAM_CLOSED_BIT != 0
@@ -1216,25 +1218,21 @@ async def _recv_from_ws(
12161218 output .close ()
12171219 except OutOfOrderMessageException :
12181220 logger .exception ("Out of order message, closing connection" )
1219- await close_session ()
1220- if error_channel :
1221- await error_channel .put (
1222- SessionClosedRiverServiceException (
1223- "Out of order message, closing connection"
1224- )
1221+ await close_session (
1222+ SessionClosedRiverServiceException (
1223+ "Out of order message, closing connection"
12251224 )
1225+ )
12261226 continue
12271227 except InvalidMessageException :
12281228 logger .exception (
12291229 "Got invalid transport message, closing session" ,
12301230 )
1231- await close_session ()
1232- if error_channel :
1233- await error_channel .put (
1234- SessionClosedRiverServiceException (
1235- "Out of order message, closing connection"
1236- )
1231+ await close_session (
1232+ SessionClosedRiverServiceException (
1233+ "Out of order message, closing connection"
12371234 )
1235+ )
12381236 continue
12391237 except FailedSendingMessageException :
12401238 # Expected error if the connection is closed.
0 commit comments