Skip to content

Commit 4a0b0c6

Browse files
WIP
1 parent 6e44517 commit 4a0b0c6

File tree

9 files changed

+489
-122
lines changed

9 files changed

+489
-122
lines changed

scripts/lint/src/lint/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ def raise_err(code: int) -> None:
1111

1212
def main() -> None:
1313
fix = ["--fix"] if "--fix" in sys.argv else []
14+
watch = ["--watch"] if "--watch" in sys.argv else []
1415
raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix)))
1516
raise_err(os.system("ruff format src scripts tests"))
1617
raise_err(os.system("mypy src"))
17-
raise_err(os.system("pyright src"))
18+
raise_err(os.system(" ".join(["pyright"] + watch + ["src"])))

src/replit_river/common_session.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import asyncio
22
import enum
33
import logging
4-
from typing import Any, Awaitable, Callable, Protocol
4+
from typing import Any, Awaitable, Callable, Coroutine, Protocol
55

66
from aiochannel import Channel, ChannelClosed
77
from opentelemetry.trace import Span
8+
from websockets import WebSocketCommonProtocol
89

9-
from replit_river.messages import FailedSendingMessageException
10+
from replit_river.messages import (
11+
FailedSendingMessageException,
12+
WebsocketClosedException,
13+
send_transport_message,
14+
)
1015
from replit_river.rpc import ACK_BIT, TransportMessage
1116
from replit_river.seq_manager import InvalidMessageException
1217

@@ -66,6 +71,7 @@ async def setup_heartbeat(
6671
# TODO: make this a message class
6772
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
6873
payload={
74+
"type": "ACK",
6975
"ack": 0,
7076
},
7177
control_flags=ACK_BIT,
@@ -115,6 +121,40 @@ async def check_to_close_session(
115121
return
116122

117123

124+
async def buffered_message_sender(
125+
get_ws: Callable[[], WebSocketCommonProtocol | None],
126+
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
127+
get_next_pending: Callable[[], TransportMessage | None],
128+
commit: Callable[[TransportMessage], None],
129+
) -> None:
130+
while True:
131+
while msg := get_next_pending():
132+
ws = get_ws()
133+
if not ws:
134+
break
135+
try:
136+
await send_transport_message(msg, ws, websocket_closed_callback)
137+
commit(msg)
138+
except WebsocketClosedException as e:
139+
logger.debug(
140+
"Connection closed while sending message %r, waiting for "
141+
"retry from buffer",
142+
type(e),
143+
exc_info=e,
144+
)
145+
break
146+
except FailedSendingMessageException:
147+
logger.error(
148+
"Failed sending message, waiting for retry from buffer",
149+
exc_info=True,
150+
)
151+
break
152+
except Exception:
153+
logger.exception("Error attempting to send buffered messages")
154+
break
155+
await asyncio.sleep(0.25)
156+
157+
118158
async def add_msg_to_stream(
119159
msg: TransportMessage,
120160
stream: Channel[Any],

src/replit_river/session.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,10 @@
3737
trace_setter = TransportMessageTracingSetter()
3838

3939
CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]]
40-
RetryConnectionCallback: TypeAlias = (
41-
Callable[
42-
[],
43-
Coroutine[Any, Any, Any],
44-
]
45-
)
40+
RetryConnectionCallback: TypeAlias = Callable[
41+
[],
42+
Coroutine[Any, Any, Any],
43+
]
4644

4745

4846
class Session:
@@ -75,7 +73,6 @@ class Session:
7573
_buffer: MessageBuffer
7674
_task_manager: BackgroundTaskManager
7775

78-
7976
def __init__(
8077
self,
8178
transport_id: str,

src/replit_river/v2/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from replit_river.v2.session import Session
2+
13
from .client import Client
24

35
__all__ = [
46
"Client",
7+
"Session",
58
]

src/replit_river/v2/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,27 @@ def translate_unknown_error(
5858
return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error")
5959

6060

61+
# Client[HandshakeSchema](
62+
# uri_and_metadata_factory=uri_and_metadata_factory,
63+
# client_id=self.client_id,
64+
# server_id="SERVER",
65+
# transport_options=TransportOptions(
66+
# session_disconnect_grace_ms=settings.RIVER_SESSION_DISCONNECT_GRACE_MS,
67+
# heartbeat_ms=settings.RIVER_HEARTBEAT_MS,
68+
# heartbeats_until_dead=settings.RIVER_HEARTBEATS_UNTIL_DEAD,
69+
# connection_retry_options=ConnectionRetryOptions(
70+
# base_interval_ms=settings.RIVER_CONNECTION_BASE_INTERVAL_MS,
71+
# max_jitter_ms=settings.RIVER_CONNECTION_MAX_JITTER_MS,
72+
# max_backoff_ms=settings.RIVER_CONNECTION_MAX_BACKOFF_MS,
73+
# attempt_budget_capacity=self.attempt_budget_capacity,
74+
# budget_restore_interval_ms=
75+
# settings.RIVER_CONNECTION_BUDGET_RESTORE_INTERVAL_MS,
76+
# max_retry=self.max_retry_count,
77+
# ),
78+
# ),
79+
# )
80+
81+
6182
class Client(Generic[HandshakeMetadataType]):
6283
def __init__(
6384
self,

src/replit_river/v2/client_session.py

Lines changed: 100 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
import logging
33
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, AsyncGenerator, Callable, Coroutine, Literal
5+
from typing import Any, AsyncGenerator, Callable, Literal
66

7-
import nanoid # type: ignore
7+
import nanoid
88
import websockets
99
from aiochannel import Channel
1010
from aiochannel.errors import ChannelClosed
1111
from opentelemetry.trace import Span
1212
from websockets.exceptions import ConnectionClosed
13+
from websockets.frames import CloseCode
1314

14-
from replit_river.common_session import add_msg_to_stream
15+
from replit_river.common_session import buffered_message_sender
1516
from replit_river.error_schema import (
1617
ERROR_CODE_CANCEL,
1718
ERROR_CODE_STREAM_CLOSED,
@@ -28,14 +29,19 @@
2829
from replit_river.rpc import (
2930
ACK_BIT,
3031
STREAM_OPEN_BIT,
32+
TransportMessage,
3133
)
3234
from replit_river.seq_manager import (
3335
IgnoreMessageException,
3436
InvalidMessageException,
3537
OutOfOrderMessageException,
3638
)
37-
from replit_river.session import CloseSessionCallback, RetryConnectionCallback, Session
3839
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions
40+
from replit_river.v2.session import (
41+
CloseSessionCallback,
42+
RetryConnectionCallback,
43+
Session,
44+
)
3945

4046
STREAM_CANCEL_BIT_TYPE = Literal[0b00100]
4147
STREAM_CANCEL_BIT: STREAM_CANCEL_BIT_TYPE = 0b00100
@@ -68,18 +74,42 @@ def __init__(
6874
)
6975

7076
async def do_close_websocket() -> None:
71-
await self.close_websocket(
72-
self._ws_wrapper,
73-
should_retry=True,
74-
)
77+
if self._ws_unwrapped:
78+
self._task_manager.create_task(self._ws_unwrapped.close())
79+
if self._retry_connection_callback:
80+
self._task_manager.create_task(self._retry_connection_callback())
7581
await self._begin_close_session_countdown()
7682

7783
self._setup_heartbeats_task(do_close_websocket)
7884

85+
def commit(msg: TransportMessage) -> None:
86+
pending = self._send_buffer.popleft()
87+
if msg.seq != pending.seq:
88+
logger.error("Out of sequence error")
89+
self._ack_buffer.append(pending)
90+
91+
# On commit, release pending writers waiting for more buffer space
92+
if self._queue_full_lock.locked():
93+
self._queue_full_lock.release()
94+
95+
def get_next_pending() -> TransportMessage | None:
96+
if self._send_buffer:
97+
return self._send_buffer[0]
98+
return None
99+
100+
self._task_manager.create_task(
101+
buffered_message_sender(
102+
get_ws=lambda: self._ws_unwrapped,
103+
websocket_closed_callback=self._begin_close_session_countdown,
104+
get_next_pending=get_next_pending,
105+
commit=commit,
106+
)
107+
)
108+
79109
async def start_serve_responses(self) -> None:
80-
self._task_manager.create_task(self.serve())
110+
self._task_manager.create_task(self._serve())
81111

82-
async def serve(self) -> None:
112+
async def _serve(self) -> None:
83113
"""Serve messages from the websocket."""
84114
self._reset_session_close_countdown()
85115
try:
@@ -106,64 +136,95 @@ async def serve(self) -> None:
106136
)
107137

108138
async def _handle_messages_from_ws(self) -> None:
139+
while self._ws_unwrapped is None:
140+
await asyncio.sleep(1)
109141
logger.debug(
110142
"%s start handling messages from ws %s",
111143
"client",
112-
self._ws_wrapper.id,
144+
self._ws_unwrapped.id,
113145
)
114146
try:
115-
ws_wrapper = self._ws_wrapper
116-
async for message in ws_wrapper.ws:
147+
ws = self._ws_unwrapped
148+
async for message in ws:
117149
try:
118-
if not await ws_wrapper.is_open():
150+
if not self._ws_unwrapped:
119151
# We should not process messages if the websocket is closed.
120152
break
121153
msg = parse_transport_msg(message, self._transport_options)
122154

123155
logger.debug(f"{self._transport_id} got a message %r", msg)
124156

125157
# Update bookkeeping
126-
await self._seq_manager.check_seq_and_update(msg)
127-
await self._buffer.remove_old_messages(
128-
self._seq_manager.receiver_ack,
129-
)
158+
if msg.seq < self.ack:
159+
raise IgnoreMessageException(
160+
f"{msg.from_} received duplicate msg, got {msg.seq}"
161+
f" expected {self.ack}"
162+
)
163+
elif msg.seq > self.ack:
164+
logger.warning(
165+
f"Out of order message received got {msg.seq} expected "
166+
f"{self.ack}"
167+
)
168+
169+
raise OutOfOrderMessageException(
170+
f"Out of order message received got {msg.seq} expected "
171+
f"{self.ack}"
172+
)
173+
174+
assert msg.seq == self.ack, "Safety net, redundant assertion"
175+
176+
# Set our next expected ack number
177+
self.ack = msg.seq + 1
178+
179+
# Discard old messages from the buffer
180+
while self._ack_buffer and self._ack_buffer[0].seq < msg.ack:
181+
self._ack_buffer.popleft()
182+
130183
self._reset_session_close_countdown()
131184

132185
if msg.controlFlags & ACK_BIT != 0:
133186
continue
134-
async with self._stream_lock:
135-
stream = self._streams.get(msg.streamId, None)
136-
if msg.controlFlags & STREAM_OPEN_BIT == 0:
137-
if not stream:
138-
logger.warning("no stream for %s", msg.streamId)
139-
raise IgnoreMessageException(
140-
"no stream for message, ignoring"
141-
)
142-
143-
if (
144-
msg.controlFlags & STREAM_CLOSED_BIT != 0
145-
and msg.payload.get("type", None) == "CLOSE"
146-
):
147-
# close message is not sent to the stream
148-
pass
149-
else:
150-
await add_msg_to_stream(msg, stream)
151-
else:
187+
stream = self._streams.get(msg.streamId, None)
188+
if msg.controlFlags & STREAM_OPEN_BIT != 0:
152189
raise InvalidMessageException(
153190
"Client should not receive stream open bit"
154191
)
155192

193+
if not stream:
194+
logger.warning("no stream for %s", msg.streamId)
195+
raise IgnoreMessageException("no stream for message, ignoring")
196+
197+
if (
198+
msg.controlFlags & STREAM_CLOSED_BIT != 0
199+
and msg.payload.get("type", None) == "CLOSE"
200+
):
201+
# close message is not sent to the stream
202+
pass
203+
else:
204+
try:
205+
await stream.put(msg.payload)
206+
except ChannelClosed:
207+
# The client is no longer interested in this stream,
208+
# just drop the message.
209+
pass
210+
except RuntimeError as e:
211+
raise InvalidMessageException(e) from e
212+
156213
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
157214
if stream:
158215
stream.close()
159-
async with self._stream_lock:
160-
del self._streams[msg.streamId]
216+
del self._streams[msg.streamId]
161217
except IgnoreMessageException:
162218
logger.debug("Ignoring transport message", exc_info=True)
163219
continue
164220
except OutOfOrderMessageException:
165221
logger.exception("Out of order message, closing connection")
166-
await ws_wrapper.close()
222+
self._task_manager.create_task(
223+
self._ws_unwrapped.close(
224+
code=CloseCode.INVALID_DATA,
225+
reason="Out of order message",
226+
)
227+
)
167228
return
168229
except InvalidMessageException:
169230
logger.exception("Got invalid transport message, closing session")

0 commit comments

Comments
 (0)