Skip to content

Commit 2983732

Browse files
deadlock patch (#148)
Why === deadlock in the server-side reconnect case: 1. establish handshake acquires session lock: https://github.com/replit/river-python/blob/6e2bc1ab1392f5a8407a08f6c8f049ec9e4b308d/src/replit_river/server_transport.py#L249 2. if there is a session mismatch, we close the old session: https://github.com/replit/river-python/blob/6e2bc1ab1392f5a8407a08f6c8f049ec9e4b308d/src/replit_river/server_transport.py#L290 3. session.close calls _close_session_callback https://github.com/replit/river-python/blob/6e2bc1ab1392f5a8407a08f6c8f049ec9e4b308d/src/replit_river/session.py#L299 4. _delete_session also tries to acquire the session lock https://github.com/replit/river-python/blob/6e2bc1ab1392f5a8407a08f6c8f049ec9e4b308d/src/replit_river/server_transport.py#L316 What changed ============ 1. dont need to call _delete_session as .close will already do that 2. lift out session close outside of session lock 3. in the handshake case, let the final call to get_or_create_session replace the session and close the old one Test plan ========= added a test Notes ========= I have a draft of a more in-depth approach on [this branch](https://github.com/replit/river-python/tree/jackyzha0/fix-deadlock) which uses a lock-ownership-transfer based approach that should catch it more generically but ran into ownership problems lol Seeing as we are planning on migrating chat service to Node anyways so we only have one River server implementation, we hopefully don't have to maintain this surface for much longer 🤞 --------- Co-authored-by: Devon Stewart <[email protected]>
1 parent 6e2bc1a commit 2983732

File tree

4 files changed

+39
-16
lines changed

4 files changed

+39
-16
lines changed

src/replit_river/client_transport.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ async def _establish_handshake(
366366
# If the session status is mismatched, we should close the old session
367367
# and let the retry logic to create a new session.
368368
await old_session.close()
369-
await self._delete_session(old_session)
370369

371370
raise RiverException(
372371
ERROR_HANDSHAKE,

src/replit_river/server_transport.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,11 @@ async def _get_or_create_session(
117117
session_id: str,
118118
websocket: WebSocketCommonProtocol,
119119
) -> ServerSession:
120+
new_session: ServerSession | None = None
121+
old_session: ServerSession | None = None
120122
async with self._session_lock:
121-
session_to_close: Session | None = None
122-
new_session: ServerSession | None = None
123-
if to_id not in self._sessions:
123+
old_session = self._sessions.get(to_id)
124+
if not old_session:
124125
logger.info(
125126
'Creating new session with "%s" using ws: %s', to_id, websocket.id
126127
)
@@ -134,7 +135,6 @@ async def _get_or_create_session(
134135
close_session_callback=self._delete_session,
135136
)
136137
else:
137-
old_session = self._sessions[to_id]
138138
if old_session.session_id != session_id:
139139
logger.info(
140140
'Create new session with "%s" for session id %s'
@@ -143,7 +143,6 @@ async def _get_or_create_session(
143143
session_id,
144144
old_session.session_id,
145145
)
146-
session_to_close = old_session
147146
new_session = ServerSession(
148147
transport_id,
149148
to_id,
@@ -167,10 +166,12 @@ async def _get_or_create_session(
167166
except FailedSendingMessageException as e:
168167
raise e
169168

170-
if session_to_close:
171-
logger.info("Closing stale session %s", session_to_close.session_id)
172-
await session_to_close.close()
173169
self._sessions[new_session._to_id] = new_session
170+
171+
if old_session and new_session != old_session:
172+
logger.info("Closing stale session %s", old_session.session_id)
173+
await old_session.close()
174+
174175
return new_session
175176

176177
async def _send_handshake_response(
@@ -247,7 +248,7 @@ async def _establish_handshake(
247248
raise InvalidMessageException("handshake request to wrong server")
248249

249250
async with self._session_lock:
250-
old_session = self._sessions.get(request_message.from_, None)
251+
old_session = self._sessions.get(request_message.from_)
251252
client_next_expected_seq = (
252253
handshake_request.expectedSessionState.nextExpectedSeq
253254
)
@@ -285,10 +286,6 @@ async def _establish_handshake(
285286
)
286287
raise SessionStateMismatchException(message)
287288
elif old_session:
288-
# we have an old session but the session id is different
289-
# just delete the old session
290-
await old_session.close()
291-
await self._delete_session(old_session)
292289
old_session = None
293290

294291
if not old_session and (

tests/river_fixtures/clientserver.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import logging
32
from typing import AsyncGenerator, Literal
43

@@ -64,7 +63,6 @@ async def websocket_uri_factory() -> UriAndMetadata[None]:
6463
await client.close()
6564

6665
finally:
67-
await asyncio.sleep(1)
6866
logging.debug("Start closing test server")
6967
if binding:
7068
binding.close()

tests/test_communication.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,32 @@ async def test_ignore_flood_subscription(client: Client) -> None:
268268
timedelta(seconds=20),
269269
)
270270
assert response == "Hello, Alice!"
271+
272+
273+
@pytest.mark.asyncio
274+
@pytest.mark.parametrize("handlers", [{**basic_rpc_method}])
275+
async def test_rpc_method_reconnect(client: Client) -> None:
276+
response = await client.send_rpc(
277+
"test_service",
278+
"rpc_method",
279+
"Alice",
280+
serialize_request,
281+
deserialize_response,
282+
deserialize_error,
283+
timedelta(seconds=20),
284+
)
285+
assert response == "Hello, Alice!"
286+
287+
await client._transport._close_all_sessions()
288+
289+
response = await client.send_rpc(
290+
"test_service",
291+
"rpc_method",
292+
"Bob",
293+
serialize_request,
294+
deserialize_response,
295+
deserialize_error,
296+
timedelta(seconds=20),
297+
)
298+
299+
assert response == "Hello, Bob!"

0 commit comments

Comments
 (0)