@@ -119,7 +119,7 @@ class _IgnoreMessage:
119119 pass
120120
121121
122- class Session :
122+ class Session [ HandshakeMetadata ] :
123123 _transport_id : str
124124 _to_id : str
125125 session_id : str
@@ -132,6 +132,12 @@ class Session:
132132 _connecting_task : asyncio .Task [None ] | None
133133 _wait_for_connected : asyncio .Event
134134
135+ _client_id : str
136+ _rate_limiter : LeakyBucketRateLimit
137+ _uri_and_metadata_factory : Callable [
138+ [], Awaitable [UriAndMetadata [HandshakeMetadata ]]
139+ ]
140+
135141 # ws state
136142 _ws : ClientConnection | None
137143 _heartbeat_misses : int
@@ -161,6 +167,11 @@ def __init__(
161167 session_id : str ,
162168 transport_options : TransportOptions ,
163169 close_session_callback : CloseSessionCallback ,
170+ client_id : str ,
171+ rate_limiter : LeakyBucketRateLimit ,
172+ uri_and_metadata_factory : Callable [
173+ [], Awaitable [UriAndMetadata [HandshakeMetadata ]]
174+ ],
164175 retry_connection_callback : RetryConnectionCallback | None = None ,
165176 ) -> None :
166177 self ._transport_id = transport_id
@@ -175,6 +186,10 @@ def __init__(
175186 self ._connecting_task = None
176187 self ._wait_for_connected = asyncio .Event ()
177188
189+ self ._client_id = client_id
190+ self ._rate_limiter = rate_limiter
191+ self ._uri_and_metadata_factory = uri_and_metadata_factory
192+
178193 # ws state
179194 self ._ws = None
180195 self ._heartbeat_misses = 0
@@ -204,14 +219,7 @@ def __init__(
204219 self ._start_close_session_checker ()
205220 self ._start_buffered_message_sender ()
206221
207- async def ensure_connected [HandshakeMetadata ](
208- self ,
209- client_id : str ,
210- rate_limiter : LeakyBucketRateLimit ,
211- uri_and_metadata_factory : Callable [
212- [], Awaitable [UriAndMetadata [HandshakeMetadata ]]
213- ],
214- ) -> None :
222+ async def ensure_connected (self ) -> None :
215223 """
216224 Either return immediately or establish a websocket connection and return
217225 once we can accept messages.
@@ -279,12 +287,12 @@ def finalize_attempt() -> None:
279287 self ._connecting_task = asyncio .create_task (
280288 _do_ensure_connected (
281289 transport_id = self ._transport_id ,
282- client_id = client_id ,
290+ client_id = self . _client_id ,
283291 to_id = self ._to_id ,
284292 session_id = self .session_id ,
285293 max_retry = self ._transport_options .connection_retry_options .max_retry ,
286- rate_limiter = rate_limiter ,
287- uri_and_metadata_factory = uri_and_metadata_factory ,
294+ rate_limiter = self . _rate_limiter ,
295+ uri_and_metadata_factory = self . _uri_and_metadata_factory ,
288296 get_next_sent_seq = get_next_sent_seq ,
289297 get_current_ack = lambda : self .ack ,
290298 get_current_time = self ._get_current_time ,
0 commit comments