Skip to content

Commit 3cd44a2

Browse files
Moving throw over into rate_limiter
1 parent e3b83b0 commit 3cd44a2

File tree

3 files changed

+96
-79
lines changed

3 files changed

+96
-79
lines changed

src/replit_river/rate_limiter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import asyncio
2+
import logging
23
import random
34
from contextvars import Context
5+
from typing import Literal
46

57
from replit_river.transport_options import ConnectionRetryOptions
8+
from replit_river.v2.client_transport import BudgetExhaustedException
9+
10+
logger = logging.getLogger(__name__)
611

712

813
class LeakyBucketRateLimit:
@@ -64,6 +69,31 @@ def has_budget(self, user: str) -> bool:
6469
"""
6570
return self.get_budget_consumed(user) < self.options.attempt_budget_capacity
6671

72+
def has_budget_or_throw(
73+
self,
74+
user: str,
75+
error_code: str,
76+
last_error: Exception | None,
77+
) -> Literal[True]:
78+
"""
79+
Check if the user has remaining budget to make a retry.
80+
If they do not, explode.
81+
82+
Args:
83+
user (str): The identifier for the user.
84+
85+
Returns:
86+
bool: True if budget is available, False otherwise.
87+
"""
88+
if self.get_budget_consumed(user) < self.options.attempt_budget_capacity:
89+
logger.debug("No retry budget for %s.", user)
90+
raise BudgetExhaustedException(
91+
error_code,
92+
"No retry budget",
93+
client_id=user,
94+
) from last_error
95+
return True
96+
6797
def consume_budget(self, user: str) -> None:
6898
"""Increment the budget consumed for the user by 1, indicating a retry attempt.
6999

src/replit_river/v2/client_transport.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def __init__(self, code: str, message: str, client_id: str) -> None:
5151
self.client_id = client_id
5252

5353

54+
class BudgetExhaustedException(RiverException):
55+
def __init__(self, code: str, message: str, client_id: str) -> None:
56+
super().__init__(code, message)
57+
self.client_id = client_id
58+
59+
5460
class ClientTransport(Generic[HandshakeMetadataType]):
5561
_session: Session | None
5662

src/replit_river/v2/session.py

Lines changed: 60 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
)
7272
from replit_river.v2.client_transport import (
7373
PROTOCOL_VERSION,
74-
HandshakeBudgetExhaustedException,
7574
)
7675

7776
STREAM_CANCEL_BIT_TYPE = Literal[0b00100]
@@ -202,11 +201,13 @@ def get_next_pending() -> TransportMessage | None:
202201
)
203202
)
204203

205-
async def ensure_connected(
204+
async def ensure_connected[HandshakeMetadata](
206205
self,
207206
client_id: str,
208207
rate_limiter: LeakyBucketRateLimit,
209-
uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]],
208+
uri_and_metadata_factory: Callable[
209+
[], Awaitable[UriAndMetadata[HandshakeMetadata]]
210+
], # noqa: E501
210211
) -> None:
211212
"""
212213
Either return immediately or establish a websocket connection and return
@@ -218,16 +219,11 @@ async def ensure_connected(
218219
logger.info("Attempting to establish new ws connection")
219220

220221
last_error: Exception | None = None
221-
for i in range(max_retry):
222+
i = 0
223+
while rate_limiter.has_budget_or_throw(client_id, ERROR_HANDSHAKE, last_error):
222224
if i > 0:
223225
logger.info(f"Retrying build handshake number {i} times")
224-
if not rate_limiter.has_budget(client_id):
225-
logger.debug("No retry budget for %s.", client_id)
226-
raise HandshakeBudgetExhaustedException(
227-
ERROR_HANDSHAKE,
228-
"No retry budget",
229-
client_id=client_id,
230-
) from last_error
226+
i += 1
231227

232228
rate_limiter.consume_budget(client_id)
233229

@@ -238,10 +234,12 @@ async def ensure_connected(
238234
try:
239235
try:
240236
expectedSessionState = ExpectedSessionState(
241-
nextExpectedSeq=0,
242-
nextSentSeq=0,
237+
nextExpectedSeq=self.ack,
238+
nextSentSeq=self.seq,
243239
)
244-
handshake_request = ControlMessageHandshakeRequest[Any](
240+
handshake_request = ControlMessageHandshakeRequest[
241+
HandshakeMetadata
242+
]( # noqa: E501
245243
type="HANDSHAKE_REQ",
246244
protocolVersion=PROTOCOL_VERSION,
247245
sessionId=self.session_id,
@@ -253,85 +251,68 @@ async def ensure_connected(
253251
async def websocket_closed_callback() -> None:
254252
logger.error("websocket closed before handshake response")
255253

254+
await send_transport_message(
255+
TransportMessage(
256+
from_=self._transport_id,
257+
to=self._to_id,
258+
streamId=stream_id,
259+
controlFlags=0,
260+
id=nanoid.generate(),
261+
seq=0,
262+
ack=0,
263+
payload=handshake_request.model_dump(),
264+
),
265+
ws=ws,
266+
websocket_closed_callback=websocket_closed_callback,
267+
)
268+
except (
269+
WebsocketClosedException,
270+
FailedSendingMessageException,
271+
) as e: # noqa: E501
272+
raise RiverException(
273+
ERROR_HANDSHAKE,
274+
"Handshake failed, conn closed while sending response", # noqa: E501
275+
) from e
276+
277+
startup_grace_deadline_ms = await self._get_current_time() + 60_000
278+
while True:
279+
if await self._get_current_time() >= startup_grace_deadline_ms: # noqa: E501
280+
raise RiverException(
281+
ERROR_HANDSHAKE,
282+
"Handshake response timeout, closing connection", # noqa: E501
283+
)
256284
try:
257-
payload = handshake_request.model_dump()
258-
await send_transport_message(
259-
TransportMessage(
260-
from_=self._transport_id,
261-
to=self._to_id,
262-
streamId=stream_id,
263-
controlFlags=0,
264-
id=nanoid.generate(),
265-
seq=0,
266-
ack=0,
267-
payload=payload,
268-
),
269-
ws=ws,
270-
websocket_closed_callback=websocket_closed_callback,
285+
data = await ws.recv()
286+
except ConnectionClosed as e:
287+
logger.debug(
288+
"Connection closed during waiting for handshake response", # noqa: E501
289+
exc_info=True,
271290
)
272-
except (
273-
WebsocketClosedException,
274-
FailedSendingMessageException,
275-
) as e: # noqa: E501
276291
raise RiverException(
277292
ERROR_HANDSHAKE,
278-
"Handshake failed, conn closed while sending response", # noqa: E501
293+
"Handshake failed, conn closed while waiting for response", # noqa: E501
294+
) from e
295+
try:
296+
response_msg = parse_transport_msg(data)
297+
break
298+
except IgnoreMessageException:
299+
logger.debug("Ignoring transport message", exc_info=True) # noqa: E501
300+
continue
301+
except InvalidMessageException as e:
302+
raise RiverException(
303+
ERROR_HANDSHAKE,
304+
"Got invalid transport message, closing connection",
279305
) from e
280-
except FailedSendingMessageException as e:
281-
raise RiverException(
282-
ERROR_CODE_STREAM_CLOSED,
283-
"Stream closed before response, closing connection",
284-
) from e
285306

286-
startup_grace_deadline_ms = await self._get_current_time() + 60_000
287307
try:
288-
while True:
289-
if (
290-
await self._get_current_time()
291-
>= startup_grace_deadline_ms
292-
): # noqa: E501
293-
raise RiverException(
294-
ERROR_HANDSHAKE,
295-
"Handshake response timeout, closing connection", # noqa: E501
296-
)
297-
try:
298-
data = await ws.recv()
299-
except ConnectionClosed as e:
300-
logger.debug(
301-
"Connection closed during waiting for handshake response", # noqa: E501
302-
exc_info=True,
303-
)
304-
raise RiverException(
305-
ERROR_HANDSHAKE,
306-
"Handshake failed, conn closed while waiting for response", # noqa: E501
307-
) from e
308-
try:
309-
response_msg = parse_transport_msg(data)
310-
break
311-
except IgnoreMessageException:
312-
logger.debug(
313-
"Ignoring transport message", exc_info=True
314-
) # noqa: E501
315-
continue
316-
except InvalidMessageException as e:
317-
raise RiverException(
318-
ERROR_HANDSHAKE,
319-
"Got invalid transport message, closing connection",
320-
) from e
321-
322308
handshake_response = ControlMessageHandshakeResponse(
323309
**response_msg.payload
324-
) # noqa: E501
310+
)
325311
logger.debug("river client waiting for handshake response")
326312
except ValidationError as e:
327313
raise RiverException(
328314
ERROR_HANDSHAKE, "Failed to parse handshake response"
329315
) from e
330-
except asyncio.TimeoutError as e:
331-
raise RiverException(
332-
ERROR_HANDSHAKE,
333-
"Handshake response timeout, closing connection", # noqa: E501
334-
) from e
335316

336317
logger.debug(
337318
"river client get handshake response : %r", handshake_response

0 commit comments

Comments
 (0)