@@ -145,7 +145,7 @@ class Session[HandshakeMetadata]:
145145 _space_available : asyncio .Event
146146
147147 # stream for tasks
148- _streams : dict [str , tuple [Channel [Exception | None ], Channel [Any ]]]
148+ _streams : dict [str , tuple [Span , Channel [Exception | None ], Channel [Any ]]]
149149
150150 # book keeping
151151 _ack_buffer : deque [TransportMessage ]
@@ -204,7 +204,9 @@ def __init__(
204204 self ._space_available .set ()
205205
206206 # stream for tasks
207- self ._streams : dict [str , tuple [Channel [Exception | None ], Channel [Any ]]] = {}
207+ self ._streams : dict [
208+ str , tuple [Span , Channel [Exception | None ], Channel [Any ]]
209+ ] = {}
208210
209211 # book keeping
210212 self ._ack_buffer = deque ()
@@ -399,7 +401,7 @@ async def close(self) -> None:
399401
400402 # TODO: unexpected_close should close stream differently here to
401403 # throw exception correctly.
402- for error_channel , stream in self ._streams .values ():
404+ for _ , error_channel , stream in self ._streams .values ():
403405 stream .close ()
404406 # Wake up backpressured writers
405407 await error_channel .put (
@@ -408,7 +410,9 @@ async def close(self) -> None:
408410 )
409411 )
410412 # Before we GC the streams, let's wait for all tasks to be closed gracefully.
411- await asyncio .gather (* [stream .join () for _ , stream in self ._streams .values ()])
413+ await asyncio .gather (
414+ * [stream .join () for _ , _ , stream in self ._streams .values ()]
415+ )
412416 self ._streams .clear ()
413417
414418 if self ._ws :
@@ -441,7 +445,7 @@ async def commit(msg: TransportMessage) -> None:
441445 # Wake up backpressured writer
442446 stream_meta = self ._streams .get (pending .streamId )
443447 if stream_meta :
444- await stream_meta [0 ].put (None )
448+ await stream_meta [1 ].put (None )
445449
446450 def get_next_pending () -> TransportMessage | None :
447451 if self ._send_buffer :
@@ -549,6 +553,7 @@ async def block_until_connected() -> None:
549553 @asynccontextmanager
550554 async def _with_stream (
551555 self ,
556+ span : Span ,
552557 stream_id : str ,
553558 maxsize : int ,
554559 ) -> AsyncIterator [tuple [Channel [Exception | None ], Channel [ResultType ]]]:
@@ -564,19 +569,22 @@ async def _with_stream(
564569 """
565570 output : Channel [Any ] = Channel (maxsize = maxsize )
566571 error_channel : Channel [Exception | None ] = Channel (maxsize = 1 )
567- self ._streams [stream_id ] = (error_channel , output )
572+ self ._streams [stream_id ] = (span , error_channel , output )
568573 try :
569574 yield (error_channel , output )
570575 finally :
571576 stream_meta = self ._streams .get (stream_id )
572577 if not stream_meta :
573- logger .warning ("_with_stream had an entry deleted out from under it" , extra = {
574- "session_id" : self .session_id ,
575- "stream_id" : stream_id ,
576- })
578+ logger .warning (
579+ "_with_stream had an entry deleted out from under it" ,
580+ extra = {
581+ "session_id" : self .session_id ,
582+ "stream_id" : stream_id ,
583+ },
584+ )
577585 return
578586 # We need to signal back to all emitters or waiters that we're gone
579- stream_meta [0 ].close ()
587+ stream_meta [1 ].close ()
580588 del self ._streams [stream_id ]
581589
582590 async def send_rpc [R , A ](
@@ -604,7 +612,7 @@ async def send_rpc[R, A](
604612 span = span ,
605613 )
606614
607- async with self ._with_stream (stream_id , 1 ) as (error_channel , output ):
615+ async with self ._with_stream (span , stream_id , 1 ) as (error_channel , output ):
608616 # Handle potential errors during communication
609617 try :
610618 async with asyncio .timeout (timeout .total_seconds ()):
@@ -665,7 +673,7 @@ async def send_upload[I, R, A](
665673 span = span ,
666674 )
667675
668- async with self ._with_stream (stream_id , 1 ) as (error_channel , output ):
676+ async with self ._with_stream (span , stream_id , 1 ) as (error_channel , output ):
669677 try :
670678 # If this request is not closed and the session is killed, we should
671679 # throw exception here
@@ -764,7 +772,10 @@ async def send_subscription[I, E, A](
764772 span = span ,
765773 )
766774
767- async with self ._with_stream (stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (_ , output ):
775+ async with self ._with_stream (span , stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (
776+ _ ,
777+ output ,
778+ ):
768779 try :
769780 async for item in output :
770781 if item .get ("type" ) == "CLOSE" :
@@ -812,7 +823,7 @@ async def send_stream[I, R, E, A](
812823 span = span ,
813824 )
814825
815- async with self ._with_stream (stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (
826+ async with self ._with_stream (span , stream_id , MAX_MESSAGE_BUFFER_SIZE ) as (
816827 error_channel ,
817828 output ,
818829 ):
@@ -1073,7 +1084,10 @@ async def _recv_from_ws(
10731084 assert_incoming_seq_bookkeeping : Callable [
10741085 [str , int , int ], Literal [True ] | _IgnoreMessage
10751086 ],
1076- get_stream : Callable [[str ], tuple [Channel [Exception | None ], Channel [Any ]] | None ],
1087+ get_stream : Callable [
1088+ [str ],
1089+ tuple [Span , Channel [Exception | None ], Channel [Any ]] | None ,
1090+ ],
10771091 enqueue_message : SendMessage [None ],
10781092) -> None :
10791093 """Serve messages from the websocket.
@@ -1169,16 +1183,16 @@ async def _recv_from_ws(
11691183 )
11701184 continue
11711185
1172- errors_and_stream = get_stream (msg .streamId )
1186+ stream_meta = get_stream (msg .streamId )
11731187
1174- if not errors_and_stream :
1188+ if not stream_meta :
11751189 logger .warning (
11761190 "no stream for %s, ignoring message" ,
11771191 msg .streamId ,
11781192 )
11791193 continue
11801194
1181- error_channel , output = errors_and_stream
1195+ _ , error_channel , output = stream_meta
11821196
11831197 if (
11841198 msg .controlFlags & STREAM_CLOSED_BIT != 0
0 commit comments