Skip to content

Commit 7f114e0

Browse files
Prevent _handle_messages_from_ws from terminating early
1 parent 9d33ded commit 7f114e0

File tree

1 file changed

+98
-93
lines changed

1 file changed

+98
-93
lines changed

src/replit_river/v2/session.py

Lines changed: 98 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -583,116 +583,121 @@ async def _serve(self) -> None:
583583
except ExceptionGroup as eg:
584584
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
585585
if unhandled:
586-
raise ExceptionGroup(
586+
# We're in a task, there's not that much that can be done.
587+
unhandled = ExceptionGroup(
587588
"Unhandled exceptions on River server", unhandled.exceptions
588589
)
590+
logger.exception("caught exception at message iterator", exc_info=unhandled)
591+
raise unhandled
589592

590593
async def _handle_messages_from_ws(self) -> None:
591594
logging.debug("_handle_messages_from_ws started")
592-
while self._ws_unwrapped is None or self._state == SessionState.CONNECTING:
593-
logging.debug("_handle_messages_from_ws started")
594-
await asyncio.sleep(1)
595-
logger.debug(
596-
"%s start handling messages from ws %s",
597-
"client",
598-
self._ws_unwrapped.id,
599-
)
600-
try:
601-
# We should not process messages if the websocket is closed.
602-
while ws := self._ws_unwrapped:
603-
# decode=False: Avoiding an unnecessary round-trip through str
604-
# Ideally this should be type-ascripted to : bytes, but there is no
605-
# @overrides in `websockets` to hint this.
606-
message = await ws.recv(decode=False)
607-
try:
608-
msg = parse_transport_msg(message)
595+
our_task = asyncio.current_task()
596+
while our_task and not our_task.cancelling() and not our_task.cancelled():
597+
while self._ws_unwrapped is None or self._state == SessionState.CONNECTING:
598+
logging.debug("_handle_messages_from_ws spinning while connecting")
599+
await asyncio.sleep(1)
600+
logger.debug(
601+
"%s start handling messages from ws %s",
602+
"client",
603+
self._ws_unwrapped.id,
604+
)
605+
try:
606+
# We should not process messages if the websocket is closed.
607+
while ws := self._ws_unwrapped:
608+
# decode=False: Avoiding an unnecessary round-trip through str
609+
# Ideally this should be type-ascripted to : bytes, but there is no
610+
# @overrides in `websockets` to hint this.
611+
message = await ws.recv(decode=False)
612+
try:
613+
msg = parse_transport_msg(message)
609614

610-
logger.debug(f"{self._transport_id} got a message %r", msg)
615+
logger.debug(f"{self._transport_id} got a message %r", msg)
611616

612-
# Update bookkeeping
613-
if msg.seq < self.ack:
614-
raise IgnoreMessageException(
615-
f"{msg.from_} received duplicate msg, got {msg.seq}"
616-
f" expected {self.ack}"
617-
)
618-
elif msg.seq > self.ack:
619-
logger.warning(
620-
f"Out of order message received got {msg.seq} expected "
621-
f"{self.ack}"
622-
)
617+
# Update bookkeeping
618+
if msg.seq < self.ack:
619+
raise IgnoreMessageException(
620+
f"{msg.from_} received duplicate msg, got {msg.seq}"
621+
f" expected {self.ack}"
622+
)
623+
elif msg.seq > self.ack:
624+
logger.warning(
625+
f"Out of order message received got {msg.seq} expected "
626+
f"{self.ack}"
627+
)
623628

624-
raise OutOfOrderMessageException(
625-
f"Out of order message received got {msg.seq} expected "
626-
f"{self.ack}"
627-
)
629+
raise OutOfOrderMessageException(
630+
f"Out of order message received got {msg.seq} expected "
631+
f"{self.ack}"
632+
)
628633

629-
assert msg.seq == self.ack, "Safety net, redundant assertion"
634+
assert msg.seq == self.ack, "Safety net, redundant assertion"
630635

631-
# Set our next expected ack number
632-
self.ack = msg.seq + 1
636+
# Set our next expected ack number
637+
self.ack = msg.seq + 1
633638

634-
# Discard old server-ack'd messages from the ack buffer
635-
while self._ack_buffer and self._ack_buffer[0].seq < msg.ack:
636-
self._ack_buffer.popleft()
639+
# Discard old server-ack'd messages from the ack buffer
640+
while self._ack_buffer and self._ack_buffer[0].seq < msg.ack:
641+
self._ack_buffer.popleft()
637642

638-
self._reset_session_close_countdown()
643+
self._reset_session_close_countdown()
639644

640-
# Shortcut to avoid processing ack packets
641-
if msg.controlFlags & ACK_BIT != 0:
642-
continue
645+
# Shortcut to avoid processing ack packets
646+
if msg.controlFlags & ACK_BIT != 0:
647+
continue
643648

644-
stream = self._streams.get(msg.streamId, None)
645-
if msg.controlFlags & STREAM_OPEN_BIT != 0:
646-
raise InvalidMessageException(
647-
"Client should not receive stream open bit"
648-
)
649+
stream = self._streams.get(msg.streamId, None)
650+
if msg.controlFlags & STREAM_OPEN_BIT != 0:
651+
raise InvalidMessageException(
652+
"Client should not receive stream open bit"
653+
)
649654

650-
if not stream:
651-
logger.warning("no stream for %s", msg.streamId)
652-
raise IgnoreMessageException("no stream for message, ignoring")
653-
654-
if (
655-
msg.controlFlags & STREAM_CLOSED_BIT != 0
656-
and msg.payload.get("type", None) == "CLOSE"
657-
):
658-
# close message is not sent to the stream
659-
pass
660-
else:
661-
try:
662-
await stream.put(msg.payload)
663-
except ChannelClosed:
664-
# The client is no longer interested in this stream,
665-
# just drop the message.
655+
if not stream:
656+
logger.warning("no stream for %s", msg.streamId)
657+
raise IgnoreMessageException("no stream for message, ignoring")
658+
659+
if (
660+
msg.controlFlags & STREAM_CLOSED_BIT != 0
661+
and msg.payload.get("type", None) == "CLOSE"
662+
):
663+
# close message is not sent to the stream
666664
pass
667-
except RuntimeError as e:
668-
raise InvalidMessageException(e) from e
669-
670-
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
671-
if stream:
672-
stream.close()
673-
del self._streams[msg.streamId]
674-
except IgnoreMessageException:
675-
logger.debug("Ignoring transport message", exc_info=True)
676-
continue
677-
except OutOfOrderMessageException:
678-
logger.exception("Out of order message, closing connection")
679-
self._task_manager.create_task(
680-
self._ws_unwrapped.close(
681-
code=CloseCode.INVALID_DATA,
682-
reason="Out of order message",
665+
else:
666+
try:
667+
await stream.put(msg.payload)
668+
except ChannelClosed:
669+
# The client is no longer interested in this stream,
670+
# just drop the message.
671+
pass
672+
except RuntimeError as e:
673+
raise InvalidMessageException(e) from e
674+
675+
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
676+
if stream:
677+
stream.close()
678+
del self._streams[msg.streamId]
679+
except IgnoreMessageException:
680+
logger.debug("Ignoring transport message", exc_info=True)
681+
continue
682+
except OutOfOrderMessageException:
683+
logger.exception("Out of order message, closing connection")
684+
self._task_manager.create_task(
685+
self._ws_unwrapped.close(
686+
code=CloseCode.INVALID_DATA,
687+
reason="Out of order message",
688+
)
683689
)
684-
)
685-
return
686-
except InvalidMessageException:
687-
logger.exception("Got invalid transport message, closing session")
688-
await self.close()
689-
return
690-
except ConnectionClosedOK:
691-
# Exited normally
692-
self._state = SessionState.CONNECTING
693-
except ConnectionClosed as e:
694-
self._state = SessionState.CONNECTING
695-
raise e
690+
return
691+
except InvalidMessageException:
692+
logger.exception("Got invalid transport message, closing session")
693+
await self.close()
694+
return
695+
except ConnectionClosedOK:
696+
# Exited normally
697+
self._state = SessionState.CONNECTING
698+
except ConnectionClosed as e:
699+
self._state = SessionState.CONNECTING
700+
raise e
696701
logging.debug("_handle_messages_from_ws exiting") # When the network disconnects this Task exits and then we don't restart it.
697702

698703
async def send_rpc[R, A](

0 commit comments

Comments
 (0)