22import logging
33from collections .abc import AsyncIterable
44from datetime import timedelta
5- from typing import Any , AsyncGenerator , Callable , Coroutine , Literal
5+ from typing import Any , AsyncGenerator , Callable , Literal
66
7- import nanoid # type: ignore
7+ import nanoid
88import websockets
99from aiochannel import Channel
1010from aiochannel .errors import ChannelClosed
1111from opentelemetry .trace import Span
1212from websockets .exceptions import ConnectionClosed
13+ from websockets .frames import CloseCode
1314
14- from replit_river .common_session import add_msg_to_stream
15+ from replit_river .common_session import buffered_message_sender
1516from replit_river .error_schema import (
1617 ERROR_CODE_CANCEL ,
1718 ERROR_CODE_STREAM_CLOSED ,
2829from replit_river .rpc import (
2930 ACK_BIT ,
3031 STREAM_OPEN_BIT ,
32+ TransportMessage ,
3133)
3234from replit_river .seq_manager import (
3335 IgnoreMessageException ,
3436 InvalidMessageException ,
3537 OutOfOrderMessageException ,
3638)
37- from replit_river .session import CloseSessionCallback , RetryConnectionCallback , Session
3839from replit_river .transport_options import MAX_MESSAGE_BUFFER_SIZE , TransportOptions
40+ from replit_river .v2 .session import (
41+ CloseSessionCallback ,
42+ RetryConnectionCallback ,
43+ Session ,
44+ )
3945
4046STREAM_CANCEL_BIT_TYPE = Literal [0b00100 ]
4147STREAM_CANCEL_BIT : STREAM_CANCEL_BIT_TYPE = 0b00100
@@ -68,18 +74,42 @@ def __init__(
6874 )
6975
7076 async def do_close_websocket () -> None :
71- await self .close_websocket (
72- self ._ws_wrapper ,
73- should_retry = True ,
74- )
77+ if self ._ws_unwrapped :
78+ self ._task_manager . create_task ( self . _ws_unwrapped . close ())
79+ if self . _retry_connection_callback :
80+ self . _task_manager . create_task ( self . _retry_connection_callback () )
7581 await self ._begin_close_session_countdown ()
7682
7783 self ._setup_heartbeats_task (do_close_websocket )
7884
85+ def commit (msg : TransportMessage ) -> None :
86+ pending = self ._send_buffer .popleft ()
87+ if msg .seq != pending .seq :
88+ logger .error ("Out of sequence error" )
89+ self ._ack_buffer .append (pending )
90+
91+ # On commit, release pending writers waiting for more buffer space
92+ if self ._queue_full_lock .locked ():
93+ self ._queue_full_lock .release ()
94+
95+ def get_next_pending () -> TransportMessage | None :
96+ if self ._send_buffer :
97+ return self ._send_buffer [0 ]
98+ return None
99+
100+ self ._task_manager .create_task (
101+ buffered_message_sender (
102+ get_ws = lambda : self ._ws_unwrapped ,
103+ websocket_closed_callback = self ._begin_close_session_countdown ,
104+ get_next_pending = get_next_pending ,
105+ commit = commit ,
106+ )
107+ )
108+
79109 async def start_serve_responses (self ) -> None :
80- self ._task_manager .create_task (self .serve ())
110+ self ._task_manager .create_task (self ._serve ())
81111
82- async def serve (self ) -> None :
112+ async def _serve (self ) -> None :
83113 """Serve messages from the websocket."""
84114 self ._reset_session_close_countdown ()
85115 try :
@@ -106,64 +136,95 @@ async def serve(self) -> None:
106136 )
107137
108138 async def _handle_messages_from_ws (self ) -> None :
139+ while self ._ws_unwrapped is None :
140+ await asyncio .sleep (1 )
109141 logger .debug (
110142 "%s start handling messages from ws %s" ,
111143 "client" ,
112- self ._ws_wrapper .id ,
144+ self ._ws_unwrapped .id ,
113145 )
114146 try :
115- ws_wrapper = self ._ws_wrapper
116- async for message in ws_wrapper . ws :
147+ ws = self ._ws_unwrapped
148+ async for message in ws :
117149 try :
118- if not await ws_wrapper . is_open () :
150+ if not self . _ws_unwrapped :
119151 # We should not process messages if the websocket is closed.
120152 break
121153 msg = parse_transport_msg (message , self ._transport_options )
122154
123155 logger .debug (f"{ self ._transport_id } got a message %r" , msg )
124156
125157 # Update bookkeeping
126- await self ._seq_manager .check_seq_and_update (msg )
127- await self ._buffer .remove_old_messages (
128- self ._seq_manager .receiver_ack ,
129- )
158+ if msg .seq < self .ack :
159+ raise IgnoreMessageException (
160+ f"{ msg .from_ } received duplicate msg, got { msg .seq } "
161+ f" expected { self .ack } "
162+ )
163+ elif msg .seq > self .ack :
164+ logger .warning (
165+ f"Out of order message received got { msg .seq } expected "
166+ f"{ self .ack } "
167+ )
168+
169+ raise OutOfOrderMessageException (
170+ f"Out of order message received got { msg .seq } expected "
171+ f"{ self .ack } "
172+ )
173+
174+ assert msg .seq == self .ack , "Safety net, redundant assertion"
175+
176+ # Set our next expected ack number
177+ self .ack = msg .seq + 1
178+
179+ # Discard old messages from the buffer
180+ while self ._ack_buffer and self ._ack_buffer [0 ].seq < msg .ack :
181+ self ._ack_buffer .popleft ()
182+
130183 self ._reset_session_close_countdown ()
131184
132185 if msg .controlFlags & ACK_BIT != 0 :
133186 continue
134- async with self ._stream_lock :
135- stream = self ._streams .get (msg .streamId , None )
136- if msg .controlFlags & STREAM_OPEN_BIT == 0 :
137- if not stream :
138- logger .warning ("no stream for %s" , msg .streamId )
139- raise IgnoreMessageException (
140- "no stream for message, ignoring"
141- )
142-
143- if (
144- msg .controlFlags & STREAM_CLOSED_BIT != 0
145- and msg .payload .get ("type" , None ) == "CLOSE"
146- ):
147- # close message is not sent to the stream
148- pass
149- else :
150- await add_msg_to_stream (msg , stream )
151- else :
187+ stream = self ._streams .get (msg .streamId , None )
188+ if msg .controlFlags & STREAM_OPEN_BIT != 0 :
152189 raise InvalidMessageException (
153190 "Client should not receive stream open bit"
154191 )
155192
193+ if not stream :
194+ logger .warning ("no stream for %s" , msg .streamId )
195+ raise IgnoreMessageException ("no stream for message, ignoring" )
196+
197+ if (
198+ msg .controlFlags & STREAM_CLOSED_BIT != 0
199+ and msg .payload .get ("type" , None ) == "CLOSE"
200+ ):
201+ # close message is not sent to the stream
202+ pass
203+ else :
204+ try :
205+ await stream .put (msg .payload )
206+ except ChannelClosed :
207+ # The client is no longer interested in this stream,
208+ # just drop the message.
209+ pass
210+ except RuntimeError as e :
211+ raise InvalidMessageException (e ) from e
212+
156213 if msg .controlFlags & STREAM_CLOSED_BIT != 0 :
157214 if stream :
158215 stream .close ()
159- async with self ._stream_lock :
160- del self ._streams [msg .streamId ]
216+ del self ._streams [msg .streamId ]
161217 except IgnoreMessageException :
162218 logger .debug ("Ignoring transport message" , exc_info = True )
163219 continue
164220 except OutOfOrderMessageException :
165221 logger .exception ("Out of order message, closing connection" )
166- await ws_wrapper .close ()
222+ self ._task_manager .create_task (
223+ self ._ws_unwrapped .close (
224+ code = CloseCode .INVALID_DATA ,
225+ reason = "Out of order message" ,
226+ )
227+ )
167228 return
168229 except InvalidMessageException :
169230 logger .exception ("Got invalid transport message, closing session" )
0 commit comments