Skip to content

Commit 379084a

Browse files
Fix bugs
1 parent 0374959 commit 379084a

File tree

4 files changed

+42
-16
lines changed

4 files changed

+42
-16
lines changed

src/replit_river/common_session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ async def check_to_close_session(
134134

135135

136136
async def buffered_message_sender(
137+
connection_condition: asyncio.Condition,
137138
message_enqueued: asyncio.Semaphore,
138139
get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None],
139140
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
@@ -143,16 +144,19 @@ async def buffered_message_sender(
143144
logger.debug("Entering buffered_message_sender")
144145
while True:
145146
await message_enqueued.acquire()
147+
while (ws := get_ws()) is None:
148+
# Block until we have a handle
149+
logger.debug("buffered_message_sender: connection_condition.acquire() %r %r %r", ws, get_ws(), connection_condition)
150+
async with connection_condition:
151+
await connection_condition.wait()
152+
logger.debug("buffered_message_sender: connection_condition UNLOCKED")
146153
logger.debug("buffered_message_sender: acquired")
147154
if msg := get_next_pending():
148-
ws = get_ws()
149155
logger.debug(
150156
"buffered_message_sender: Dequeued %r to send over %r",
151157
msg,
152158
ws,
153159
)
154-
if not ws:
155-
break
156160
try:
157161
logger.debug("buffered_message_sender: Sending...")
158162
await send_transport_message(msg, ws, websocket_closed_callback)
@@ -175,6 +179,7 @@ async def buffered_message_sender(
175179
except Exception:
176180
logger.exception("Error attempting to send buffered messages")
177181
break
182+
print("buffered_message_sender exit")
178183

179184

180185
async def add_msg_to_stream(

src/replit_river/rate_limiter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def has_budget_or_throw(
9191
Returns:
9292
bool: True if budget is available, False otherwise.
9393
"""
94-
if self.get_budget_consumed(user) < self.options.attempt_budget_capacity:
94+
logger.debug("self.get_budget_consumed(user)=%r < self.options.attempt_budget_capacity=%r", self.get_budget_consumed(user), self.options.attempt_budget_capacity)
95+
if self.get_budget_consumed(user) > self.options.attempt_budget_capacity:
9596
logger.debug("No retry budget for %s.", user)
9697
raise BudgetExhaustedException(
9798
error_code,

src/replit_river/v2/client_transport.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ async def get_or_create_session(self) -> Session:
6060
call ensure_connected on whatever session is active.
6161
"""
6262
existing_session = self._session
63+
logger.debug(f"if not existing_session={existing_session} or existing_session.is_closed()={existing_session and existing_session.is_closed()}:")
6364
if not existing_session or existing_session.is_closed():
6465
logger.info("Creating new session")
6566
new_session = Session(

src/replit_river/v2/session.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class Session:
100100
_close_session_callback: CloseSessionCallback
101101
_close_session_after_time_secs: float | None
102102
_connecting_task: asyncio.Task[Literal[True]] | None
103+
_connection_condition: asyncio.Condition
103104

104105
# ws state
105106
_ws_unwrapped: ClientConnection | None
@@ -135,6 +136,7 @@ def __init__(
135136
self._close_session_callback = close_session_callback
136137
self._close_session_after_time_secs: float | None = None
137138
self._connecting_task = None
139+
self._connection_condition = asyncio.Condition()
138140

139141
# ws state
140142
self._ws_unwrapped = None
@@ -162,11 +164,13 @@ async def do_close_websocket() -> None:
162164
self._state,
163165
self._ws_unwrapped,
164166
)
165-
self._state = SessionState.CLOSING
166167
if self._ws_unwrapped:
167168
self._task_manager.create_task(self._ws_unwrapped.close())
168169
if self._retry_connection_callback:
169170
self._task_manager.create_task(self._retry_connection_callback())
171+
self._ws_unwrapped = None
172+
else:
173+
self._state = SessionState.CLOSING
170174
await self._begin_close_session_countdown()
171175

172176
def increment_and_get_heartbeat_misses() -> int:
@@ -211,14 +215,18 @@ def get_next_pending() -> TransportMessage | None:
211215
return self._send_buffer[0]
212216
return None
213217

218+
# TODO: Just return _ws_unwrapped once we are no longer using the legacy client
219+
def get_ws() -> WebSocketCommonProtocol | ClientConnection | None:
220+
logger.debug("get_ws: %r %r", self.is_connected(), self._ws_unwrapped)
221+
if self.is_connected():
222+
return self._ws_unwrapped
223+
return None
224+
214225
self._task_manager.create_task(
215226
buffered_message_sender(
227+
self._connection_condition,
216228
self._message_enqueued,
217-
get_ws=lambda: (
218-
cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped)
219-
if self.is_connected()
220-
else None
221-
),
229+
get_ws=get_ws,
222230
websocket_closed_callback=self._begin_close_session_countdown,
223231
get_next_pending=get_next_pending,
224232
commit=commit,
@@ -242,6 +250,7 @@ async def ensure_connected[HandshakeMetadata](
242250
logic that actually establishes the connection.
243251
"""
244252

253+
logger.debug("ensure_connected: %r", self.is_connected())
245254
if self.is_connected():
246255
return
247256

@@ -255,7 +264,9 @@ async def ensure_connected[HandshakeMetadata](
255264
)
256265
)
257266

267+
logger.debug("BEFORE await _do_ensure_connected")
258268
await self._connecting_task
269+
logger.debug("AFTER await _do_ensure_connected")
259270

260271
async def _do_ensure_connected[HandshakeMetadata](
261272
self,
@@ -271,6 +282,7 @@ async def _do_ensure_connected[HandshakeMetadata](
271282

272283
last_error: Exception | None = None
273284
i = 0
285+
await self._connection_condition.acquire()
274286
while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error):
275287
if i > 0:
276288
logger.info(f"Retrying build handshake number {i} times")
@@ -378,6 +390,11 @@ async def websocket_closed_callback() -> None:
378390
last_error = None
379391
rate_limiter.start_restoring_budget(client_id)
380392
self._state = SessionState.ACTIVE
393+
self._ws_unwrapped = ws
394+
logger.debug("Before notify_all: %r %r %r", self._state, self._ws_unwrapped, self._connection_condition)
395+
self._connection_condition.notify_all()
396+
self._connection_condition.release()
397+
break
381398
except RiverException as e:
382399
await ws.close()
383400
raise e
@@ -411,6 +428,7 @@ async def websocket_closed_callback() -> None:
411428
f"Failed to create ws after retrying {max_retry} number of times",
412429
) from last_error
413430

431+
logger.debug("EXITING _do_ensure_connected")
414432
return True
415433

416434
def is_closed(self) -> bool:
@@ -419,7 +437,7 @@ def is_closed(self) -> bool:
419437
Do not send messages, do not expect any more messages to be emitted,
420438
the state is expected to be stale.
421439
"""
422-
return self._state not in TerminalStates
440+
return self._state in TerminalStates
423441

424442
def is_connected(self) -> bool:
425443
return self._state == SessionState.ACTIVE
@@ -477,6 +495,7 @@ async def send_message(
477495
serviceName=service_name,
478496
procedureName=procedure_name,
479497
)
498+
logger.debug("SENDING MESSAGE: %r", msg)
480499

481500
if span:
482501
with use_span(span):
@@ -516,17 +535,17 @@ async def close(self) -> None:
516535
self._reset_session_close_countdown()
517536
await self._task_manager.cancel_all_tasks()
518537

519-
if self._ws_unwrapped:
520-
# The Session isn't guaranteed to live much longer than this close()
521-
# invocation, so let's await this close to avoid dropping the socket.
522-
await self._ws_unwrapped.close()
523-
524538
# TODO: unexpected_close should close stream differently here to
525539
# throw exception correctly.
526540
for stream in self._streams.values():
527541
stream.close()
528542
self._streams.clear()
529543

544+
if self._ws_unwrapped:
545+
# The Session isn't guaranteed to live much longer than this close()
546+
# invocation, so let's await this close to avoid dropping the socket.
547+
await self._ws_unwrapped.close()
548+
530549
self._state = SessionState.CLOSED
531550

532551
# Clear the session in transports

0 commit comments

Comments
 (0)