@@ -113,19 +113,41 @@ async def handshake_to_get_session(
113113 async def close (self ) -> None :
114114 await self ._close_all_sessions (self ._get_all_sessions )
115115
116+ async def _get_existing_session (self , to_id : str ) -> ServerSession | None :
117+ async with self ._session_lock :
118+ return self ._sessions .get (to_id )
119+
116120 async def _get_or_create_session (
117121 self ,
118122 transport_id : str ,
119123 to_id : str ,
120124 session_id : str ,
121125 websocket : WebSocketCommonProtocol ,
122126 ) -> ServerSession :
123- async with self ._session_lock :
124- session_to_close : Session | None = None
125- new_session : ServerSession | None = None
126- if to_id not in self ._sessions :
127+ session_to_close : Session | None = None
128+ new_session : ServerSession | None = None
129+ old_session : ServerSession | None = await self ._get_existing_session (to_id )
130+ if not old_session :
131+ logger .info (
132+ 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
133+ )
134+ new_session = ServerSession (
135+ transport_id ,
136+ to_id ,
137+ session_id ,
138+ websocket ,
139+ self ._transport_options ,
140+ self ._handlers ,
141+ close_session_callback = self ._delete_session ,
142+ )
143+ else :
144+ if old_session .session_id != session_id :
127145 logger .info (
128- 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
146+ 'Create new session with "%s" for session id %s'
147+ " and close old session %s" ,
148+ to_id ,
149+ session_id ,
150+ old_session .session_id ,
129151 )
130152 new_session = ServerSession (
131153 transport_id ,
@@ -137,43 +159,26 @@ async def _get_or_create_session(
137159 close_session_callback = self ._delete_session ,
138160 )
139161 else :
140- old_session = self ._sessions [to_id ]
141- if old_session .session_id != session_id :
142- logger .info (
143- 'Create new session with "%s" for session id %s'
144- " and close old session %s" ,
145- to_id ,
146- session_id ,
147- old_session .session_id ,
148- )
149- session_to_close = old_session
150- new_session = ServerSession (
151- transport_id ,
152- to_id ,
153- session_id ,
154- websocket ,
155- self ._transport_options ,
156- self ._handlers ,
157- close_session_callback = self ._delete_session ,
158- )
159- else :
160- # If the instance id is the same, we reuse the session and assign
161- # a new websocket to it.
162- logger .debug (
163- 'Reuse old session with "%s" using new ws: %s' ,
164- to_id ,
165- websocket .id ,
166- )
167- try :
168- await old_session .replace_with_new_websocket (websocket )
169- new_session = old_session
170- except FailedSendingMessageException as e :
171- raise e
162+ # If the instance id is the same, we reuse the session and assign
163+ # a new websocket to it.
164+ logger .debug (
165+ 'Reuse old session with "%s" using new ws: %s' ,
166+ to_id ,
167+ websocket .id ,
168+ )
169+ try :
170+ await old_session .replace_with_new_websocket (websocket )
171+ new_session = old_session
172+ except FailedSendingMessageException as e :
173+ raise e
172174
173- if session_to_close :
174- logger .info ("Closing stale session %s" , session_to_close .session_id )
175- await session_to_close .close ()
175+ if old_session and new_session != old_session :
176+ logger .info ("Closing stale session %s" , old_session .session_id )
177+ await old_session .close ()
178+
179+ async with self ._session_lock :
176180 self ._sessions [new_session ._to_id ] = new_session
181+
177182 return new_session
178183
179184 async def _send_handshake_response (
@@ -249,71 +254,67 @@ async def _establish_handshake(
249254 )
250255 raise InvalidMessageException ("handshake request to wrong server" )
251256
252- async with self ._session_lock :
253- old_session = self ._sessions .get (request_message .from_ , None )
254- client_next_expected_seq = (
255- handshake_request .expectedSessionState .nextExpectedSeq
256- )
257- client_next_sent_seq = (
258- handshake_request .expectedSessionState .nextSentSeq or 0
259- )
260- if old_session and old_session .session_id == handshake_request .sessionId :
261- # check invariants
262- # ordering must be correct
263- our_next_seq = await old_session .get_next_sent_seq ()
264- our_ack = await old_session .get_next_expected_seq ()
265-
266- if client_next_sent_seq > our_ack :
267- message = (
268- "client is in the future: "
269- f"server wanted { our_ack } but client has { client_next_sent_seq } "
270- )
271- await self ._send_handshake_response (
272- request_message ,
273- HandShakeStatus (ok = False , reason = message ),
274- websocket ,
275- )
276- raise SessionStateMismatchException (message )
257+ old_session = await self ._get_existing_session (request_message .from_ )
258+ client_next_expected_seq = (
259+ handshake_request .expectedSessionState .nextExpectedSeq
260+ )
261+ client_next_sent_seq = handshake_request .expectedSessionState .nextSentSeq or 0
262+ if old_session and old_session .session_id == handshake_request .sessionId :
263+ # check invariants
264+ # ordering must be correct
265+ our_next_seq = await old_session .get_next_sent_seq ()
266+ our_ack = await old_session .get_next_expected_seq ()
277267
278- if our_next_seq > client_next_expected_seq :
279- message = (
280- "server is in the future: "
281- f"client wanted { client_next_expected_seq } "
282- f"but server has { our_next_seq } "
283- )
284- await self ._send_handshake_response (
285- request_message ,
286- HandShakeStatus (ok = False , reason = message ),
287- websocket ,
288- )
289- raise SessionStateMismatchException (message )
290- elif old_session :
291- # we have an old session but the session id is different
292- # just delete the old session
293- await old_session .close ()
294- await self ._delete_session (old_session )
295- old_session = None
268+ if client_next_sent_seq > our_ack :
269+ message = (
270+ "client is in the future: "
271+ f"server wanted { our_ack } but client has { client_next_sent_seq } "
272+ )
273+ await self ._send_handshake_response (
274+ request_message ,
275+ HandShakeStatus (ok = False , reason = message ),
276+ websocket ,
277+ )
278+ raise SessionStateMismatchException (message )
296279
297- if not old_session and (
298- client_next_sent_seq > 0 or client_next_expected_seq > 0
299- ):
300- message = "client is trying to resume a session but we don't have it"
280+ if our_next_seq > client_next_expected_seq :
281+ message = (
282+ "server is in the future: "
283+ f"client wanted { client_next_expected_seq } "
284+ f"but server has { our_next_seq } "
285+ )
301286 await self ._send_handshake_response (
302287 request_message ,
303288 HandShakeStatus (ok = False , reason = message ),
304289 websocket ,
305290 )
306291 raise SessionStateMismatchException (message )
292+ elif old_session :
293+ # we have an old session but the session id is different
294+ # just delete the old session
295+ await old_session .close ()
296+ old_session = None
307297
308- # from this point on, we're committed to connecting
309- session_id = handshake_request .sessionId
310- handshake_response = await self ._send_handshake_response (
298+ if not old_session and (
299+ client_next_sent_seq > 0 or client_next_expected_seq > 0
300+ ):
301+ message = "client is trying to resume a session but we don't have it"
302+ await self ._send_handshake_response (
311303 request_message ,
312- HandShakeStatus (ok = True , sessionId = session_id ),
304+ HandShakeStatus (ok = False , reason = message ),
313305 websocket ,
314306 )
307+ raise SessionStateMismatchException (message )
308+
309+ # from this point on, we're committed to connecting
310+ session_id = handshake_request .sessionId
311+ handshake_response = await self ._send_handshake_response (
312+ request_message ,
313+ HandShakeStatus (ok = True , sessionId = session_id ),
314+ websocket ,
315+ )
315316
316- return handshake_request , handshake_response
317+ return handshake_request , handshake_response
317318
318319 def _get_all_sessions (self ) -> Mapping [str , Session ]:
319320 return self ._sessions
0 commit comments