@@ -118,6 +118,12 @@ class _IgnoreMessage:
118118 pass
119119
120120
121+ class StreamMeta (TypedDict ):
122+ span : Span
123+ error_channel : Channel [None | Exception ]
124+ output : Channel [Any ]
125+
126+
121127class Session [HandshakeMetadata ]:
122128 _server_id : str
123129 session_id : str
@@ -145,7 +151,7 @@ class Session[HandshakeMetadata]:
145151 _space_available : asyncio .Event
146152
147153 # stream for tasks
148- _streams : dict [str , tuple [ Span , Channel [ Exception | None ], Channel [ Any ]] ]
154+ _streams : dict [str , StreamMeta ]
149155
150156 # book keeping
151157 _ack_buffer : deque [TransportMessage ]
@@ -204,9 +210,7 @@ def __init__(
204210 self ._space_available .set ()
205211
206212 # stream for tasks
207- self ._streams : dict [
208- str , tuple [Span , Channel [Exception | None ], Channel [Any ]]
209- ] = {}
213+ self ._streams : dict [str , StreamMeta ] = {}
210214
211215 # book keeping
212216 self ._ack_buffer = deque ()
@@ -399,11 +403,11 @@ async def close(self, reason: Exception | None = None) -> None:
399403
400404 await self ._task_manager .cancel_all_tasks ()
401405
402- for _ , error_channel , stream in self ._streams .values ():
403- stream .close ()
406+ for stream_meta in self ._streams .values ():
407+ stream_meta [ "output" ] .close ()
404408 # Wake up backpressured writers
405409 try :
406- error_channel .put_nowait (
410+ stream_meta [ " error_channel" ] .put_nowait (
407411 reason
408412 or SessionClosedRiverServiceException (
409413 "river session is closed" ,
@@ -415,7 +419,7 @@ async def close(self, reason: Exception | None = None) -> None:
415419 )
416420 # Before we GC the streams, let's wait for all tasks to be closed gracefully.
417421 await asyncio .gather (
418- * [stream .join () for _ , _ , stream in self ._streams .values ()]
422+ * [stream_meta [ "output" ] .join () for stream_meta in self ._streams .values ()]
419423 )
420424 self ._streams .clear ()
421425
@@ -469,7 +473,7 @@ async def commit(msg: TransportMessage) -> None:
469473 # Wake up backpressured writer
470474 stream_meta = self ._streams .get (pending .streamId )
471475 if stream_meta :
472- await stream_meta [1 ].put (None )
476+ await stream_meta ["error_channel" ].put (None )
473477
474478 def get_next_pending () -> TransportMessage | None :
475479 if self ._send_buffer :
@@ -580,7 +584,7 @@ async def _with_stream(
580584 span : Span ,
581585 stream_id : str ,
582586 maxsize : int ,
583- ) -> AsyncIterator [tuple [Channel [Exception | None ], Channel [ResultType ]]]:
587+ ) -> AsyncIterator [tuple [Channel [None | Exception ], Channel [ResultType ]]]:
584588 """
585589 _with_stream
586590
@@ -592,8 +596,12 @@ async def _with_stream(
592596 emitted should call await error_channel.wait() prior to emission.
593597 """
594598 output : Channel [Any ] = Channel (maxsize = maxsize )
595- error_channel : Channel [Exception | None ] = Channel (maxsize = 1 )
596- self ._streams [stream_id ] = (span , error_channel , output )
599+ error_channel : Channel [None | Exception ] = Channel (maxsize = 1 )
600+ self ._streams [stream_id ] = {
601+ "span" : span ,
602+ "error_channel" : error_channel ,
603+ "output" : output ,
604+ }
597605 try :
598606 yield (error_channel , output )
599607 finally :
@@ -608,7 +616,7 @@ async def _with_stream(
608616 )
609617 return
610618 # We need to signal back to all emitters or waiters that we're gone
611- stream_meta [ 1 ] .close ()
619+ output .close ()
612620 del self ._streams [stream_id ]
613621
614622 async def send_rpc [R , A ](
@@ -1111,7 +1119,7 @@ async def _recv_from_ws(
11111119 ],
11121120 get_stream : Callable [
11131121 [str ],
1114- tuple [ Span , Channel [ Exception | None ], Channel [ Any ]] | None ,
1122+ StreamMeta | None ,
11151123 ],
11161124 enqueue_message : SendMessage [None ],
11171125) -> None :
@@ -1215,8 +1223,6 @@ async def _recv_from_ws(
12151223 )
12161224 continue
12171225
1218- _ , _ , output = stream_meta
1219-
12201226 if (
12211227 msg .controlFlags & STREAM_CLOSED_BIT != 0
12221228 and msg .payload .get ("type" , None ) == "CLOSE"
@@ -1226,7 +1232,7 @@ async def _recv_from_ws(
12261232 pass
12271233 else :
12281234 try :
1229- await output .put (msg .payload )
1235+ await stream_meta [ " output" ] .put (msg .payload )
12301236 except ChannelClosed :
12311237 # The client is no longer interested in this stream,
12321238 # just drop the message.
@@ -1236,7 +1242,7 @@ async def _recv_from_ws(
12361242 # Communicate that we're going down
12371243 #
12381244 # This implements the receive side of the half-closed strategy.
1239- output .close ()
1245+ stream_meta [ " output" ] .close ()
12401246 except OutOfOrderMessageException :
12411247 logger .exception ("Out of order message, closing connection" )
12421248 await close_session (
0 commit comments