Skip to content

Commit 1255cbd

Browse files
Make ensure_connected callable based on our own internal state
1 parent 8ff4128 commit 1255cbd

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

src/replit_river/v2/client_transport.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,15 @@ async def get_or_create_session(self) -> Session:
5757
transport_options=self._transport_options,
5858
close_session_callback=self._delete_session,
5959
retry_connection_callback=self._retry_connection,
60+
uri_and_metadata_factory=self._uri_and_metadata_factory,
61+
rate_limiter=self._rate_limiter,
62+
client_id=self._client_id,
6063
)
6164

6265
self._session = new_session
6366
existing_session = new_session
6467

65-
await existing_session.ensure_connected(
66-
client_id=self._client_id,
67-
rate_limiter=self._rate_limiter,
68-
uri_and_metadata_factory=self._uri_and_metadata_factory,
69-
)
68+
await existing_session.ensure_connected()
7069
return existing_session
7170

7271
async def _retry_connection(self) -> Session:

src/replit_river/v2/session.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class _IgnoreMessage:
119119
pass
120120

121121

122-
class Session:
122+
class Session[HandshakeMetadata]:
123123
_transport_id: str
124124
_to_id: str
125125
session_id: str
@@ -132,6 +132,12 @@ class Session:
132132
_connecting_task: asyncio.Task[None] | None
133133
_wait_for_connected: asyncio.Event
134134

135+
_client_id: str
136+
_rate_limiter: LeakyBucketRateLimit
137+
_uri_and_metadata_factory: Callable[
138+
[], Awaitable[UriAndMetadata[HandshakeMetadata]]
139+
]
140+
135141
# ws state
136142
_ws: ClientConnection | None
137143
_heartbeat_misses: int
@@ -161,6 +167,11 @@ def __init__(
161167
session_id: str,
162168
transport_options: TransportOptions,
163169
close_session_callback: CloseSessionCallback,
170+
client_id: str,
171+
rate_limiter: LeakyBucketRateLimit,
172+
uri_and_metadata_factory: Callable[
173+
[], Awaitable[UriAndMetadata[HandshakeMetadata]]
174+
],
164175
retry_connection_callback: RetryConnectionCallback | None = None,
165176
) -> None:
166177
self._transport_id = transport_id
@@ -175,6 +186,10 @@ def __init__(
175186
self._connecting_task = None
176187
self._wait_for_connected = asyncio.Event()
177188

189+
self._client_id = client_id
190+
self._rate_limiter = rate_limiter
191+
self._uri_and_metadata_factory = uri_and_metadata_factory
192+
178193
# ws state
179194
self._ws = None
180195
self._heartbeat_misses = 0
@@ -204,14 +219,7 @@ def __init__(
204219
self._start_close_session_checker()
205220
self._start_buffered_message_sender()
206221

207-
async def ensure_connected[HandshakeMetadata](
208-
self,
209-
client_id: str,
210-
rate_limiter: LeakyBucketRateLimit,
211-
uri_and_metadata_factory: Callable[
212-
[], Awaitable[UriAndMetadata[HandshakeMetadata]]
213-
],
214-
) -> None:
222+
async def ensure_connected(self) -> None:
215223
"""
216224
Either return immediately or establish a websocket connection and return
217225
once we can accept messages.
@@ -279,12 +287,12 @@ def finalize_attempt() -> None:
279287
self._connecting_task = asyncio.create_task(
280288
_do_ensure_connected(
281289
transport_id=self._transport_id,
282-
client_id=client_id,
290+
client_id=self._client_id,
283291
to_id=self._to_id,
284292
session_id=self.session_id,
285293
max_retry=self._transport_options.connection_retry_options.max_retry,
286-
rate_limiter=rate_limiter,
287-
uri_and_metadata_factory=uri_and_metadata_factory,
294+
rate_limiter=self._rate_limiter,
295+
uri_and_metadata_factory=self._uri_and_metadata_factory,
288296
get_next_sent_seq=get_next_sent_seq,
289297
get_current_ack=lambda: self.ack,
290298
get_current_time=self._get_current_time,

0 commit comments

Comments
 (0)