Skip to content

Commit 1e206f5

Browse files
Just describe StreamMeta instead of ever-embiggening tuples
1 parent 70684e4 commit 1e206f5

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

src/replit_river/v2/session.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ class _IgnoreMessage:
118118
pass
119119

120120

121+
class StreamMeta(TypedDict):
122+
span: Span
123+
error_channel: Channel[None | Exception]
124+
output: Channel[Any]
125+
126+
121127
class Session[HandshakeMetadata]:
122128
_server_id: str
123129
session_id: str
@@ -145,7 +151,7 @@ class Session[HandshakeMetadata]:
145151
_space_available: asyncio.Event
146152

147153
# stream for tasks
148-
_streams: dict[str, tuple[Span, Channel[Exception | None], Channel[Any]]]
154+
_streams: dict[str, StreamMeta]
149155

150156
# book keeping
151157
_ack_buffer: deque[TransportMessage]
@@ -204,9 +210,7 @@ def __init__(
204210
self._space_available.set()
205211

206212
# stream for tasks
207-
self._streams: dict[
208-
str, tuple[Span, Channel[Exception | None], Channel[Any]]
209-
] = {}
213+
self._streams: dict[str, StreamMeta] = {}
210214

211215
# book keeping
212216
self._ack_buffer = deque()
@@ -399,11 +403,11 @@ async def close(self, reason: Exception | None = None) -> None:
399403

400404
await self._task_manager.cancel_all_tasks()
401405

402-
for _, error_channel, stream in self._streams.values():
403-
stream.close()
406+
for stream_meta in self._streams.values():
407+
stream_meta["output"].close()
404408
# Wake up backpressured writers
405409
try:
406-
error_channel.put_nowait(
410+
stream_meta["error_channel"].put_nowait(
407411
reason
408412
or SessionClosedRiverServiceException(
409413
"river session is closed",
@@ -415,7 +419,7 @@ async def close(self, reason: Exception | None = None) -> None:
415419
)
416420
# Before we GC the streams, let's wait for all tasks to be closed gracefully.
417421
await asyncio.gather(
418-
*[stream.join() for _, _, stream in self._streams.values()]
422+
*[stream_meta["output"].join() for stream_meta in self._streams.values()]
419423
)
420424
self._streams.clear()
421425

@@ -469,7 +473,7 @@ async def commit(msg: TransportMessage) -> None:
469473
# Wake up backpressured writer
470474
stream_meta = self._streams.get(pending.streamId)
471475
if stream_meta:
472-
await stream_meta[1].put(None)
476+
await stream_meta["error_channel"].put(None)
473477

474478
def get_next_pending() -> TransportMessage | None:
475479
if self._send_buffer:
@@ -580,7 +584,7 @@ async def _with_stream(
580584
span: Span,
581585
stream_id: str,
582586
maxsize: int,
583-
) -> AsyncIterator[tuple[Channel[Exception | None], Channel[ResultType]]]:
587+
) -> AsyncIterator[tuple[Channel[None | Exception], Channel[ResultType]]]:
584588
"""
585589
_with_stream
586590
@@ -592,8 +596,12 @@ async def _with_stream(
592596
emitted should call await error_channel.wait() prior to emission.
593597
"""
594598
output: Channel[Any] = Channel(maxsize=maxsize)
595-
error_channel: Channel[Exception | None] = Channel(maxsize=1)
596-
self._streams[stream_id] = (span, error_channel, output)
599+
error_channel: Channel[None | Exception] = Channel(maxsize=1)
600+
self._streams[stream_id] = {
601+
"span": span,
602+
"error_channel": error_channel,
603+
"output": output,
604+
}
597605
try:
598606
yield (error_channel, output)
599607
finally:
@@ -608,7 +616,7 @@ async def _with_stream(
608616
)
609617
return
610618
# We need to signal back to all emitters or waiters that we're gone
611-
stream_meta[1].close()
619+
output.close()
612620
del self._streams[stream_id]
613621

614622
async def send_rpc[R, A](
@@ -1111,7 +1119,7 @@ async def _recv_from_ws(
11111119
],
11121120
get_stream: Callable[
11131121
[str],
1114-
tuple[Span, Channel[Exception | None], Channel[Any]] | None,
1122+
StreamMeta | None,
11151123
],
11161124
enqueue_message: SendMessage[None],
11171125
) -> None:
@@ -1215,8 +1223,6 @@ async def _recv_from_ws(
12151223
)
12161224
continue
12171225

1218-
_, _, output = stream_meta
1219-
12201226
if (
12211227
msg.controlFlags & STREAM_CLOSED_BIT != 0
12221228
and msg.payload.get("type", None) == "CLOSE"
@@ -1226,7 +1232,7 @@ async def _recv_from_ws(
12261232
pass
12271233
else:
12281234
try:
1229-
await output.put(msg.payload)
1235+
await stream_meta["output"].put(msg.payload)
12301236
except ChannelClosed:
12311237
# The client is no longer interested in this stream,
12321238
# just drop the message.
@@ -1236,7 +1242,7 @@ async def _recv_from_ws(
12361242
# Communicate that we're going down
12371243
#
12381244
# This implements the receive side of the half-closed strategy.
1239-
output.close()
1245+
stream_meta["output"].close()
12401246
except OutOfOrderMessageException:
12411247
logger.exception("Out of order message, closing connection")
12421248
await close_session(

0 commit comments

Comments
 (0)