Skip to content

Commit 4bffed8

Browse files
Prevent _handle_messages_from_ws from terminating early
1 parent 9d33ded commit 4bffed8

File tree

1 file changed

+48
-27
lines changed

1 file changed

+48
-27
lines changed

src/replit_river/v2/session.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
2525
from websockets.frames import CloseCode
2626
from websockets.legacy.protocol import WebSocketCommonProtocol
27+
from websockets.protocol import CONNECTING
2728

2829
from replit_river.common_session import (
2930
SessionState,
@@ -557,40 +558,60 @@ async def close(self) -> None:
557558
await self._close_session_callback(self)
558559

559560
async def start_serve_responses(self) -> None:
560-
self._task_manager.create_task(self._serve())
561+
async def transition_closed() -> None:
562+
self._state = SessionState.CONNECTING
563+
if self._retry_connection_callback:
564+
self._task_manager.create_task(self._retry_connection_callback())
565+
566+
await self._begin_close_session_countdown()
567+
self._task_manager.create_task(self._serve(
568+
get_state=lambda: self._state,
569+
transition_closed=transition_closed,
570+
reset_session_close_countdown=self._reset_session_close_countdown,
571+
))
561572

562-
async def _serve(self) -> None:
573+
async def _serve(
574+
self,
575+
get_state: Callable[[], SessionState],
576+
transition_closed: Callable[[], Awaitable[None]],
577+
reset_session_close_countdown: Callable[[], None],
578+
) -> None:
563579
"""Serve messages from the websocket."""
564-
self._reset_session_close_countdown()
565-
try:
580+
reset_session_close_countdown()
581+
our_task = asyncio.current_task()
582+
idx = 0
583+
while our_task and not our_task.cancelling() and not our_task.cancelled():
584+
logging.debug(f"_serve loop count={idx}")
585+
idx += 1
566586
try:
567-
await self._handle_messages_from_ws()
568-
except ConnectionClosed:
569-
# Set ourselves to closed as soon as we get the signal
570-
self._state = SessionState.CONNECTING
571-
if self._retry_connection_callback:
572-
self._task_manager.create_task(self._retry_connection_callback())
573-
574-
await self._begin_close_session_countdown()
575-
logger.debug("ConnectionClosed while serving", exc_info=True)
576-
except FailedSendingMessageException:
577-
# Expected error if the connection is closed.
578-
logger.debug(
579-
"FailedSendingMessageException while serving", exc_info=True
580-
)
581-
except Exception:
582-
logger.exception("caught exception at message iterator")
583-
except ExceptionGroup as eg:
584-
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
585-
if unhandled:
586-
raise ExceptionGroup(
587-
"Unhandled exceptions on River server", unhandled.exceptions
588-
)
587+
try:
588+
await self._handle_messages_from_ws()
589+
except ConnectionClosed:
590+
# Set ourselves to closed as soon as we get the signal
591+
await transition_closed()
592+
logger.debug("ConnectionClosed while serving", exc_info=True)
593+
except FailedSendingMessageException:
594+
# Expected error if the connection is closed.
595+
logger.debug(
596+
"FailedSendingMessageException while serving", exc_info=True
597+
)
598+
except Exception:
599+
logger.exception("caught exception at message iterator")
600+
except ExceptionGroup as eg:
601+
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
602+
if unhandled:
603+
# We're in a task, there's not that much that can be done.
604+
unhandled = ExceptionGroup(
605+
"Unhandled exceptions on River server", unhandled.exceptions
606+
)
607+
logger.exception("caught exception at message iterator", exc_info=unhandled)
608+
raise unhandled
609+
logging.debug(f"_serve exiting normally after {idx} loops")
589610

590611
async def _handle_messages_from_ws(self) -> None:
591612
logging.debug("_handle_messages_from_ws started")
592613
while self._ws_unwrapped is None or self._state == SessionState.CONNECTING:
593-
logging.debug("_handle_messages_from_ws started")
614+
logging.debug("_handle_messages_from_ws spinning while connecting")
594615
await asyncio.sleep(1)
595616
logger.debug(
596617
"%s start handling messages from ws %s",

0 commit comments

Comments
 (0)