Skip to content

Commit 6e2bc1a

Browse files
[chore] feat/dethread clientserver (#147)
Why === It's challenging to reason about how all the bits fit together right now. What changed ============ Dethread server and client sessions and transport. Test plan ========= CI
1 parent 42de571 commit 6e2bc1a

File tree

18 files changed

+667
-424
lines changed

18 files changed

+667
-424
lines changed

src/replit_river/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .client import Client
22
from .error_schema import RiverError
33
from .rpc import (
4-
GenericRpcHandler,
4+
GenericRpcHandlerBuilder,
55
GrpcContext,
66
rpc_method_handler,
77
stream_method_handler,
@@ -15,7 +15,7 @@
1515
"Server",
1616
"GrpcContext",
1717
"RiverError",
18-
"GenericRpcHandler",
18+
"GenericRpcHandlerBuilder",
1919
"rpc_method_handler",
2020
"subscription_method_handler",
2121
"upload_method_handler",

src/replit_river/client_session.py

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import logging
33
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, AsyncGenerator, Callable
5+
from typing import Any, AsyncGenerator, Callable, Coroutine
66

77
import nanoid # type: ignore
8+
import websockets
89
from aiochannel import Channel
910
from aiochannel.errors import ChannelClosed
1011
from opentelemetry.trace import Span
12+
from websockets.exceptions import ConnectionClosed
1113

14+
from replit_river.common_session import add_msg_to_stream
1215
from replit_river.error_schema import (
1316
ERROR_CODE_CANCEL,
1417
ERROR_CODE_STREAM_CLOSED,
@@ -17,10 +20,20 @@
1720
StreamClosedRiverServiceException,
1821
exception_from_message,
1922
)
23+
from replit_river.messages import (
24+
FailedSendingMessageException,
25+
parse_transport_msg,
26+
)
27+
from replit_river.seq_manager import (
28+
IgnoreMessageException,
29+
InvalidMessageException,
30+
OutOfOrderMessageException,
31+
)
2032
from replit_river.session import Session
21-
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
33+
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions
2234

2335
from .rpc import (
36+
ACK_BIT,
2437
STREAM_CLOSED_BIT,
2538
STREAM_OPEN_BIT,
2639
ErrorType,
@@ -33,6 +46,129 @@
3346

3447

3548
class ClientSession(Session):
49+
def __init__(
50+
self,
51+
transport_id: str,
52+
to_id: str,
53+
session_id: str,
54+
websocket: websockets.WebSocketCommonProtocol,
55+
transport_options: TransportOptions,
56+
close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]],
57+
retry_connection_callback: (
58+
Callable[
59+
[],
60+
Coroutine[Any, Any, Any],
61+
]
62+
| None
63+
) = None,
64+
) -> None:
65+
super().__init__(
66+
transport_id=transport_id,
67+
to_id=to_id,
68+
session_id=session_id,
69+
websocket=websocket,
70+
transport_options=transport_options,
71+
close_session_callback=close_session_callback,
72+
retry_connection_callback=retry_connection_callback,
73+
)
74+
75+
async def do_close_websocket() -> None:
76+
await self.close_websocket(
77+
self._ws_wrapper,
78+
should_retry=True,
79+
)
80+
await self._begin_close_session_countdown()
81+
82+
self._setup_heartbeats_task(do_close_websocket)
83+
84+
async def start_serve_responses(self) -> None:
85+
self._task_manager.create_task(self.serve())
86+
87+
async def serve(self) -> None:
88+
"""Serve messages from the websocket."""
89+
self._reset_session_close_countdown()
90+
try:
91+
try:
92+
await self._handle_messages_from_ws()
93+
except ConnectionClosed:
94+
if self._retry_connection_callback:
95+
self._task_manager.create_task(self._retry_connection_callback())
96+
97+
await self._begin_close_session_countdown()
98+
logger.debug("ConnectionClosed while serving", exc_info=True)
99+
except FailedSendingMessageException:
100+
# Expected error if the connection is closed.
101+
logger.debug(
102+
"FailedSendingMessageException while serving", exc_info=True
103+
)
104+
except Exception:
105+
logger.exception("caught exception at message iterator")
106+
except ExceptionGroup as eg:
107+
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
108+
if unhandled:
109+
raise ExceptionGroup(
110+
"Unhandled exceptions on River server", unhandled.exceptions
111+
)
112+
113+
async def _handle_messages_from_ws(self) -> None:
114+
logger.debug(
115+
"%s start handling messages from ws %s",
116+
"client",
117+
self._ws_wrapper.id,
118+
)
119+
try:
120+
ws_wrapper = self._ws_wrapper
121+
async for message in ws_wrapper.ws:
122+
try:
123+
if not await ws_wrapper.is_open():
124+
# We should not process messages if the websocket is closed.
125+
break
126+
msg = parse_transport_msg(message, self._transport_options)
127+
128+
logger.debug(f"{self._transport_id} got a message %r", msg)
129+
130+
# Update bookkeeping
131+
await self._seq_manager.check_seq_and_update(msg)
132+
await self._buffer.remove_old_messages(
133+
self._seq_manager.receiver_ack,
134+
)
135+
self._reset_session_close_countdown()
136+
137+
if msg.controlFlags & ACK_BIT != 0:
138+
continue
139+
async with self._stream_lock:
140+
stream = self._streams.get(msg.streamId, None)
141+
if msg.controlFlags & STREAM_OPEN_BIT == 0:
142+
if not stream:
143+
logger.warning("no stream for %s", msg.streamId)
144+
raise IgnoreMessageException(
145+
"no stream for message, ignoring"
146+
)
147+
await add_msg_to_stream(msg, stream)
148+
else:
149+
raise InvalidMessageException(
150+
"Client should not receive stream open bit"
151+
)
152+
153+
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
154+
if stream:
155+
stream.close()
156+
async with self._stream_lock:
157+
del self._streams[msg.streamId]
158+
except IgnoreMessageException:
159+
logger.debug("Ignoring transport message", exc_info=True)
160+
continue
161+
except OutOfOrderMessageException:
162+
logger.exception("Out of order message, closing connection")
163+
await ws_wrapper.close()
164+
return
165+
except InvalidMessageException:
166+
logger.exception("Got invalid transport message, closing session")
167+
await self.close()
168+
return
169+
except ConnectionClosed as e:
170+
raise e
171+
36172
async def send_rpc(
37173
self,
38174
service_name: str,

src/replit_river/client_transport.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
import logging
33
from collections.abc import Awaitable, Callable
4-
from typing import Generic
4+
from typing import Generic, assert_never
55

6+
import nanoid
67
import websockets
78
from pydantic import ValidationError
89
from websockets import (
@@ -36,7 +37,7 @@
3637
IgnoreMessageException,
3738
InvalidMessageException,
3839
)
39-
from replit_river.transport import Transport
40+
from replit_river.session import Session
4041
from replit_river.transport_options import (
4142
HandshakeMetadataType,
4243
TransportOptions,
@@ -46,19 +47,21 @@
4647
logger = logging.getLogger(__name__)
4748

4849

49-
class ClientTransport(Transport, Generic[HandshakeMetadataType]):
50+
class ClientTransport(Generic[HandshakeMetadataType]):
51+
_sessions: dict[str, ClientSession]
52+
5053
def __init__(
5154
self,
5255
uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]],
5356
client_id: str,
5457
server_id: str,
5558
transport_options: TransportOptions,
5659
):
57-
super().__init__(
58-
transport_id=client_id,
59-
transport_options=transport_options,
60-
is_server=False,
61-
)
60+
self._sessions = {}
61+
self._transport_id = client_id
62+
self._transport_options = transport_options
63+
self._session_lock = asyncio.Lock()
64+
6265
self._uri_and_metadata_factory = uri_and_metadata_factory
6366
self._client_id = client_id
6467
self._server_id = server_id
@@ -68,6 +71,24 @@ def __init__(
6871
# We want to make sure there's only one session creation at a time
6972
self._create_session_lock = asyncio.Lock()
7073

74+
async def _close_all_sessions(self) -> None:
75+
sessions = self._sessions.values()
76+
logger.info(
77+
f"start closing sessions {self._transport_id}, number sessions : "
78+
f"{len(sessions)}"
79+
)
80+
sessions_to_close = list(sessions)
81+
82+
# closing sessions requires access to the session lock, so we need to close
83+
# them one by one to be safe
84+
for session in sessions_to_close:
85+
await session.close()
86+
87+
logger.info(f"Transport closed {self._transport_id}")
88+
89+
def generate_nanoid(self) -> str:
90+
return str(nanoid.generate())
91+
7192
async def close(self) -> None:
7293
self._rate_limiter.close()
7394
await self._close_all_sessions()
@@ -201,13 +222,11 @@ async def _create_new_session(
201222
session_id=hs_request.sessionId,
202223
websocket=new_ws,
203224
transport_options=self._transport_options,
204-
is_server=False,
205225
close_session_callback=self._delete_session,
206226
retry_connection_callback=self._retry_connection,
207-
handlers={},
208227
)
209228

210-
self._set_session(new_session)
229+
self._sessions[new_session._to_id] = new_session
211230
await new_session.start_serve_responses()
212231
return new_session
213232

@@ -297,24 +316,27 @@ async def _establish_handshake(
297316
ControlMessageHandshakeResponse,
298317
]:
299318
try:
319+
expectedSessionState: ExpectedSessionState
320+
match old_session:
321+
case None:
322+
expectedSessionState = ExpectedSessionState(
323+
nextExpectedSeq=0,
324+
nextSentSeq=0,
325+
)
326+
case ClientSession():
327+
expectedSessionState = ExpectedSessionState(
328+
nextExpectedSeq=await old_session.get_next_expected_seq(),
329+
nextSentSeq=await old_session.get_next_sent_seq(),
330+
)
331+
case other:
332+
assert_never(other)
300333
handshake_request = await self._send_handshake_request(
301334
transport_id=transport_id,
302335
to_id=to_id,
303336
session_id=session_id,
304337
handshake_metadata=handshake_metadata,
305338
websocket=websocket,
306-
expected_session_state=ExpectedSessionState(
307-
nextExpectedSeq=(
308-
await old_session.get_next_expected_seq()
309-
if old_session is not None
310-
else 0
311-
),
312-
nextSentSeq=(
313-
await old_session.get_next_sent_seq()
314-
if old_session is not None
315-
else 0
316-
),
317-
),
339+
expected_session_state=expectedSessionState,
318340
)
319341
except FailedSendingMessageException as e:
320342
raise RiverException(
@@ -352,3 +374,8 @@ async def _establish_handshake(
352374
+ f"{handshake_response.status.reason}",
353375
)
354376
return handshake_request, handshake_response
377+
378+
async def _delete_session(self, session: Session) -> None:
379+
async with self._session_lock:
380+
if session._to_id in self._sessions:
381+
del self._sessions[session._to_id]

src/replit_river/codegen/client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767

6868
FILE_HEADER = dedent(
6969
"""\
70-
# ruff: noqa
7170
# Code generated by river.codegen. DO NOT EDIT.
7271
from collections.abc import AsyncIterable, AsyncIterator
7372
import datetime

src/replit_river/codegen/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def add_{service.name}Servicer_to_server(
342342
) -> None:
343343
rpc_method_handlers: Mapping[
344344
tuple[str, str],
345-
tuple[str, river.GenericRpcHandler]
345+
tuple[str, river.GenericRpcHandlerBuilder]
346346
] = {{
347347
"""
348348
),

0 commit comments

Comments
 (0)