@@ -113,19 +113,40 @@ async def handshake_to_get_session(
113113 async def close (self ) -> None :
114114 await self ._close_all_sessions (self ._get_all_sessions )
115115
116+ async def _get_existing_session (self , to_id : str ) -> ServerSession | None :
117+ async with self ._session_lock :
118+ return self ._sessions .get (to_id )
119+
116120 async def _get_or_create_session (
117121 self ,
118122 transport_id : str ,
119123 to_id : str ,
120124 session_id : str ,
121125 websocket : WebSocketCommonProtocol ,
122126 ) -> ServerSession :
123- async with self ._session_lock :
124- session_to_close : Session | None = None
125- new_session : ServerSession | None = None
126- if to_id not in self ._sessions :
127+ new_session : ServerSession | None = None
128+ old_session : ServerSession | None = await self ._get_existing_session (to_id )
129+ if not old_session :
130+ logger .info (
131+ 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
132+ )
133+ new_session = ServerSession (
134+ transport_id ,
135+ to_id ,
136+ session_id ,
137+ websocket ,
138+ self ._transport_options ,
139+ self ._handlers ,
140+ close_session_callback = self ._delete_session ,
141+ )
142+ else :
143+ if old_session .session_id != session_id :
127144 logger .info (
128- 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
145+ 'Create new session with "%s" for session id %s'
146+ " and close old session %s" ,
147+ to_id ,
148+ session_id ,
149+ old_session .session_id ,
129150 )
130151 new_session = ServerSession (
131152 transport_id ,
@@ -137,43 +158,26 @@ async def _get_or_create_session(
137158 close_session_callback = self ._delete_session ,
138159 )
139160 else :
140- old_session = self ._sessions [to_id ]
141- if old_session .session_id != session_id :
142- logger .info (
143- 'Create new session with "%s" for session id %s'
144- " and close old session %s" ,
145- to_id ,
146- session_id ,
147- old_session .session_id ,
148- )
149- session_to_close = old_session
150- new_session = ServerSession (
151- transport_id ,
152- to_id ,
153- session_id ,
154- websocket ,
155- self ._transport_options ,
156- self ._handlers ,
157- close_session_callback = self ._delete_session ,
158- )
159- else :
160- # If the instance id is the same, we reuse the session and assign
161- # a new websocket to it.
162- logger .debug (
163- 'Reuse old session with "%s" using new ws: %s' ,
164- to_id ,
165- websocket .id ,
166- )
167- try :
168- await old_session .replace_with_new_websocket (websocket )
169- new_session = old_session
170- except FailedSendingMessageException as e :
171- raise e
161+ # If the instance id is the same, we reuse the session and assign
162+ # a new websocket to it.
163+ logger .debug (
164+ 'Reuse old session with "%s" using new ws: %s' ,
165+ to_id ,
166+ websocket .id ,
167+ )
168+ try :
169+ await old_session .replace_with_new_websocket (websocket )
170+ new_session = old_session
171+ except FailedSendingMessageException as e :
172+ raise e
172173
173- if session_to_close :
174- logger .info ("Closing stale session %s" , session_to_close .session_id )
175- await session_to_close .close ()
174+ if old_session and new_session != old_session :
175+ logger .info ("Closing stale session %s" , old_session .session_id )
176+ await old_session .close ()
177+
178+ async with self ._session_lock :
176179 self ._sessions [new_session ._to_id ] = new_session
180+
177181 return new_session
178182
179183 async def _send_handshake_response (
@@ -249,71 +253,67 @@ async def _establish_handshake(
249253 )
250254 raise InvalidMessageException ("handshake request to wrong server" )
251255
252- async with self ._session_lock :
253- old_session = self ._sessions .get (request_message .from_ , None )
254- client_next_expected_seq = (
255- handshake_request .expectedSessionState .nextExpectedSeq
256- )
257- client_next_sent_seq = (
258- handshake_request .expectedSessionState .nextSentSeq or 0
259- )
260- if old_session and old_session .session_id == handshake_request .sessionId :
261- # check invariants
262- # ordering must be correct
263- our_next_seq = await old_session .get_next_sent_seq ()
264- our_ack = await old_session .get_next_expected_seq ()
265-
266- if client_next_sent_seq > our_ack :
267- message = (
268- "client is in the future: "
269- f"server wanted { our_ack } but client has { client_next_sent_seq } "
270- )
271- await self ._send_handshake_response (
272- request_message ,
273- HandShakeStatus (ok = False , reason = message ),
274- websocket ,
275- )
276- raise SessionStateMismatchException (message )
256+ old_session = await self ._get_existing_session (request_message .from_ )
257+ client_next_expected_seq = (
258+ handshake_request .expectedSessionState .nextExpectedSeq
259+ )
260+ client_next_sent_seq = handshake_request .expectedSessionState .nextSentSeq or 0
261+ if old_session and old_session .session_id == handshake_request .sessionId :
262+ # check invariants
263+ # ordering must be correct
264+ our_next_seq = await old_session .get_next_sent_seq ()
265+ our_ack = await old_session .get_next_expected_seq ()
277266
278- if our_next_seq > client_next_expected_seq :
279- message = (
280- "server is in the future: "
281- f"client wanted { client_next_expected_seq } "
282- f"but server has { our_next_seq } "
283- )
284- await self ._send_handshake_response (
285- request_message ,
286- HandShakeStatus (ok = False , reason = message ),
287- websocket ,
288- )
289- raise SessionStateMismatchException (message )
290- elif old_session :
291- # we have an old session but the session id is different
292- # just delete the old session
293- await old_session .close ()
294- await self ._delete_session (old_session )
295- old_session = None
267+ if client_next_sent_seq > our_ack :
268+ message = (
269+ "client is in the future: "
270+ f"server wanted { our_ack } but client has { client_next_sent_seq } "
271+ )
272+ await self ._send_handshake_response (
273+ request_message ,
274+ HandShakeStatus (ok = False , reason = message ),
275+ websocket ,
276+ )
277+ raise SessionStateMismatchException (message )
296278
297- if not old_session and (
298- client_next_sent_seq > 0 or client_next_expected_seq > 0
299- ):
300- message = "client is trying to resume a session but we don't have it"
279+ if our_next_seq > client_next_expected_seq :
280+ message = (
281+ "server is in the future: "
282+ f"client wanted { client_next_expected_seq } "
283+ f"but server has { our_next_seq } "
284+ )
301285 await self ._send_handshake_response (
302286 request_message ,
303287 HandShakeStatus (ok = False , reason = message ),
304288 websocket ,
305289 )
306290 raise SessionStateMismatchException (message )
291+ elif old_session :
292+ # we have an old session but the session id is different
293+ # just delete the old session
294+ await old_session .close ()
295+ old_session = None
307296
308- # from this point on, we're committed to connecting
309- session_id = handshake_request .sessionId
310- handshake_response = await self ._send_handshake_response (
297+ if not old_session and (
298+ client_next_sent_seq > 0 or client_next_expected_seq > 0
299+ ):
300+ message = "client is trying to resume a session but we don't have it"
301+ await self ._send_handshake_response (
311302 request_message ,
312- HandShakeStatus (ok = True , sessionId = session_id ),
303+ HandShakeStatus (ok = False , reason = message ),
313304 websocket ,
314305 )
306+ raise SessionStateMismatchException (message )
307+
308+ # from this point on, we're committed to connecting
309+ session_id = handshake_request .sessionId
310+ handshake_response = await self ._send_handshake_response (
311+ request_message ,
312+ HandShakeStatus (ok = True , sessionId = session_id ),
313+ websocket ,
314+ )
315315
316- return handshake_request , handshake_response
316+ return handshake_request , handshake_response
317317
318318 def _get_all_sessions (self ) -> Mapping [str , Session ]:
319319 return self ._sessions
0 commit comments