Skip to content

Commit d6e57a6

Browse files
Apply Jacky's patch
1 parent 3e5ad7f commit d6e57a6

File tree

2 files changed

+92
-93
lines changed

2 files changed

+92
-93
lines changed

src/replit_river/client_transport.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ async def _establish_handshake(
369369
# If the session status is mismatched, we should close the old session
370370
# and let the retry logic to create a new session.
371371
await old_session.close()
372-
await self._delete_session(old_session)
373372

374373
raise RiverException(
375374
ERROR_HANDSHAKE,

src/replit_river/server_transport.py

Lines changed: 92 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,40 @@ async def handshake_to_get_session(
113113
async def close(self) -> None:
114114
await self._close_all_sessions(self._get_all_sessions)
115115

116+
async def _get_existing_session(self, to_id: str) -> ServerSession | None:
117+
async with self._session_lock:
118+
return self._sessions.get(to_id)
119+
116120
async def _get_or_create_session(
117121
self,
118122
transport_id: str,
119123
to_id: str,
120124
session_id: str,
121125
websocket: WebSocketCommonProtocol,
122126
) -> ServerSession:
123-
async with self._session_lock:
124-
session_to_close: Session | None = None
125-
new_session: ServerSession | None = None
126-
if to_id not in self._sessions:
127+
new_session: ServerSession | None = None
128+
old_session: ServerSession | None = await self._get_existing_session(to_id)
129+
if not old_session:
130+
logger.info(
131+
'Creating new session with "%s" using ws: %s', to_id, websocket.id
132+
)
133+
new_session = ServerSession(
134+
transport_id,
135+
to_id,
136+
session_id,
137+
websocket,
138+
self._transport_options,
139+
self._handlers,
140+
close_session_callback=self._delete_session,
141+
)
142+
else:
143+
if old_session.session_id != session_id:
127144
logger.info(
128-
'Creating new session with "%s" using ws: %s', to_id, websocket.id
145+
'Create new session with "%s" for session id %s'
146+
" and close old session %s",
147+
to_id,
148+
session_id,
149+
old_session.session_id,
129150
)
130151
new_session = ServerSession(
131152
transport_id,
@@ -137,43 +158,26 @@ async def _get_or_create_session(
137158
close_session_callback=self._delete_session,
138159
)
139160
else:
140-
old_session = self._sessions[to_id]
141-
if old_session.session_id != session_id:
142-
logger.info(
143-
'Create new session with "%s" for session id %s'
144-
" and close old session %s",
145-
to_id,
146-
session_id,
147-
old_session.session_id,
148-
)
149-
session_to_close = old_session
150-
new_session = ServerSession(
151-
transport_id,
152-
to_id,
153-
session_id,
154-
websocket,
155-
self._transport_options,
156-
self._handlers,
157-
close_session_callback=self._delete_session,
158-
)
159-
else:
160-
# If the instance id is the same, we reuse the session and assign
161-
# a new websocket to it.
162-
logger.debug(
163-
'Reuse old session with "%s" using new ws: %s',
164-
to_id,
165-
websocket.id,
166-
)
167-
try:
168-
await old_session.replace_with_new_websocket(websocket)
169-
new_session = old_session
170-
except FailedSendingMessageException as e:
171-
raise e
161+
# If the instance id is the same, we reuse the session and assign
162+
# a new websocket to it.
163+
logger.debug(
164+
'Reuse old session with "%s" using new ws: %s',
165+
to_id,
166+
websocket.id,
167+
)
168+
try:
169+
await old_session.replace_with_new_websocket(websocket)
170+
new_session = old_session
171+
except FailedSendingMessageException as e:
172+
raise e
172173

173-
if session_to_close:
174-
logger.info("Closing stale session %s", session_to_close.session_id)
175-
await session_to_close.close()
174+
if old_session and new_session != old_session:
175+
logger.info("Closing stale session %s", old_session.session_id)
176+
await old_session.close()
177+
178+
async with self._session_lock:
176179
self._sessions[new_session._to_id] = new_session
180+
177181
return new_session
178182

179183
async def _send_handshake_response(
@@ -249,71 +253,67 @@ async def _establish_handshake(
249253
)
250254
raise InvalidMessageException("handshake request to wrong server")
251255

252-
async with self._session_lock:
253-
old_session = self._sessions.get(request_message.from_, None)
254-
client_next_expected_seq = (
255-
handshake_request.expectedSessionState.nextExpectedSeq
256-
)
257-
client_next_sent_seq = (
258-
handshake_request.expectedSessionState.nextSentSeq or 0
259-
)
260-
if old_session and old_session.session_id == handshake_request.sessionId:
261-
# check invariants
262-
# ordering must be correct
263-
our_next_seq = await old_session.get_next_sent_seq()
264-
our_ack = await old_session.get_next_expected_seq()
265-
266-
if client_next_sent_seq > our_ack:
267-
message = (
268-
"client is in the future: "
269-
f"server wanted {our_ack} but client has {client_next_sent_seq}"
270-
)
271-
await self._send_handshake_response(
272-
request_message,
273-
HandShakeStatus(ok=False, reason=message),
274-
websocket,
275-
)
276-
raise SessionStateMismatchException(message)
256+
old_session = await self._get_existing_session(request_message.from_)
257+
client_next_expected_seq = (
258+
handshake_request.expectedSessionState.nextExpectedSeq
259+
)
260+
client_next_sent_seq = handshake_request.expectedSessionState.nextSentSeq or 0
261+
if old_session and old_session.session_id == handshake_request.sessionId:
262+
# check invariants
263+
# ordering must be correct
264+
our_next_seq = await old_session.get_next_sent_seq()
265+
our_ack = await old_session.get_next_expected_seq()
277266

278-
if our_next_seq > client_next_expected_seq:
279-
message = (
280-
"server is in the future: "
281-
f"client wanted {client_next_expected_seq} "
282-
f"but server has {our_next_seq}"
283-
)
284-
await self._send_handshake_response(
285-
request_message,
286-
HandShakeStatus(ok=False, reason=message),
287-
websocket,
288-
)
289-
raise SessionStateMismatchException(message)
290-
elif old_session:
291-
# we have an old session but the session id is different
292-
# just delete the old session
293-
await old_session.close()
294-
await self._delete_session(old_session)
295-
old_session = None
267+
if client_next_sent_seq > our_ack:
268+
message = (
269+
"client is in the future: "
270+
f"server wanted {our_ack} but client has {client_next_sent_seq}"
271+
)
272+
await self._send_handshake_response(
273+
request_message,
274+
HandShakeStatus(ok=False, reason=message),
275+
websocket,
276+
)
277+
raise SessionStateMismatchException(message)
296278

297-
if not old_session and (
298-
client_next_sent_seq > 0 or client_next_expected_seq > 0
299-
):
300-
message = "client is trying to resume a session but we don't have it"
279+
if our_next_seq > client_next_expected_seq:
280+
message = (
281+
"server is in the future: "
282+
f"client wanted {client_next_expected_seq} "
283+
f"but server has {our_next_seq}"
284+
)
301285
await self._send_handshake_response(
302286
request_message,
303287
HandShakeStatus(ok=False, reason=message),
304288
websocket,
305289
)
306290
raise SessionStateMismatchException(message)
291+
elif old_session:
292+
# we have an old session but the session id is different
293+
# just delete the old session
294+
await old_session.close()
295+
old_session = None
307296

308-
# from this point on, we're committed to connecting
309-
session_id = handshake_request.sessionId
310-
handshake_response = await self._send_handshake_response(
297+
if not old_session and (
298+
client_next_sent_seq > 0 or client_next_expected_seq > 0
299+
):
300+
message = "client is trying to resume a session but we don't have it"
301+
await self._send_handshake_response(
311302
request_message,
312-
HandShakeStatus(ok=True, sessionId=session_id),
303+
HandShakeStatus(ok=False, reason=message),
313304
websocket,
314305
)
306+
raise SessionStateMismatchException(message)
307+
308+
# from this point on, we're committed to connecting
309+
session_id = handshake_request.sessionId
310+
handshake_response = await self._send_handshake_response(
311+
request_message,
312+
HandShakeStatus(ok=True, sessionId=session_id),
313+
websocket,
314+
)
315315

316-
return handshake_request, handshake_response
316+
return handshake_request, handshake_response
317317

318318
def _get_all_sessions(self) -> Mapping[str, Session]:
319319
return self._sessions

0 commit comments

Comments
 (0)