Skip to content

Commit 6ad1adc

Browse files
Split serve() functionality between client and server
1 parent fb6c482 commit 6ad1adc

File tree

3 files changed

+273
-158
lines changed

3 files changed

+273
-158
lines changed

src/replit_river/client_session.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aiochannel import Channel
99
from aiochannel.errors import ChannelClosed
1010
from opentelemetry.trace import Span
11+
from websockets.exceptions import ConnectionClosed
1112

1213
from replit_river.error_schema import (
1314
ERROR_CODE_CANCEL,
@@ -17,22 +18,125 @@
1718
StreamClosedRiverServiceException,
1819
exception_from_message,
1920
)
21+
from replit_river.messages import (
22+
FailedSendingMessageException,
23+
parse_transport_msg,
24+
)
25+
from replit_river.seq_manager import (
26+
IgnoreMessageException,
27+
InvalidMessageException,
28+
OutOfOrderMessageException,
29+
)
2030
from replit_river.session import Session
2131
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
2232

2333
from .rpc import (
34+
ACK_BIT,
2435
STREAM_CLOSED_BIT,
2536
STREAM_OPEN_BIT,
2637
ErrorType,
2738
InitType,
2839
RequestType,
2940
ResponseType,
41+
TransportMessage,
3042
)
3143

3244
logger = logging.getLogger(__name__)
3345

3446

3547
class ClientSession(Session):
48+
async def start_serve_responses(self) -> None:
49+
self._task_manager.create_task(self.serve())
50+
51+
async def serve(self) -> None:
52+
"""Serve messages from the websocket."""
53+
self._reset_session_close_countdown()
54+
try:
55+
async with asyncio.TaskGroup() as tg:
56+
try:
57+
await self._handle_messages_from_ws(tg)
58+
except ConnectionClosed:
59+
if self._retry_connection_callback:
60+
self._task_manager.create_task(
61+
self._retry_connection_callback()
62+
)
63+
64+
await self._begin_close_session_countdown()
65+
logger.debug("ConnectionClosed while serving", exc_info=True)
66+
except FailedSendingMessageException:
67+
# Expected error if the connection is closed.
68+
logger.debug(
69+
"FailedSendingMessageException while serving", exc_info=True
70+
)
71+
except Exception:
72+
logger.exception("caught exception at message iterator")
73+
except ExceptionGroup as eg:
74+
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
75+
if unhandled:
76+
raise ExceptionGroup(
77+
"Unhandled exceptions on River server", unhandled.exceptions
78+
)
79+
80+
async def _update_book_keeping(self, msg: TransportMessage) -> None:
81+
await self._seq_manager.check_seq_and_update(msg)
82+
await self._remove_acked_messages_in_buffer()
83+
self._reset_session_close_countdown()
84+
85+
async def _handle_messages_from_ws(
86+
self, tg: asyncio.TaskGroup | None = None
87+
) -> None:
88+
logger.debug(
89+
"%s start handling messages from ws %s",
90+
"client",
91+
self._ws_wrapper.id,
92+
)
93+
try:
94+
ws_wrapper = self._ws_wrapper
95+
async for message in ws_wrapper.ws:
96+
try:
97+
if not await ws_wrapper.is_open():
98+
# We should not process messages if the websocket is closed.
99+
break
100+
msg = parse_transport_msg(message, self._transport_options)
101+
102+
logger.debug(f"{self._transport_id} got a message %r", msg)
103+
104+
await self._update_book_keeping(msg)
105+
if msg.controlFlags & ACK_BIT != 0:
106+
continue
107+
async with self._stream_lock:
108+
stream = self._streams.get(msg.streamId, None)
109+
if msg.controlFlags & STREAM_OPEN_BIT == 0:
110+
if not stream:
111+
logger.warning("no stream for %s", msg.streamId)
112+
raise IgnoreMessageException(
113+
"no stream for message, ignoring"
114+
)
115+
await self._add_msg_to_stream(msg, stream)
116+
else:
117+
raise InvalidMessageException(
118+
"Client should not receive stream open bit"
119+
)
120+
121+
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
122+
if stream:
123+
stream.close()
124+
async with self._stream_lock:
125+
del self._streams[msg.streamId]
126+
except IgnoreMessageException:
127+
logger.debug("Ignoring transport message", exc_info=True)
128+
continue
129+
except OutOfOrderMessageException:
130+
logger.exception("Out of order message, closing connection")
131+
await ws_wrapper.close()
132+
return
133+
except InvalidMessageException:
134+
logger.exception("Got invalid transport message, closing session")
135+
await self.close()
136+
return
137+
except ConnectionClosed as e:
138+
raise e
139+
36140
async def send_rpc(
37141
self,
38142
service_name: str,

src/replit_river/server_session.py

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
1+
import asyncio
12
import logging
3+
from typing import Any
24

5+
from aiochannel import Channel, ChannelClosed
36
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
7+
from websockets.exceptions import ConnectionClosed
48

9+
from replit_river.messages import (
10+
FailedSendingMessageException,
11+
parse_transport_msg,
12+
)
13+
from replit_river.seq_manager import (
14+
IgnoreMessageException,
15+
InvalidMessageException,
16+
OutOfOrderMessageException,
17+
)
518
from replit_river.session import Session
19+
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
620

721
from .rpc import (
22+
ACK_BIT,
23+
STREAM_CLOSED_BIT,
24+
STREAM_OPEN_BIT,
25+
TransportMessage,
826
TransportMessageTracingSetter,
927
)
1028

29+
logger = logging.getLogger(__name__)
30+
31+
1132
logger = logging.getLogger(__name__)
1233

1334
trace_propagator = TraceContextTextMapPropagator()
@@ -17,4 +38,150 @@
1738
class ServerSession(Session):
1839
"""A transport object that handles the websocket connection with a client."""
1940

20-
pass
41+
async def start_serve_responses(self) -> None:
42+
self._task_manager.create_task(self.serve())
43+
44+
async def serve(self) -> None:
45+
"""Serve messages from the websocket."""
46+
self._reset_session_close_countdown()
47+
try:
48+
async with asyncio.TaskGroup() as tg:
49+
try:
50+
await self._handle_messages_from_ws(tg)
51+
except ConnectionClosed:
52+
if self._retry_connection_callback:
53+
self._task_manager.create_task(
54+
self._retry_connection_callback()
55+
)
56+
57+
await self._begin_close_session_countdown()
58+
logger.debug("ConnectionClosed while serving", exc_info=True)
59+
except FailedSendingMessageException:
60+
# Expected error if the connection is closed.
61+
logger.debug(
62+
"FailedSendingMessageException while serving", exc_info=True
63+
)
64+
except Exception:
65+
logger.exception("caught exception at message iterator")
66+
except ExceptionGroup as eg:
67+
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
68+
if unhandled:
69+
raise ExceptionGroup(
70+
"Unhandled exceptions on River server", unhandled.exceptions
71+
)
72+
73+
async def _update_book_keeping(self, msg: TransportMessage) -> None:
74+
await self._seq_manager.check_seq_and_update(msg)
75+
await self._remove_acked_messages_in_buffer()
76+
self._reset_session_close_countdown()
77+
78+
async def _handle_messages_from_ws(
79+
self, tg: asyncio.TaskGroup | None = None
80+
) -> None:
81+
logger.debug(
82+
"%s start handling messages from ws %s",
83+
"server",
84+
self._ws_wrapper.id,
85+
)
86+
try:
87+
ws_wrapper = self._ws_wrapper
88+
async for message in ws_wrapper.ws:
89+
try:
90+
if not await ws_wrapper.is_open():
91+
# We should not process messages if the websocket is closed.
92+
break
93+
msg = parse_transport_msg(message, self._transport_options)
94+
95+
logger.debug(f"{self._transport_id} got a message %r", msg)
96+
97+
await self._update_book_keeping(msg)
98+
if msg.controlFlags & ACK_BIT != 0:
99+
continue
100+
async with self._stream_lock:
101+
stream = self._streams.get(msg.streamId, None)
102+
if msg.controlFlags & STREAM_OPEN_BIT == 0:
103+
if not stream:
104+
logger.warning("no stream for %s", msg.streamId)
105+
raise IgnoreMessageException(
106+
"no stream for message, ignoring"
107+
)
108+
await self._add_msg_to_stream(msg, stream)
109+
else:
110+
# TODO(dstewart) This looks like it opens a new call to handler
111+
# on ever ws message, instead of demuxing and
112+
# routing.
113+
_stream = await self._open_stream_and_call_handler(msg, tg)
114+
if not stream:
115+
async with self._stream_lock:
116+
self._streams[msg.streamId] = _stream
117+
stream = _stream
118+
119+
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
120+
if stream:
121+
stream.close()
122+
async with self._stream_lock:
123+
del self._streams[msg.streamId]
124+
except IgnoreMessageException:
125+
logger.debug("Ignoring transport message", exc_info=True)
126+
continue
127+
except OutOfOrderMessageException:
128+
logger.exception("Out of order message, closing connection")
129+
await ws_wrapper.close()
130+
return
131+
except InvalidMessageException:
132+
logger.exception("Got invalid transport message, closing session")
133+
await self.close()
134+
return
135+
except ConnectionClosed as e:
136+
raise e
137+
138+
async def _open_stream_and_call_handler(
139+
self,
140+
msg: TransportMessage,
141+
tg: asyncio.TaskGroup | None,
142+
) -> Channel:
143+
if not msg.serviceName or not msg.procedureName:
144+
raise IgnoreMessageException(
145+
f"Service name or procedure name is missing in the message {msg}"
146+
)
147+
key = (msg.serviceName, msg.procedureName)
148+
handler = self._handlers.get(key, None)
149+
if not handler:
150+
raise IgnoreMessageException(
151+
f"No handler for {key} handlers : {self._handlers.keys()}"
152+
)
153+
method_type, handler_func = handler
154+
is_streaming_output = method_type in (
155+
"subscription-stream", # subscription
156+
"stream",
157+
)
158+
is_streaming_input = method_type in (
159+
"upload-stream", # subscription
160+
"stream",
161+
)
162+
# New channel pair.
163+
input_stream: Channel[Any] = Channel(
164+
MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1
165+
)
166+
output_stream: Channel[Any] = Channel(
167+
MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1
168+
)
169+
if (
170+
msg.controlFlags & STREAM_CLOSED_BIT == 0
171+
or msg.payload.get("type", None) != "CLOSE"
172+
):
173+
try:
174+
await input_stream.put(msg.payload)
175+
except (RuntimeError, ChannelClosed) as e:
176+
raise InvalidMessageException(e) from e
177+
# Start the handler.
178+
self._task_manager.create_task(
179+
handler_func(msg.from_, input_stream, output_stream), tg
180+
)
181+
self._task_manager.create_task(
182+
self._send_responses_from_output_stream(
183+
msg.streamId, output_stream, is_streaming_output
184+
),
185+
tg,
186+
)
187+
return input_stream

0 commit comments

Comments
 (0)