Skip to content

Commit 58e6eab

Browse files
Avoid exceptions for flow control
1 parent 74828c1 commit 58e6eab

File tree

7 files changed

+67
-49
lines changed

7 files changed

+67
-49
lines changed

src/replit_river/client_session.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, AsyncGenerator, Callable, Coroutine
5+
from typing import Any, AsyncGenerator, Callable, Coroutine, assert_never
66

77
import nanoid
88
import websockets
@@ -24,7 +24,7 @@
2424
parse_transport_msg,
2525
)
2626
from replit_river.seq_manager import (
27-
IgnoreMessageException,
27+
IgnoreMessage,
2828
InvalidMessageException,
2929
OutOfOrderMessageException,
3030
)
@@ -125,11 +125,21 @@ async def _handle_messages_from_ws(self) -> None:
125125
# We should not process messages if the websocket is closed.
126126
break
127127
msg = parse_transport_msg(message)
128+
if isinstance(msg, str):
129+
logger.debug("Ignoring transport message", exc_info=True)
130+
continue
128131

129132
logger.debug(f"{self._transport_id} got a message %r", msg)
130133

131134
# Update bookkeeping
132-
await self._seq_manager.check_seq_and_update(msg)
135+
match await self._seq_manager.check_seq_and_update(msg):
136+
case IgnoreMessage():
137+
continue
138+
case None:
139+
pass
140+
case other:
141+
assert_never(other)
142+
133143
await self._buffer.remove_old_messages(
134144
self._seq_manager.receiver_ack,
135145
)
@@ -142,9 +152,7 @@ async def _handle_messages_from_ws(self) -> None:
142152
if msg.controlFlags & STREAM_OPEN_BIT == 0:
143153
if not stream:
144154
logger.warning("no stream for %s", msg.streamId)
145-
raise IgnoreMessageException(
146-
"no stream for message, ignoring"
147-
)
155+
continue
148156

149157
if (
150158
msg.controlFlags & STREAM_CLOSED_BIT != 0
@@ -171,9 +179,6 @@ async def _handle_messages_from_ws(self) -> None:
171179
stream.close()
172180
async with self._stream_lock:
173181
del self._streams[msg.streamId]
174-
except IgnoreMessageException:
175-
logger.debug("Ignoring transport message", exc_info=True)
176-
continue
177182
except OutOfOrderMessageException:
178183
logger.exception("Out of order message, closing connection")
179184
await ws_wrapper.close()

src/replit_river/client_transport.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
TransportMessage,
3434
)
3535
from replit_river.seq_manager import (
36-
IgnoreMessageException,
3736
InvalidMessageException,
3837
)
3938
from replit_river.session import Session
@@ -296,10 +295,11 @@ async def _get_handshake_response_msg(
296295
"Handshake failed, conn closed while waiting for response",
297296
) from e
298297
try:
299-
return parse_transport_msg(data)
300-
except IgnoreMessageException:
301-
logger.debug("Ignoring transport message", exc_info=True)
302-
continue
298+
msg = parse_transport_msg(data)
299+
if isinstance(msg, str):
300+
logger.debug("Ignoring transport message", exc_info=True)
301+
continue
302+
return msg
303303
except InvalidMessageException as e:
304304
raise RiverException(
305305
ERROR_HANDSHAKE,

src/replit_river/messages.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
TransportMessage,
1515
)
1616
from replit_river.seq_manager import (
17-
IgnoreMessageException,
1817
InvalidMessageException,
1918
)
2019

@@ -59,11 +58,9 @@ def formatted_bytes(message: bytes) -> str:
5958
return " ".join(f"{b:02x}" for b in message)
6059

6160

62-
def parse_transport_msg(message: str | bytes) -> TransportMessage:
61+
def parse_transport_msg(message: str | bytes) -> TransportMessage | str:
6362
if isinstance(message, str):
64-
raise IgnoreMessageException(
65-
f"ignored a message beacuse it was a text frame: {message}"
66-
)
63+
return message
6764
try:
6865
# :param int timestamp:
6966
# Control how timestamp type is unpacked:

src/replit_river/seq_manager.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
from dataclasses import dataclass
34

45
from replit_river.rpc import TransportMessage
56

@@ -34,6 +35,11 @@ class SessionStateMismatchException(Exception):
3435
pass
3536

3637

38+
@dataclass
39+
class IgnoreMessage:
40+
pass
41+
42+
3743
class SeqManager:
3844
"""Manages the sequence number and ack number for a connection."""
3945

@@ -68,14 +74,11 @@ async def get_ack(self) -> int:
6874
async with self._ack_lock:
6975
return self.ack
7076

71-
async def check_seq_and_update(self, msg: TransportMessage) -> None:
77+
async def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None:
7278
async with self._ack_lock:
7379
if msg.seq != self.ack:
7480
if msg.seq < self.ack:
75-
raise IgnoreMessageException(
76-
f"{msg.from_} received duplicate msg, got {msg.seq}"
77-
f" expected {self.ack}"
78-
)
81+
return IgnoreMessage()
7982
else:
8083
logger.warn(
8184
f"Out of order message received got {msg.seq} expected "

src/replit_river/server_session.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Any, Callable, Coroutine
3+
from typing import Any, Callable, Coroutine, assert_never
44

55
import websockets
66
from aiochannel import Channel, ChannelClosed
@@ -12,7 +12,7 @@
1212
parse_transport_msg,
1313
)
1414
from replit_river.seq_manager import (
15-
IgnoreMessageException,
15+
IgnoreMessage,
1616
InvalidMessageException,
1717
OutOfOrderMessageException,
1818
)
@@ -122,11 +122,20 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
122122
# We should not process messages if the websocket is closed.
123123
break
124124
msg = parse_transport_msg(message)
125+
if isinstance(msg, str):
126+
logger.debug("Ignoring transport message", exc_info=True)
127+
continue
125128

126129
logger.debug(f"{self._transport_id} got a message %r", msg)
127130

128131
# Update bookkeeping
129-
await self._seq_manager.check_seq_and_update(msg)
132+
match self._seq_manager.check_seq_and_update(msg):
133+
case IgnoreMessage():
134+
continue
135+
case None:
136+
pass
137+
case other:
138+
assert_never(other)
130139
await self._buffer.remove_old_messages(
131140
self._seq_manager.receiver_ack,
132141
)
@@ -135,13 +144,11 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
135144
if msg.controlFlags & ACK_BIT != 0:
136145
continue
137146
async with self._stream_lock:
138-
stream = self._streams.get(msg.streamId, None)
147+
stream = self._streams.get(msg.streamId)
139148
if msg.controlFlags & STREAM_OPEN_BIT == 0:
140149
if not stream:
141150
logger.warning("no stream for %s", msg.streamId)
142-
raise IgnoreMessageException(
143-
"no stream for message, ignoring"
144-
)
151+
continue
145152

146153
if (
147154
msg.controlFlags & STREAM_CLOSED_BIT != 0
@@ -160,6 +167,8 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
160167
raise InvalidMessageException(e) from e
161168
else:
162169
_stream = await self._open_stream_and_call_handler(msg, tg)
170+
if isinstance(_stream, IgnoreMessage):
171+
continue
163172
if not stream:
164173
async with self._stream_lock:
165174
self._streams[msg.streamId] = _stream
@@ -170,9 +179,6 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
170179
stream.close()
171180
async with self._stream_lock:
172181
del self._streams[msg.streamId]
173-
except IgnoreMessageException:
174-
logger.debug("Ignoring transport message", exc_info=True)
175-
continue
176182
except OutOfOrderMessageException:
177183
logger.exception("Out of order message, closing connection")
178184
await ws_wrapper.close()
@@ -188,17 +194,22 @@ async def _open_stream_and_call_handler(
188194
self,
189195
msg: TransportMessage,
190196
tg: asyncio.TaskGroup | None,
191-
) -> Channel:
197+
) -> Channel | IgnoreMessage:
192198
if not msg.serviceName or not msg.procedureName:
193-
raise IgnoreMessageException(
194-
f"Service name or procedure name is missing in the message {msg}"
199+
logger.warning(
200+
"Service name or procedure name is missing in the message %r",
201+
msg,
195202
)
203+
return IgnoreMessage()
196204
key = (msg.serviceName, msg.procedureName)
197205
handler = self._handlers.get(key, None)
198206
if not handler:
199-
raise IgnoreMessageException(
200-
f"No handler for {key} handlers : {self._handlers.keys()}"
207+
logger.warning(
208+
"No handler for %r handlers: %r",
209+
key,
210+
self._handlers.keys(),
201211
)
212+
return IgnoreMessage()
202213
method_type, handler_func = handler
203214
is_streaming_output = method_type in (
204215
"subscription-stream", # subscription

src/replit_river/server_transport.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
TransportMessage,
2525
)
2626
from replit_river.seq_manager import (
27-
IgnoreMessageException,
2827
InvalidMessageException,
2928
SessionStateMismatchException,
3029
)
@@ -74,12 +73,12 @@ async def handshake_to_get_session(
7473
async for message in websocket:
7574
try:
7675
msg = parse_transport_msg(message)
76+
if isinstance(msg, str):
77+
continue
7778
(
7879
handshake_request,
7980
handshake_response,
8081
) = await self._establish_handshake(msg, websocket)
81-
except IgnoreMessageException:
82-
continue
8382
except InvalidMessageException as e:
8483
error_msg = "Got invalid transport message, closing connection"
8584
raise InvalidMessageException(error_msg) from e

src/replit_river/v2/session.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
TransportMessageTracingSetter,
6161
)
6262
from replit_river.seq_manager import (
63-
IgnoreMessageException,
6463
InvalidMessageException,
6564
OutOfOrderMessageException,
6665
)
@@ -1059,13 +1058,14 @@ async def websocket_closed_callback() -> None:
10591058

10601059
try:
10611060
response_msg = parse_transport_msg(data)
1061+
if isinstance(response_msg, str):
1062+
logger.debug(
1063+
"_do_ensure_connected: Ignoring transport message",
1064+
exc_info=True,
1065+
)
1066+
continue
1067+
10621068
break
1063-
except IgnoreMessageException:
1064-
logger.debug(
1065-
"_do_ensure_connected: Ignoring transport message",
1066-
exc_info=True,
1067-
)
1068-
continue
10691069
except InvalidMessageException as e:
10701070
raise RiverException(
10711071
ERROR_HANDSHAKE,
@@ -1217,6 +1217,9 @@ async def _serve(
12171217
transport_id,
12181218
msg,
12191219
)
1220+
if isinstance(msg, str):
1221+
logger.debug("Ignoring transport message", exc_info=True)
1222+
continue
12201223

12211224
if msg.controlFlags & STREAM_OPEN_BIT != 0:
12221225
raise InvalidMessageException(

0 commit comments

Comments
 (0)