@@ -566,7 +566,7 @@ async def start_serve_responses(self) -> None:
566566 async def transition_connecting () -> None :
567567 self ._state = SessionState .CONNECTING
568568
569- async def transition_closed () -> None :
569+ async def connection_interrupted () -> None :
570570 self ._state = SessionState .CONNECTING
571571 if self ._retry_connection_callback :
572572 self ._task_manager .create_task (self ._retry_connection_callback ())
@@ -614,7 +614,7 @@ def close_stream(stream_id: str) -> None:
614614 get_state = lambda : self ._state ,
615615 get_ws = lambda : self ._ws_unwrapped ,
616616 transition_connecting = transition_connecting ,
617- transition_closed = transition_closed ,
617+ connection_interrupted = connection_interrupted ,
618618 reset_session_close_countdown = self ._reset_session_close_countdown ,
619619 close_session = self .close ,
620620 assert_incoming_seq_bookkeeping = assert_incoming_seq_bookkeeping ,
@@ -935,7 +935,7 @@ async def _serve(
935935 get_state : Callable [[], SessionState ],
936936 get_ws : Callable [[], ClientConnection | None ],
937937 transition_connecting : Callable [[], Awaitable [None ]],
938- transition_closed : Callable [[], Awaitable [None ]],
938+ connection_interrupted : Callable [[], Awaitable [None ]],
939939 reset_session_close_countdown : Callable [[], None ],
940940 close_session : Callable [[], Awaitable [None ]],
941941 assert_incoming_seq_bookkeeping : Callable [
@@ -952,117 +952,115 @@ async def _serve(
952952 logging .debug (f"_serve loop count={ idx } " )
953953 idx += 1
954954 try :
955- try :
956- logging .debug ("_handle_messages_from_ws started" )
957- while (
958- ws := get_ws ()
959- ) is None or get_state () == SessionState .CONNECTING :
960- logging .debug ("_handle_messages_from_ws spinning while connecting" )
961- await asyncio .sleep (1 )
962- logger .debug (
963- "%s start handling messages from ws %s" ,
964- "client" ,
965- ws .id ,
966- )
955+ logging .debug ("_handle_messages_from_ws started" )
956+ while (ws := get_ws ()) is None or get_state () == SessionState .CONNECTING :
957+ logging .debug ("_handle_messages_from_ws spinning while connecting" )
958+ await asyncio .sleep (1 )
959+ logger .debug (
960+ "%s start handling messages from ws %s" ,
961+ "client" ,
962+ ws .id ,
963+ )
964+ # We should not process messages if the websocket is closed.
965+ while (ws := get_ws ()) and get_state () == SessionState .ACTIVE :
966+ # decode=False: Avoiding an unnecessary round-trip through str
967+ # Ideally this should be type-ascripted to : bytes, but there
968+ # is no @overrides in `websockets` to hint this.
969+ message = await ws .recv (decode = False )
967970 try :
968- # We should not process messages if the websocket is closed.
969- while ws := get_ws ():
970- # decode=False: Avoiding an unnecessary round-trip through str
971- # Ideally this should be type-ascripted to : bytes, but there
972- # is no @overrides in `websockets` to hint this.
973- message = await ws .recv (decode = False )
974- try :
975- msg = parse_transport_msg (message )
971+ msg = parse_transport_msg (message )
972+ logger .debug (
973+ "[%s] got a message %r" ,
974+ transport_id ,
975+ msg ,
976+ )
977+
978+ if msg .controlFlags & STREAM_OPEN_BIT != 0 :
979+ raise InvalidMessageException (
980+ "Client should not receive stream open bit"
981+ )
982+
983+ match assert_incoming_seq_bookkeeping (
984+ msg .from_ ,
985+ msg .seq ,
986+ msg .ack ,
987+ ):
988+ case _IgnoreMessage ():
976989 logger .debug (
977- "[%s] got a message %r" ,
978- transport_id ,
979- msg ,
990+ "Ignoring transport message" ,
991+ exc_info = True ,
980992 )
993+ continue
994+ case True :
995+ pass
996+ case other :
997+ assert_never (other )
981998
982- if msg .controlFlags & STREAM_OPEN_BIT != 0 :
983- raise InvalidMessageException (
984- "Client should not receive stream open bit"
985- )
986-
987- match assert_incoming_seq_bookkeeping (
988- msg .from_ ,
989- msg .seq ,
990- msg .ack ,
991- ):
992- case _IgnoreMessage ():
993- logger .debug (
994- "Ignoring transport message" ,
995- exc_info = True ,
996- )
997- continue
998- case True :
999- pass
1000- case other :
1001- assert_never (other )
1002-
1003- reset_session_close_countdown ()
1004-
1005- # Shortcut to avoid processing ack packets
1006- if msg .controlFlags & ACK_BIT != 0 :
1007- continue
1008-
1009- stream = get_stream (msg .streamId )
1010-
1011- if not stream :
1012- logger .warning (
1013- "no stream for %s, ignoring message" ,
1014- msg .streamId ,
1015- )
1016- continue
1017-
1018- if (
1019- msg .controlFlags & STREAM_CLOSED_BIT != 0
1020- and msg .payload .get ("type" , None ) == "CLOSE"
1021- ):
1022- # close message is not sent to the stream
1023- pass
1024- else :
1025- try :
1026- await stream .put (msg .payload )
1027- except ChannelClosed :
1028- # The client is no longer interested in this stream,
1029- # just drop the message.
1030- pass
1031- except RuntimeError as e :
1032- raise InvalidMessageException (e ) from e
1033-
1034- if msg .controlFlags & STREAM_CLOSED_BIT != 0 :
1035- if stream :
1036- stream .close ()
1037- close_stream (msg .streamId )
1038- except OutOfOrderMessageException :
1039- logger .exception ("Out of order message, closing connection" )
1040- await close_session ()
1041- return
1042- except InvalidMessageException :
1043- logger .exception (
1044- "Got invalid transport message, closing session" ,
1045- )
1046- await close_session ()
1047- return
999+ reset_session_close_countdown ()
1000+
1001+ # Shortcut to avoid processing ack packets
1002+ if msg .controlFlags & ACK_BIT != 0 :
1003+ continue
1004+
1005+ stream = get_stream (msg .streamId )
1006+
1007+ if not stream :
1008+ logger .warning (
1009+ "no stream for %s, ignoring message" ,
1010+ msg .streamId ,
1011+ )
1012+ continue
1013+
1014+ if (
1015+ msg .controlFlags & STREAM_CLOSED_BIT != 0
1016+ and msg .payload .get ("type" , None ) == "CLOSE"
1017+ ):
1018+ # close message is not sent to the stream
1019+ pass
1020+ else :
1021+ try :
1022+ await stream .put (msg .payload )
1023+ except ChannelClosed :
1024+ # The client is no longer interested in this stream,
1025+ # just drop the message.
1026+ pass
1027+ except RuntimeError as e :
1028+ raise InvalidMessageException (e ) from e
1029+
1030+ if msg .controlFlags & STREAM_CLOSED_BIT != 0 :
1031+ if stream :
1032+ stream .close ()
1033+ close_stream (msg .streamId )
1034+ except OutOfOrderMessageException :
1035+ logger .exception ("Out of order message, closing connection" )
1036+ await close_session ()
1037+ continue
1038+ except InvalidMessageException :
1039+ logger .exception (
1040+ "Got invalid transport message, closing session" ,
1041+ )
1042+ await close_session ()
1043+ continue
10481044 except ConnectionClosedOK :
10491045 # Exited normally
10501046 transition_connecting ()
1051- except ConnectionClosed as e :
1052- transition_connecting ()
1053- raise e
1054- logging .debug ("_handle_messages_from_ws exiting" )
1055- except ConnectionClosed :
1056- # Set ourselves to closed as soon as we get the signal
1057- await transition_closed ()
1058- logger .debug ("ConnectionClosed while serving" , exc_info = True )
1059- except FailedSendingMessageException :
1060- # Expected error if the connection is closed.
1061- logger .debug (
1062- "FailedSendingMessageException while serving" , exc_info = True
1063- )
1064- except Exception :
1065- logger .exception ("caught exception at message iterator" )
1047+ break
1048+ except ConnectionClosed :
1049+ # Set ourselves to closed as soon as we get the signal
1050+ await connection_interrupted ()
1051+ logger .debug ("ConnectionClosed while serving" , exc_info = True )
1052+ break
1053+ except FailedSendingMessageException :
1054+ # Expected error if the connection is closed.
1055+ await connection_interrupted ()
1056+ logger .debug (
1057+ "FailedSendingMessageException while serving" , exc_info = True
1058+ )
1059+ break
1060+ except Exception :
1061+ logger .exception ("caught exception at message iterator" )
1062+ break
1063+ logging .debug ("_handle_messages_from_ws exiting" )
10661064 except ExceptionGroup as eg :
10671065 _ , unhandled = eg .split (lambda e : isinstance (e , ConnectionClosed ))
10681066 if unhandled :
0 commit comments