Skip to content

Commit 86f3c76

Browse files
Break out "ServerSession" type
1 parent e27368f commit 86f3c76

File tree

6 files changed

+65
-29
lines changed

6 files changed

+65
-29
lines changed

src/replit_river/client_transport.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from collections.abc import Awaitable, Callable
4-
from typing import Generic
4+
from typing import Generic, Mapping
55

66
import websockets
77
from pydantic import ValidationError
@@ -36,6 +36,7 @@
3636
IgnoreMessageException,
3737
InvalidMessageException,
3838
)
39+
from replit_river.session import Session
3940
from replit_river.transport import Transport
4041
from replit_river.transport_options import (
4142
HandshakeMetadataType,
@@ -47,6 +48,8 @@
4748

4849

4950
class ClientTransport(Transport, Generic[HandshakeMetadataType]):
51+
_sessions: dict[str, ClientSession]
52+
5053
def __init__(
5154
self,
5255
uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]],
@@ -59,6 +62,7 @@ def __init__(
5962
transport_options=transport_options,
6063
is_server=False,
6164
)
65+
self._sessions = {}
6266
self._uri_and_metadata_factory = uri_and_metadata_factory
6367
self._client_id = client_id
6468
self._server_id = server_id
@@ -70,7 +74,7 @@ def __init__(
7074

7175
async def close(self) -> None:
7276
self._rate_limiter.close()
73-
await self._close_all_sessions()
77+
await self._close_all_sessions(self._get_all_sessions)
7478

7579
async def get_or_create_session(self) -> ClientSession:
7680
async with self._create_session_lock:
@@ -207,13 +211,13 @@ async def _create_new_session(
207211
handlers={},
208212
)
209213

210-
self._set_session(new_session)
214+
self._sessions[new_session._to_id] = new_session
211215
await new_session.start_serve_responses()
212216
return new_session
213217

214218
async def _retry_connection(self) -> ClientSession:
215219
if not self._transport_options.transparent_reconnect:
216-
await self._close_all_sessions()
220+
await self._close_all_sessions(self._get_all_sessions)
217221
return await self.get_or_create_session()
218222

219223
async def _send_handshake_request(
@@ -352,3 +356,11 @@ async def _establish_handshake(
352356
+ f"{handshake_response.status.reason}",
353357
)
354358
return handshake_request, handshake_response
359+
360+
def _get_all_sessions(self) -> Mapping[str, Session]:
361+
return self._sessions
362+
363+
async def _delete_session(self, session: Session) -> None:
364+
async with self._session_lock:
365+
if session._to_id in self._sessions:
366+
del self._sessions[session._to_id]

src/replit_river/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from replit_river.messages import WebsocketClosedException
1010
from replit_river.seq_manager import SessionStateMismatchException
11+
from replit_river.server_session import ServerSession
1112
from replit_river.server_transport import ServerTransport
12-
from replit_river.session import Session
1313
from replit_river.transport import TransportOptions
1414

1515
from .rpc import (
@@ -41,7 +41,7 @@ def add_rpc_handlers(
4141

4242
async def _handshake_to_get_session(
4343
self, websocket: WebSocketServerProtocol
44-
) -> Session | None:
44+
) -> ServerSession | None:
4545
"""This is a wrapper to make sentry happy, sentry doesn't recognize the
4646
exception handling outside of a task or asyncio.wait_for. So we need to catch
4747
the errors specifically here.

src/replit_river/server_session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import logging
2+
3+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
4+
5+
from replit_river.session import Session
6+
7+
from .rpc import (
8+
TransportMessageTracingSetter,
9+
)
10+
11+
logger = logging.getLogger(__name__)
12+
13+
trace_propagator = TraceContextTextMapPropagator()
14+
trace_setter = TransportMessageTracingSetter()
15+
16+
17+
class ServerSession(Session):
18+
"""A transport object that handles the websocket connection with a client."""
19+
20+
pass

src/replit_river/server_transport.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any
2+
from typing import Any, Mapping
33

44
import nanoid # type: ignore # type: ignore
55
from pydantic import ValidationError
@@ -27,6 +27,7 @@
2727
InvalidMessageException,
2828
SessionStateMismatchException,
2929
)
30+
from replit_river.server_session import ServerSession
3031
from replit_river.session import Session
3132
from replit_river.transport import Transport
3233
from replit_river.transport_options import TransportOptions
@@ -35,6 +36,8 @@
3536

3637

3738
class ServerTransport(Transport):
39+
_sessions: dict[str, ServerSession]
40+
3841
def __init__(
3942
self,
4043
transport_id: str,
@@ -45,11 +48,12 @@ def __init__(
4548
transport_options=transport_options,
4649
is_server=True,
4750
)
51+
self._sessions = {}
4852

4953
async def handshake_to_get_session(
5054
self,
5155
websocket: WebSocketServerProtocol,
52-
) -> Session:
56+
) -> ServerSession:
5357
async for message in websocket:
5458
try:
5559
msg = parse_transport_msg(message, self._transport_options)
@@ -88,23 +92,23 @@ async def handshake_to_get_session(
8892
raise WebsocketClosedException("No handshake message received")
8993

9094
async def close(self) -> None:
91-
await self._close_all_sessions()
95+
await self._close_all_sessions(self._get_all_sessions)
9296

9397
async def _get_or_create_session(
9498
self,
9599
transport_id: str,
96100
to_id: str,
97101
session_id: str,
98102
websocket: WebSocketCommonProtocol,
99-
) -> Session:
103+
) -> ServerSession:
100104
async with self._session_lock:
101105
session_to_close: Session | None = None
102-
new_session: Session | None = None
106+
new_session: ServerSession | None = None
103107
if to_id not in self._sessions:
104108
logger.info(
105109
'Creating new session with "%s" using ws: %s', to_id, websocket.id
106110
)
107-
new_session = Session(
111+
new_session = ServerSession(
108112
transport_id,
109113
to_id,
110114
session_id,
@@ -125,7 +129,7 @@ async def _get_or_create_session(
125129
old_session.session_id,
126130
)
127131
session_to_close = old_session
128-
new_session = Session(
132+
new_session = ServerSession(
129133
transport_id,
130134
to_id,
131135
session_id,
@@ -152,7 +156,7 @@ async def _get_or_create_session(
152156
if session_to_close:
153157
logger.info("Closing stale session %s", session_to_close.session_id)
154158
await session_to_close.close()
155-
self._set_session(new_session)
159+
self._sessions[new_session._to_id] = new_session
156160
return new_session
157161

158162
async def _send_handshake_response(
@@ -293,3 +297,11 @@ async def _establish_handshake(
293297
)
294298

295299
return handshake_request, handshake_response
300+
301+
def _get_all_sessions(self) -> Mapping[str, Session]:
302+
return self._sessions
303+
304+
async def _delete_session(self, session: Session) -> None:
305+
async with self._session_lock:
306+
if session._to_id in self._sessions:
307+
del self._sessions[session._to_id]

src/replit_river/session.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444

4545
class Session:
46-
"""A transport object that handles the websocket connection with a client."""
46+
"""Common functionality shared between client_session and server_session"""
4747

4848
def __init__(
4949
self,
@@ -253,9 +253,6 @@ async def replace_with_new_websocket(
253253
await old_wrapper.close()
254254
self._ws_wrapper = WebsocketWrapper(new_ws)
255255
await self._send_buffered_messages(new_ws)
256-
# Server will call serve itself.
257-
if not self._is_server:
258-
await self.start_serve_responses()
259256

260257
async def _get_current_time(self) -> float:
261258
return asyncio.get_event_loop().time()

src/replit_river/transport.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
from typing import Callable, Mapping
34

45
import nanoid # type: ignore
56

@@ -22,12 +23,14 @@ def __init__(
2223
self._transport_id = transport_id
2324
self._transport_options = transport_options
2425
self._is_server = is_server
25-
self._sessions: dict[str, Session] = {}
2626
self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] = {}
2727
self._session_lock = asyncio.Lock()
2828

29-
async def _close_all_sessions(self) -> None:
30-
sessions = self._sessions.values()
29+
async def _close_all_sessions(
30+
self,
31+
get_all_sessions: Callable[[], Mapping[str, Session]],
32+
) -> None:
33+
sessions = get_all_sessions().values()
3134
logger.info(
3235
f"start closing sessions {self._transport_id}, number sessions : "
3336
f"{len(sessions)}"
@@ -41,13 +44,5 @@ async def _close_all_sessions(self) -> None:
4144

4245
logger.info(f"Transport closed {self._transport_id}")
4346

44-
async def _delete_session(self, session: Session) -> None:
45-
async with self._session_lock:
46-
if session._to_id in self._sessions:
47-
del self._sessions[session._to_id]
48-
49-
def _set_session(self, session: Session) -> None:
50-
self._sessions[session._to_id] = session
51-
5247
def generate_nanoid(self) -> str:
5348
return str(nanoid.generate())

0 commit comments

Comments
 (0)