Skip to content

Commit b536b64

Browse files
blast-hardcheesejackyzha0
authored andcommitted
Inlining no-longer-invariant _sessions access
1 parent d6e57a6 commit b536b64

File tree

2 files changed

+9
-21
lines changed

2 files changed

+9
-21
lines changed

src/replit_river/client_transport.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from collections.abc import Awaitable, Callable
4-
from typing import Generic, Mapping, assert_never
4+
from typing import Generic, assert_never
55

66
import nanoid
77
import websockets
@@ -71,11 +71,8 @@ def __init__(
7171
# We want to make sure there's only one session creation at a time
7272
self._create_session_lock = asyncio.Lock()
7373

74-
async def _close_all_sessions(
75-
self,
76-
get_all_sessions: Callable[[], Mapping[str, Session]],
77-
) -> None:
78-
sessions = get_all_sessions().values()
74+
async def _close_all_sessions(self) -> None:
75+
sessions = self._sessions.values()
7976
logger.info(
8077
f"start closing sessions {self._transport_id}, number sessions : "
8178
f"{len(sessions)}"
@@ -94,7 +91,7 @@ def generate_nanoid(self) -> str:
9491

9592
async def close(self) -> None:
9693
self._rate_limiter.close()
97-
await self._close_all_sessions(self._get_all_sessions)
94+
await self._close_all_sessions()
9895

9996
async def get_or_create_session(self) -> ClientSession:
10097
async with self._create_session_lock:
@@ -235,7 +232,7 @@ async def _create_new_session(
235232

236233
async def _retry_connection(self) -> ClientSession:
237234
if not self._transport_options.transparent_reconnect:
238-
await self._close_all_sessions(self._get_all_sessions)
235+
await self._close_all_sessions()
239236
return await self.get_or_create_session()
240237

241238
async def _send_handshake_request(
@@ -377,9 +374,6 @@ async def _establish_handshake(
377374
)
378375
return handshake_request, handshake_response
379376

380-
def _get_all_sessions(self) -> Mapping[str, Session]:
381-
return self._sessions
382-
383377
async def _delete_session(self, session: Session) -> None:
384378
async with self._session_lock:
385379
if session._to_id in self._sessions:

src/replit_river/server_transport.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Any, Callable, Mapping
3+
from typing import Any
44

55
import nanoid # type: ignore # type: ignore
66
from pydantic import ValidationError
@@ -51,11 +51,8 @@ def __init__(
5151
self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] = {}
5252
self._session_lock = asyncio.Lock()
5353

54-
async def _close_all_sessions(
55-
self,
56-
get_all_sessions: Callable[[], Mapping[str, Session]],
57-
) -> None:
58-
sessions = get_all_sessions().values()
54+
async def _close_all_sessions(self) -> None:
55+
sessions = self._sessions.values()
5956
logger.info(
6057
f"start closing sessions {self._transport_id}, number sessions : "
6158
f"{len(sessions)}"
@@ -111,7 +108,7 @@ async def handshake_to_get_session(
111108
raise WebsocketClosedException("No handshake message received")
112109

113110
async def close(self) -> None:
114-
await self._close_all_sessions(self._get_all_sessions)
111+
await self._close_all_sessions()
115112

116113
async def _get_existing_session(self, to_id: str) -> ServerSession | None:
117114
async with self._session_lock:
@@ -315,9 +312,6 @@ async def _establish_handshake(
315312

316313
return handshake_request, handshake_response
317314

318-
def _get_all_sessions(self) -> Mapping[str, Session]:
319-
return self._sessions
320-
321315
async def _delete_session(self, session: Session) -> None:
322316
async with self._session_lock:
323317
if session._to_id in self._sessions:

0 commit comments

Comments
 (0)