|
24 | 24 | from websockets.exceptions import ConnectionClosed, ConnectionClosedOK |
25 | 25 | from websockets.frames import CloseCode |
26 | 26 | from websockets.legacy.protocol import WebSocketCommonProtocol |
| 27 | +from websockets.protocol import CONNECTING |
27 | 28 |
|
28 | 29 | from replit_river.common_session import ( |
29 | 30 | SessionState, |
@@ -557,40 +558,60 @@ async def close(self) -> None: |
557 | 558 | await self._close_session_callback(self) |
558 | 559 |
|
559 | 560 | async def start_serve_responses(self) -> None: |
560 | | - self._task_manager.create_task(self._serve()) |
| 561 | + async def transition_closed() -> None: |
| 562 | + self._state = SessionState.CONNECTING |
| 563 | + if self._retry_connection_callback: |
| 564 | + self._task_manager.create_task(self._retry_connection_callback()) |
| 565 | + |
| 566 | + await self._begin_close_session_countdown() |
| 567 | + self._task_manager.create_task(self._serve( |
| 568 | + get_state=lambda: self._state, |
| 569 | + transition_closed=transition_closed, |
| 570 | + reset_session_close_countdown=self._reset_session_close_countdown, |
| 571 | + )) |
561 | 572 |
|
562 | | - async def _serve(self) -> None: |
| 573 | + async def _serve( |
| 574 | + self, |
| 575 | + get_state: Callable[[], SessionState], |
| 576 | + transition_closed: Callable[[], Awaitable[None]], |
| 577 | + reset_session_close_countdown: Callable[[], None], |
| 578 | + ) -> None: |
563 | 579 | """Serve messages from the websocket.""" |
564 | | - self._reset_session_close_countdown() |
565 | | - try: |
| 580 | + reset_session_close_countdown() |
| 581 | + our_task = asyncio.current_task() |
| 582 | + idx = 0 |
| 583 | + while our_task and not our_task.cancelling() and not our_task.cancelled(): |
| 584 | + logging.debug(f"_serve loop count={idx}") |
| 585 | + idx += 1 |
566 | 586 | try: |
567 | | - await self._handle_messages_from_ws() |
568 | | - except ConnectionClosed: |
569 | | - # Set ourselves to closed as soon as we get the signal |
570 | | - self._state = SessionState.CONNECTING |
571 | | - if self._retry_connection_callback: |
572 | | - self._task_manager.create_task(self._retry_connection_callback()) |
573 | | - |
574 | | - await self._begin_close_session_countdown() |
575 | | - logger.debug("ConnectionClosed while serving", exc_info=True) |
576 | | - except FailedSendingMessageException: |
577 | | - # Expected error if the connection is closed. |
578 | | - logger.debug( |
579 | | - "FailedSendingMessageException while serving", exc_info=True |
580 | | - ) |
581 | | - except Exception: |
582 | | - logger.exception("caught exception at message iterator") |
583 | | - except ExceptionGroup as eg: |
584 | | - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) |
585 | | - if unhandled: |
586 | | - raise ExceptionGroup( |
587 | | - "Unhandled exceptions on River server", unhandled.exceptions |
588 | | - ) |
| 587 | + try: |
| 588 | + await self._handle_messages_from_ws() |
| 589 | + except ConnectionClosed: |
| 590 | + # Set ourselves to closed as soon as we get the signal |
| 591 | + await transition_closed() |
| 592 | + logger.debug("ConnectionClosed while serving", exc_info=True) |
| 593 | + except FailedSendingMessageException: |
| 594 | + # Expected error if the connection is closed. |
| 595 | + logger.debug( |
| 596 | + "FailedSendingMessageException while serving", exc_info=True |
| 597 | + ) |
| 598 | + except Exception: |
| 599 | + logger.exception("caught exception at message iterator") |
| 600 | + except ExceptionGroup as eg: |
| 601 | + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) |
| 602 | + if unhandled: |
| 603 | + # We're in a task, there's not that much that can be done. |
| 604 | + unhandled = ExceptionGroup( |
| 605 | + "Unhandled exceptions on River server", unhandled.exceptions |
| 606 | + ) |
| 607 | + logger.exception("caught exception at message iterator", exc_info=unhandled) |
| 608 | + raise unhandled |
| 609 | + logging.debug(f"_serve exiting normally after {idx} loops") |
589 | 610 |
|
590 | 611 | async def _handle_messages_from_ws(self) -> None: |
591 | 612 | logging.debug("_handle_messages_from_ws started") |
592 | 613 | while self._ws_unwrapped is None or self._state == SessionState.CONNECTING: |
593 | | - logging.debug("_handle_messages_from_ws started") |
| 614 | + logging.debug("_handle_messages_from_ws spinning while connecting") |
594 | 615 | await asyncio.sleep(1) |
595 | 616 | logger.debug( |
596 | 617 | "%s start handling messages from ws %s", |
|
0 commit comments