@@ -583,116 +583,121 @@ async def _serve(self) -> None:
583583 except ExceptionGroup as eg :
584584 _ , unhandled = eg .split (lambda e : isinstance (e , ConnectionClosed ))
585585 if unhandled :
586- raise ExceptionGroup (
586+ # We're in a task, there's not that much that can be done.
587+ unhandled = ExceptionGroup (
587588 "Unhandled exceptions on River server" , unhandled .exceptions
588589 )
590+ logger .exception ("caught exception at message iterator" , exc_info = unhandled )
591+ raise unhandled
589592
590593 async def _handle_messages_from_ws (self ) -> None :
591594 logging .debug ("_handle_messages_from_ws started" )
592- while self ._ws_unwrapped is None or self ._state == SessionState .CONNECTING :
593- logging .debug ("_handle_messages_from_ws started" )
594- await asyncio .sleep (1 )
595- logger .debug (
596- "%s start handling messages from ws %s" ,
597- "client" ,
598- self ._ws_unwrapped .id ,
599- )
600- try :
601- # We should not process messages if the websocket is closed.
602- while ws := self ._ws_unwrapped :
603- # decode=False: Avoiding an unnecessary round-trip through str
604- # Ideally this should be type-ascripted to : bytes, but there is no
605- # @overrides in `websockets` to hint this.
606- message = await ws .recv (decode = False )
607- try :
608- msg = parse_transport_msg (message )
595+ our_task = asyncio .current_task ()
596+ while our_task and not our_task .cancelling () and not our_task .cancelled ():
597+ while self ._ws_unwrapped is None or self ._state == SessionState .CONNECTING :
598+ logging .debug ("_handle_messages_from_ws spinning while connecting" )
599+ await asyncio .sleep (1 )
600+ logger .debug (
601+ "%s start handling messages from ws %s" ,
602+ "client" ,
603+ self ._ws_unwrapped .id ,
604+ )
605+ try :
606+ # We should not process messages if the websocket is closed.
607+ while ws := self ._ws_unwrapped :
608+ # decode=False: Avoiding an unnecessary round-trip through str
609+ # Ideally this should be type-ascripted to : bytes, but there is no
610+ # @overrides in `websockets` to hint this.
611+ message = await ws .recv (decode = False )
612+ try :
613+ msg = parse_transport_msg (message )
609614
610- logger .debug (f"{ self ._transport_id } got a message %r" , msg )
615+ logger .debug (f"{ self ._transport_id } got a message %r" , msg )
611616
612- # Update bookkeeping
613- if msg .seq < self .ack :
614- raise IgnoreMessageException (
615- f"{ msg .from_ } received duplicate msg, got { msg .seq } "
616- f" expected { self .ack } "
617- )
618- elif msg .seq > self .ack :
619- logger .warning (
620- f"Out of order message received got { msg .seq } expected "
621- f"{ self .ack } "
622- )
617+ # Update bookkeeping
618+ if msg .seq < self .ack :
619+ raise IgnoreMessageException (
620+ f"{ msg .from_ } received duplicate msg, got { msg .seq } "
621+ f" expected { self .ack } "
622+ )
623+ elif msg .seq > self .ack :
624+ logger .warning (
625+ f"Out of order message received got { msg .seq } expected "
626+ f"{ self .ack } "
627+ )
623628
624- raise OutOfOrderMessageException (
625- f"Out of order message received got { msg .seq } expected "
626- f"{ self .ack } "
627- )
629+ raise OutOfOrderMessageException (
630+ f"Out of order message received got { msg .seq } expected "
631+ f"{ self .ack } "
632+ )
628633
629- assert msg .seq == self .ack , "Safety net, redundant assertion"
634+ assert msg .seq == self .ack , "Safety net, redundant assertion"
630635
631- # Set our next expected ack number
632- self .ack = msg .seq + 1
636+ # Set our next expected ack number
637+ self .ack = msg .seq + 1
633638
634- # Discard old server-ack'd messages from the ack buffer
635- while self ._ack_buffer and self ._ack_buffer [0 ].seq < msg .ack :
636- self ._ack_buffer .popleft ()
639+ # Discard old server-ack'd messages from the ack buffer
640+ while self ._ack_buffer and self ._ack_buffer [0 ].seq < msg .ack :
641+ self ._ack_buffer .popleft ()
637642
638- self ._reset_session_close_countdown ()
643+ self ._reset_session_close_countdown ()
639644
640- # Shortcut to avoid processing ack packets
641- if msg .controlFlags & ACK_BIT != 0 :
642- continue
645+ # Shortcut to avoid processing ack packets
646+ if msg .controlFlags & ACK_BIT != 0 :
647+ continue
643648
644- stream = self ._streams .get (msg .streamId , None )
645- if msg .controlFlags & STREAM_OPEN_BIT != 0 :
646- raise InvalidMessageException (
647- "Client should not receive stream open bit"
648- )
649+ stream = self ._streams .get (msg .streamId , None )
650+ if msg .controlFlags & STREAM_OPEN_BIT != 0 :
651+ raise InvalidMessageException (
652+ "Client should not receive stream open bit"
653+ )
649654
650- if not stream :
651- logger .warning ("no stream for %s" , msg .streamId )
652- raise IgnoreMessageException ("no stream for message, ignoring" )
653-
654- if (
655- msg .controlFlags & STREAM_CLOSED_BIT != 0
656- and msg .payload .get ("type" , None ) == "CLOSE"
657- ):
658- # close message is not sent to the stream
659- pass
660- else :
661- try :
662- await stream .put (msg .payload )
663- except ChannelClosed :
664- # The client is no longer interested in this stream,
665- # just drop the message.
655+ if not stream :
656+ logger .warning ("no stream for %s" , msg .streamId )
657+ raise IgnoreMessageException ("no stream for message, ignoring" )
658+
659+ if (
660+ msg .controlFlags & STREAM_CLOSED_BIT != 0
661+ and msg .payload .get ("type" , None ) == "CLOSE"
662+ ):
663+ # close message is not sent to the stream
666664 pass
667- except RuntimeError as e :
668- raise InvalidMessageException (e ) from e
669-
670- if msg .controlFlags & STREAM_CLOSED_BIT != 0 :
671- if stream :
672- stream .close ()
673- del self ._streams [msg .streamId ]
674- except IgnoreMessageException :
675- logger .debug ("Ignoring transport message" , exc_info = True )
676- continue
677- except OutOfOrderMessageException :
678- logger .exception ("Out of order message, closing connection" )
679- self ._task_manager .create_task (
680- self ._ws_unwrapped .close (
681- code = CloseCode .INVALID_DATA ,
682- reason = "Out of order message" ,
665+ else :
666+ try :
667+ await stream .put (msg .payload )
668+ except ChannelClosed :
669+ # The client is no longer interested in this stream,
670+ # just drop the message.
671+ pass
672+ except RuntimeError as e :
673+ raise InvalidMessageException (e ) from e
674+
675+ if msg .controlFlags & STREAM_CLOSED_BIT != 0 :
676+ if stream :
677+ stream .close ()
678+ del self ._streams [msg .streamId ]
679+ except IgnoreMessageException :
680+ logger .debug ("Ignoring transport message" , exc_info = True )
681+ continue
682+ except OutOfOrderMessageException :
683+ logger .exception ("Out of order message, closing connection" )
684+ self ._task_manager .create_task (
685+ self ._ws_unwrapped .close (
686+ code = CloseCode .INVALID_DATA ,
687+ reason = "Out of order message" ,
688+ )
683689 )
684- )
685- return
686- except InvalidMessageException :
687- logger .exception ("Got invalid transport message, closing session" )
688- await self .close ()
689- return
690- except ConnectionClosedOK :
691- # Exited normally
692- self ._state = SessionState .CONNECTING
693- except ConnectionClosed as e :
694- self ._state = SessionState .CONNECTING
695- raise e
690+ return
691+ except InvalidMessageException :
692+ logger .exception ("Got invalid transport message, closing session" )
693+ await self .close ()
694+ return
695+ except ConnectionClosedOK :
696+ # Exited normally
697+ self ._state = SessionState .CONNECTING
698+ except ConnectionClosed as e :
699+ self ._state = SessionState .CONNECTING
700+ raise e
696701 logging .debug ("_handle_messages_from_ws exiting" ) # When the network disconnects this Task exits and then we don't restart it.
697702
698703 async def send_rpc [R , A ](
0 commit comments