Skip to content

Commit 01de8f0

Browse files
Adding Span to _streams
1 parent eeb5ce8 commit 01de8f0

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

src/replit_river/v2/session.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class Session[HandshakeMetadata]:
145145
_space_available: asyncio.Event
146146

147147
# stream for tasks
148-
_streams: dict[str, tuple[Channel[Exception | None], Channel[Any]]]
148+
_streams: dict[str, tuple[Span, Channel[Exception | None], Channel[Any]]]
149149

150150
# book keeping
151151
_ack_buffer: deque[TransportMessage]
@@ -204,7 +204,9 @@ def __init__(
204204
self._space_available.set()
205205

206206
# stream for tasks
207-
self._streams: dict[str, tuple[Channel[Exception | None], Channel[Any]]] = {}
207+
self._streams: dict[
208+
str, tuple[Span, Channel[Exception | None], Channel[Any]]
209+
] = {}
208210

209211
# book keeping
210212
self._ack_buffer = deque()
@@ -399,7 +401,7 @@ async def close(self) -> None:
399401

400402
# TODO: unexpected_close should close stream differently here to
401403
# throw exception correctly.
402-
for error_channel, stream in self._streams.values():
404+
for _, error_channel, stream in self._streams.values():
403405
stream.close()
404406
# Wake up backpressured writers
405407
await error_channel.put(
@@ -408,7 +410,9 @@ async def close(self) -> None:
408410
)
409411
)
410412
# Before we GC the streams, let's wait for all tasks to be closed gracefully.
411-
await asyncio.gather(*[stream.join() for _, stream in self._streams.values()])
413+
await asyncio.gather(
414+
*[stream.join() for _, _, stream in self._streams.values()]
415+
)
412416
self._streams.clear()
413417

414418
if self._ws:
@@ -441,7 +445,7 @@ async def commit(msg: TransportMessage) -> None:
441445
# Wake up backpressured writer
442446
stream_meta = self._streams.get(pending.streamId)
443447
if stream_meta:
444-
await stream_meta[0].put(None)
448+
await stream_meta[1].put(None)
445449

446450
def get_next_pending() -> TransportMessage | None:
447451
if self._send_buffer:
@@ -549,6 +553,7 @@ async def block_until_connected() -> None:
549553
@asynccontextmanager
550554
async def _with_stream(
551555
self,
556+
span: Span,
552557
stream_id: str,
553558
maxsize: int,
554559
) -> AsyncIterator[tuple[Channel[Exception | None], Channel[ResultType]]]:
@@ -564,19 +569,22 @@ async def _with_stream(
564569
"""
565570
output: Channel[Any] = Channel(maxsize=maxsize)
566571
error_channel: Channel[Exception | None] = Channel(maxsize=1)
567-
self._streams[stream_id] = (error_channel, output)
572+
self._streams[stream_id] = (span, error_channel, output)
568573
try:
569574
yield (error_channel, output)
570575
finally:
571576
stream_meta = self._streams.get(stream_id)
572577
if not stream_meta:
573-
logger.warning("_with_stream had an entry deleted out from under it", extra={
574-
"session_id": self.session_id,
575-
"stream_id": stream_id,
576-
})
578+
logger.warning(
579+
"_with_stream had an entry deleted out from under it",
580+
extra={
581+
"session_id": self.session_id,
582+
"stream_id": stream_id,
583+
},
584+
)
577585
return
578586
# We need to signal back to all emitters or waiters that we're gone
579-
stream_meta[0].close()
587+
stream_meta[1].close()
580588
del self._streams[stream_id]
581589

582590
async def send_rpc[R, A](
@@ -604,7 +612,7 @@ async def send_rpc[R, A](
604612
span=span,
605613
)
606614

607-
async with self._with_stream(stream_id, 1) as (error_channel, output):
615+
async with self._with_stream(span, stream_id, 1) as (error_channel, output):
608616
# Handle potential errors during communication
609617
try:
610618
async with asyncio.timeout(timeout.total_seconds()):
@@ -665,7 +673,7 @@ async def send_upload[I, R, A](
665673
span=span,
666674
)
667675

668-
async with self._with_stream(stream_id, 1) as (error_channel, output):
676+
async with self._with_stream(span, stream_id, 1) as (error_channel, output):
669677
try:
670678
# If this request is not closed and the session is killed, we should
671679
# throw exception here
@@ -764,7 +772,10 @@ async def send_subscription[I, E, A](
764772
span=span,
765773
)
766774

767-
async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (_, output):
775+
async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
776+
_,
777+
output,
778+
):
768779
try:
769780
async for item in output:
770781
if item.get("type") == "CLOSE":
@@ -812,7 +823,7 @@ async def send_stream[I, R, E, A](
812823
span=span,
813824
)
814825

815-
async with self._with_stream(stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
826+
async with self._with_stream(span, stream_id, MAX_MESSAGE_BUFFER_SIZE) as (
816827
error_channel,
817828
output,
818829
):
@@ -1073,7 +1084,10 @@ async def _recv_from_ws(
10731084
assert_incoming_seq_bookkeeping: Callable[
10741085
[str, int, int], Literal[True] | _IgnoreMessage
10751086
],
1076-
get_stream: Callable[[str], tuple[Channel[Exception | None], Channel[Any]] | None],
1087+
get_stream: Callable[
1088+
[str],
1089+
tuple[Span, Channel[Exception | None], Channel[Any]] | None,
1090+
],
10771091
enqueue_message: SendMessage[None],
10781092
) -> None:
10791093
"""Serve messages from the websocket.
@@ -1169,16 +1183,16 @@ async def _recv_from_ws(
11691183
)
11701184
continue
11711185

1172-
errors_and_stream = get_stream(msg.streamId)
1186+
stream_meta = get_stream(msg.streamId)
11731187

1174-
if not errors_and_stream:
1188+
if not stream_meta:
11751189
logger.warning(
11761190
"no stream for %s, ignoring message",
11771191
msg.streamId,
11781192
)
11791193
continue
11801194

1181-
error_channel, output = errors_and_stream
1195+
_, error_channel, output = stream_meta
11821196

11831197
if (
11841198
msg.controlFlags & STREAM_CLOSED_BIT != 0

0 commit comments

Comments
 (0)