Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ lint:
format:
uv run ruff format src tests
uv run ruff check src tests --fix

test:
uv run pytest tests
1 change: 0 additions & 1 deletion src/replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ async def _establish_handshake(
# If the session status is mismatched, we should close the old session
# and let the retry logic to create a new session.
await old_session.close()
await self._delete_session(old_session)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol dupes


raise RiverException(
ERROR_HANDSHAKE,
Expand Down
186 changes: 93 additions & 93 deletions src/replit_river/server_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,41 @@ async def handshake_to_get_session(
async def close(self) -> None:
await self._close_all_sessions()

async def _get_existing_session(self, to_id: str) -> Optional[Session]:
async with self._session_lock:
return self._sessions.get(to_id)

async def _get_or_create_session(
self,
transport_id: str,
to_id: str,
session_id: str,
websocket: WebSocketCommonProtocol,
) -> Session:
async with self._session_lock:
session_to_close: Optional[Session] = None
new_session: Optional[Session] = None
if to_id not in self._sessions:
new_session: Optional[Session] = None
old_session: Optional[Session] = await self._get_existing_session(to_id)
if not old_session:
logger.info(
'Creating new session with "%s" using ws: %s', to_id, websocket.id
)
new_session = Session(
transport_id,
to_id,
session_id,
websocket,
self._transport_options,
self._is_server,
self._handlers,
close_session_callback=self._delete_session,
)
else:
if old_session.session_id != session_id:
logger.info(
'Creating new session with "%s" using ws: %s', to_id, websocket.id
'Create new session with "%s" for session id %s'
" and close old session %s",
to_id,
session_id,
old_session.session_id,
)
new_session = Session(
transport_id,
Expand All @@ -115,44 +137,26 @@ async def _get_or_create_session(
close_session_callback=self._delete_session,
)
else:
old_session = self._sessions[to_id]
if old_session.session_id != session_id:
logger.info(
'Create new session with "%s" for session id %s'
" and close old session %s",
to_id,
session_id,
old_session.session_id,
)
session_to_close = old_session
new_session = Session(
transport_id,
to_id,
session_id,
websocket,
self._transport_options,
self._is_server,
self._handlers,
close_session_callback=self._delete_session,
)
else:
# If the instance id is the same, we reuse the session and assign
# a new websocket to it.
logger.debug(
'Reuse old session with "%s" using new ws: %s',
to_id,
websocket.id,
)
try:
await old_session.replace_with_new_websocket(websocket)
new_session = old_session
except FailedSendingMessageException as e:
raise e
# If the instance id is the same, we reuse the session and assign
# a new websocket to it.
logger.debug(
'Reuse old session with "%s" using new ws: %s',
to_id,
websocket.id,
)
try:
await old_session.replace_with_new_websocket(websocket)
new_session = old_session
except FailedSendingMessageException as e:
raise e

if session_to_close:
logger.info("Closing stale session %s", session_to_close.session_id)
await session_to_close.close()
if old_session and new_session != old_session:
logger.info("Closing stale session %s", old_session.session_id)
await old_session.close()

async with self._session_lock:
self._set_session(new_session)
Comment on lines +157 to 158
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for safety, one property that the old code had that this one doesn't is that there could be two concurrent calls to this function (with the same id?) and then we would close the old session twice and leak one of the new sessions!

to prevent that what we can do is to rename _set_session to _set_session_locked (and add a docstring comment that it needs the session_lock to operate). and have _set_session_locked detect this mid-air collision and ... do something with the preexisting one (close it???)

with that out of the way, this relaxing of the locks' critical sections seems to be safe, because the Session by itself is safe to create twice (as long as we close it, otherwise we leak tasks!!!)


return new_session

async def _send_handshake_response(
Expand Down Expand Up @@ -228,68 +232,64 @@ async def _establish_handshake(
)
raise InvalidMessageException("handshake request to wrong server")

async with self._session_lock:
old_session = self._sessions.get(request_message.from_, None)
client_next_expected_seq = (
handshake_request.expectedSessionState.nextExpectedSeq
)
client_next_sent_seq = (
handshake_request.expectedSessionState.nextSentSeq or 0
)
if old_session and old_session.session_id == handshake_request.sessionId:
# check invariants
# ordering must be correct
our_next_seq = await old_session.get_next_sent_seq()
our_ack = await old_session.get_next_expected_seq()

if client_next_sent_seq > our_ack:
message = (
"client is in the future: "
f"server wanted {our_ack} but client has {client_next_sent_seq}"
)
await self._send_handshake_response(
request_message,
HandShakeStatus(ok=False, reason=message),
websocket,
)
raise SessionStateMismatchException(message)
old_session = await self._get_existing_session(request_message.from_)
client_next_expected_seq = (
handshake_request.expectedSessionState.nextExpectedSeq
)
client_next_sent_seq = handshake_request.expectedSessionState.nextSentSeq or 0
if old_session and old_session.session_id == handshake_request.sessionId:
# check invariants
# ordering must be correct
our_next_seq = await old_session.get_next_sent_seq()
our_ack = await old_session.get_next_expected_seq()

if our_next_seq > client_next_expected_seq:
message = (
"server is in the future: "
f"client wanted {client_next_expected_seq} "
f"but server has {our_next_seq}"
)
await self._send_handshake_response(
request_message,
HandShakeStatus(ok=False, reason=message),
websocket,
)
raise SessionStateMismatchException(message)
elif old_session:
# we have an old session but the session id is different
# just delete the old session
await old_session.close()
await self._delete_session(old_session)
old_session = None
if client_next_sent_seq > our_ack:
message = (
"client is in the future: "
f"server wanted {our_ack} but client has {client_next_sent_seq}"
)
await self._send_handshake_response(
request_message,
HandShakeStatus(ok=False, reason=message),
websocket,
)
raise SessionStateMismatchException(message)

if not old_session and (
client_next_sent_seq > 0 or client_next_expected_seq > 0
):
message = "client is trying to resume a session but we don't have it"
if our_next_seq > client_next_expected_seq:
message = (
"server is in the future: "
f"client wanted {client_next_expected_seq} "
f"but server has {our_next_seq}"
)
await self._send_handshake_response(
request_message,
HandShakeStatus(ok=False, reason=message),
websocket,
)
raise SessionStateMismatchException(message)
elif old_session:
# we have an old session but the session id is different
# just delete the old session
await old_session.close()
old_session = None

# from this point on, we're committed to connecting
session_id = handshake_request.sessionId
handshake_response = await self._send_handshake_response(
if not old_session and (
client_next_sent_seq > 0 or client_next_expected_seq > 0
):
message = "client is trying to resume a session but we don't have it"
await self._send_handshake_response(
request_message,
HandShakeStatus(ok=True, sessionId=session_id),
HandShakeStatus(ok=False, reason=message),
websocket,
)
raise SessionStateMismatchException(message)

# from this point on, we're committed to connecting
session_id = handshake_request.sessionId
handshake_response = await self._send_handshake_response(
request_message,
HandShakeStatus(ok=True, sessionId=session_id),
websocket,
)

return handshake_request, handshake_response
return handshake_request, handshake_response
2 changes: 0 additions & 2 deletions tests/river_fixtures/clientserver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
from typing import AsyncGenerator, Literal

Expand Down Expand Up @@ -62,7 +61,6 @@ async def websocket_uri_factory() -> UriAndMetadata[None]:
logging.debug("Start closing test client : %s", "test_client")
await client.close()
finally:
await asyncio.sleep(1)
logging.debug("Start closing test server")
await server.close()
# Server should close normally
Expand Down
28 changes: 28 additions & 0 deletions tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,31 @@ async def test_ignore_flood_subscription(client: Client) -> None:
timedelta(seconds=20),
)
assert response == "Hello, Alice!"


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_rpc_method}])
async def test_rpc_method_reconnect(client: Client) -> None:
response = await client.send_rpc(
"test_service",
"rpc_method",
"Alice",
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
assert response == "Hello, Alice!"

await client._transport._close_all_sessions()
response = await client.send_rpc(
"test_service",
"rpc_method",
"Bob",
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)

assert response == "Hello, Bob!"