Skip to content

Commit b792c9a

Browse files
Apply Jacky's patch
1 parent 782855f commit b792c9a

File tree

2 files changed

+93
-93
lines changed

2 files changed

+93
-93
lines changed

src/replit_river/client_transport.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ async def _establish_handshake(
369369
# If the session status is mismatched, we should close the old session
370370
# and let the retry logic to create a new session.
371371
await old_session.close()
372-
await self._delete_session(old_session)
373372

374373
raise RiverException(
375374
ERROR_HANDSHAKE,

src/replit_river/server_transport.py

Lines changed: 93 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,41 @@ async def handshake_to_get_session(
113113
async def close(self) -> None:
114114
await self._close_all_sessions(self._get_all_sessions)
115115

116+
async def _get_existing_session(self, to_id: str) -> ServerSession | None:
117+
async with self._session_lock:
118+
return self._sessions.get(to_id)
119+
116120
async def _get_or_create_session(
117121
self,
118122
transport_id: str,
119123
to_id: str,
120124
session_id: str,
121125
websocket: WebSocketCommonProtocol,
122126
) -> ServerSession:
123-
async with self._session_lock:
124-
session_to_close: Session | None = None
125-
new_session: ServerSession | None = None
126-
if to_id not in self._sessions:
127+
session_to_close: Session | None = None
128+
new_session: ServerSession | None = None
129+
old_session: ServerSession | None = await self._get_existing_session(to_id)
130+
if not old_session:
131+
logger.info(
132+
'Creating new session with "%s" using ws: %s', to_id, websocket.id
133+
)
134+
new_session = ServerSession(
135+
transport_id,
136+
to_id,
137+
session_id,
138+
websocket,
139+
self._transport_options,
140+
self._handlers,
141+
close_session_callback=self._delete_session,
142+
)
143+
else:
144+
if old_session.session_id != session_id:
127145
logger.info(
128-
'Creating new session with "%s" using ws: %s', to_id, websocket.id
146+
'Create new session with "%s" for session id %s'
147+
" and close old session %s",
148+
to_id,
149+
session_id,
150+
old_session.session_id,
129151
)
130152
new_session = ServerSession(
131153
transport_id,
@@ -137,43 +159,26 @@ async def _get_or_create_session(
137159
close_session_callback=self._delete_session,
138160
)
139161
else:
140-
old_session = self._sessions[to_id]
141-
if old_session.session_id != session_id:
142-
logger.info(
143-
'Create new session with "%s" for session id %s'
144-
" and close old session %s",
145-
to_id,
146-
session_id,
147-
old_session.session_id,
148-
)
149-
session_to_close = old_session
150-
new_session = ServerSession(
151-
transport_id,
152-
to_id,
153-
session_id,
154-
websocket,
155-
self._transport_options,
156-
self._handlers,
157-
close_session_callback=self._delete_session,
158-
)
159-
else:
160-
# If the instance id is the same, we reuse the session and assign
161-
# a new websocket to it.
162-
logger.debug(
163-
'Reuse old session with "%s" using new ws: %s',
164-
to_id,
165-
websocket.id,
166-
)
167-
try:
168-
await old_session.replace_with_new_websocket(websocket)
169-
new_session = old_session
170-
except FailedSendingMessageException as e:
171-
raise e
162+
# If the instance id is the same, we reuse the session and assign
163+
# a new websocket to it.
164+
logger.debug(
165+
'Reuse old session with "%s" using new ws: %s',
166+
to_id,
167+
websocket.id,
168+
)
169+
try:
170+
await old_session.replace_with_new_websocket(websocket)
171+
new_session = old_session
172+
except FailedSendingMessageException as e:
173+
raise e
172174

173-
if session_to_close:
174-
logger.info("Closing stale session %s", session_to_close.session_id)
175-
await session_to_close.close()
175+
if old_session and new_session != old_session:
176+
logger.info("Closing stale session %s", old_session.session_id)
177+
await old_session.close()
178+
179+
async with self._session_lock:
176180
self._sessions[new_session._to_id] = new_session
181+
177182
return new_session
178183

179184
async def _send_handshake_response(
@@ -249,71 +254,67 @@ async def _establish_handshake(
249254
)
250255
raise InvalidMessageException("handshake request to wrong server")
251256

252-
async with self._session_lock:
253-
old_session = self._sessions.get(request_message.from_, None)
254-
client_next_expected_seq = (
255-
handshake_request.expectedSessionState.nextExpectedSeq
256-
)
257-
client_next_sent_seq = (
258-
handshake_request.expectedSessionState.nextSentSeq or 0
259-
)
260-
if old_session and old_session.session_id == handshake_request.sessionId:
261-
# check invariants
262-
# ordering must be correct
263-
our_next_seq = await old_session.get_next_sent_seq()
264-
our_ack = await old_session.get_next_expected_seq()
265-
266-
if client_next_sent_seq > our_ack:
267-
message = (
268-
"client is in the future: "
269-
f"server wanted {our_ack} but client has {client_next_sent_seq}"
270-
)
271-
await self._send_handshake_response(
272-
request_message,
273-
HandShakeStatus(ok=False, reason=message),
274-
websocket,
275-
)
276-
raise SessionStateMismatchException(message)
257+
old_session = await self._get_existing_session(request_message.from_)
258+
client_next_expected_seq = (
259+
handshake_request.expectedSessionState.nextExpectedSeq
260+
)
261+
client_next_sent_seq = handshake_request.expectedSessionState.nextSentSeq or 0
262+
if old_session and old_session.session_id == handshake_request.sessionId:
263+
# check invariants
264+
# ordering must be correct
265+
our_next_seq = await old_session.get_next_sent_seq()
266+
our_ack = await old_session.get_next_expected_seq()
277267

278-
if our_next_seq > client_next_expected_seq:
279-
message = (
280-
"server is in the future: "
281-
f"client wanted {client_next_expected_seq} "
282-
f"but server has {our_next_seq}"
283-
)
284-
await self._send_handshake_response(
285-
request_message,
286-
HandShakeStatus(ok=False, reason=message),
287-
websocket,
288-
)
289-
raise SessionStateMismatchException(message)
290-
elif old_session:
291-
# we have an old session but the session id is different
292-
# just delete the old session
293-
await old_session.close()
294-
await self._delete_session(old_session)
295-
old_session = None
268+
if client_next_sent_seq > our_ack:
269+
message = (
270+
"client is in the future: "
271+
f"server wanted {our_ack} but client has {client_next_sent_seq}"
272+
)
273+
await self._send_handshake_response(
274+
request_message,
275+
HandShakeStatus(ok=False, reason=message),
276+
websocket,
277+
)
278+
raise SessionStateMismatchException(message)
296279

297-
if not old_session and (
298-
client_next_sent_seq > 0 or client_next_expected_seq > 0
299-
):
300-
message = "client is trying to resume a session but we don't have it"
280+
if our_next_seq > client_next_expected_seq:
281+
message = (
282+
"server is in the future: "
283+
f"client wanted {client_next_expected_seq} "
284+
f"but server has {our_next_seq}"
285+
)
301286
await self._send_handshake_response(
302287
request_message,
303288
HandShakeStatus(ok=False, reason=message),
304289
websocket,
305290
)
306291
raise SessionStateMismatchException(message)
292+
elif old_session:
293+
# we have an old session but the session id is different
294+
# just delete the old session
295+
await old_session.close()
296+
old_session = None
307297

308-
# from this point on, we're committed to connecting
309-
session_id = handshake_request.sessionId
310-
handshake_response = await self._send_handshake_response(
298+
if not old_session and (
299+
client_next_sent_seq > 0 or client_next_expected_seq > 0
300+
):
301+
message = "client is trying to resume a session but we don't have it"
302+
await self._send_handshake_response(
311303
request_message,
312-
HandShakeStatus(ok=True, sessionId=session_id),
304+
HandShakeStatus(ok=False, reason=message),
313305
websocket,
314306
)
307+
raise SessionStateMismatchException(message)
308+
309+
# from this point on, we're committed to connecting
310+
session_id = handshake_request.sessionId
311+
handshake_response = await self._send_handshake_response(
312+
request_message,
313+
HandShakeStatus(ok=True, sessionId=session_id),
314+
websocket,
315+
)
315316

316-
return handshake_request, handshake_response
317+
return handshake_request, handshake_response
317318

318319
def _get_all_sessions(self) -> Mapping[str, Session]:
319320
return self._sessions

0 commit comments

Comments
 (0)