Skip to content

Commit 6dcd3e9

Browse files
committed
Add standard retries and unit tests
1 parent 11872ef commit 6dcd3e9

File tree

4 files changed

+362
-23
lines changed

4 files changed

+362
-23
lines changed

packages/smithy-core/src/smithy_core/aio/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
330330
return await self._handle_attempt(call, request_context, request_future)
331331

332332
retry_strategy = call.retry_strategy
333-
retry_token = retry_strategy.acquire_initial_retry_token(
333+
retry_token = await retry_strategy.acquire_initial_retry_token(
334334
token_scope=call.retry_scope
335335
)
336336

@@ -349,7 +349,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
349349

350350
if isinstance(output_context.response, Exception):
351351
try:
352-
retry_strategy.refresh_retry_token_for_retry(
352+
retry_token = await retry_strategy.refresh_retry_token_for_retry(
353353
token_to_renew=retry_token,
354354
error=output_context.response,
355355
)
@@ -364,7 +364,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
364364

365365
await seek(request_context.transport_request.body, 0)
366366
else:
367-
retry_strategy.record_success(token=retry_token)
367+
await retry_strategy.record_success(token=retry_token)
368368
return output_context
369369

370370
async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](

packages/smithy-core/src/smithy_core/interfaces/retries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class RetryStrategy(Protocol):
6161
max_attempts: int
6262
"""Upper limit on total attempt count (initial attempt plus retries)."""
6363

64-
def acquire_initial_retry_token(
64+
async def acquire_initial_retry_token(
6565
self, *, token_scope: str | None = None
6666
) -> RetryToken:
6767
"""Called before any retries (for the first attempt at the operation).
@@ -74,7 +74,7 @@ def acquire_initial_retry_token(
7474
"""
7575
...
7676

77-
def refresh_retry_token_for_retry(
77+
async def refresh_retry_token_for_retry(
7878
self, *, token_to_renew: RetryToken, error: Exception
7979
) -> RetryToken:
8080
"""Replace an existing retry token from a failed attempt with a new token.
@@ -91,7 +91,7 @@ def refresh_retry_token_for_retry(
9191
"""
9292
...
9393

94-
def record_success(self, *, token: RetryToken) -> None:
94+
async def record_success(self, *, token: RetryToken) -> None:
9595
"""Return token after successful completion of an operation.
9696
9797
Upon successful completion of the operation, a user calls this function to

packages/smithy-core/src/smithy_core/retries.py

Lines changed: 157 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
34
import random
45
from collections.abc import Callable
56
from dataclasses import dataclass
@@ -204,7 +205,7 @@ def __init__(
204205
self.backoff_strategy = backoff_strategy or ExponentialRetryBackoffStrategy()
205206
self.max_attempts = max_attempts
206207

207-
def acquire_initial_retry_token(
208+
async def acquire_initial_retry_token(
208209
self, *, token_scope: str | None = None
209210
) -> SimpleRetryToken:
210211
"""Called before any retries (for the first attempt at the operation).
@@ -214,7 +215,7 @@ def acquire_initial_retry_token(
214215
retry_delay = self.backoff_strategy.compute_next_backoff_delay(0)
215216
return SimpleRetryToken(retry_count=0, retry_delay=retry_delay)
216217

217-
def refresh_retry_token_for_retry(
218+
async def refresh_retry_token_for_retry(
218219
self,
219220
*,
220221
token_to_renew: retries_interface.RetryToken,
@@ -240,5 +241,158 @@ def refresh_retry_token_for_retry(
240241
else:
241242
raise RetryError(f"Error is not retryable: {error}") from error
242243

243-
def record_success(self, *, token: retries_interface.RetryToken) -> None:
244+
async def record_success(self, *, token: retries_interface.RetryToken) -> None:
244245
"""Not used by this retry strategy."""
246+
247+
248+
@dataclass(kw_only=True)
249+
class StandardRetryToken:
250+
retry_count: int
251+
"""Retry count is the total number of attempts minus the initial attempt."""
252+
253+
retry_delay: float
254+
"""Delay in seconds to wait before the retry attempt."""
255+
256+
quota_consumed: int = 0
257+
"""The total amount of quota consumed."""
258+
259+
last_quota_acquired: int = 0
260+
"""The amount of last quota acquired."""
261+
262+
263+
class StandardRetryStrategy(retries_interface.RetryStrategy):
264+
def __init__(self, *, max_attempts: int = 3):
265+
"""Standard retry strategy using truncated binary exponential backoff with full
266+
jitter.
267+
268+
:param max_attempts: Upper limit on total number of attempts made, including
269+
initial attempt and retries.
270+
"""
271+
self.backoff_strategy = ExponentialRetryBackoffStrategy(
272+
backoff_scale_value=1,
273+
jitter_type=ExponentialBackoffJitterType.FULL,
274+
)
275+
self.max_attempts = max_attempts
276+
self._retry_quota = StandardRetryQuota()
277+
278+
async def acquire_initial_retry_token(
279+
self, *, token_scope: str | None = None
280+
) -> StandardRetryToken:
281+
"""Called before any retries (for the first attempt at the operation).
282+
283+
:param token_scope: This argument is ignored by this retry strategy.
284+
"""
285+
retry_delay = self.backoff_strategy.compute_next_backoff_delay(0)
286+
return StandardRetryToken(retry_count=0, retry_delay=retry_delay)
287+
288+
async def refresh_retry_token_for_retry(
289+
self,
290+
*,
291+
token_to_renew: StandardRetryToken,
292+
error: Exception,
293+
) -> StandardRetryToken:
294+
"""Replace an existing retry token from a failed attempt with a new token.
295+
296+
This retry strategy always returns a token until the attempt count stored in
297+
the new token exceeds the ``max_attempts`` value.
298+
299+
:param token_to_renew: The token used for the previous failed attempt.
300+
:param error: The error that triggered the need for a retry.
301+
:raises RetryError: If no further retry attempts are allowed.
302+
"""
303+
if isinstance(error, retries_interface.ErrorRetryInfo) and error.is_retry_safe:
304+
retry_count = token_to_renew.retry_count + 1
305+
if retry_count >= self.max_attempts:
306+
raise RetryError(
307+
f"Reached maximum number of allowed attempts: {self.max_attempts}"
308+
) from error
309+
310+
# Acquire additional quota for this retry attempt
311+
# (may raise a RetryError if none is available)
312+
quota_acquired = await self._retry_quota.acquire(error=error)
313+
total_quota = token_to_renew.quota_consumed + quota_acquired
314+
315+
if error.retry_after is not None:
316+
retry_delay = error.retry_after
317+
else:
318+
retry_delay = self.backoff_strategy.compute_next_backoff_delay(
319+
retry_count
320+
)
321+
322+
return StandardRetryToken(
323+
retry_count=retry_count,
324+
retry_delay=retry_delay,
325+
quota_consumed=total_quota,
326+
last_quota_acquired=quota_acquired,
327+
)
328+
else:
329+
raise RetryError(f"Error is not retryable: {error}") from error
330+
331+
async def record_success(self, *, token: StandardRetryToken) -> None:
332+
"""Return token after successful completion of an operation.
333+
334+
Releases retry tokens back to the retry quota based on the previous amount
335+
consumed.
336+
337+
:param token: The token used for the previous successful attempt.
338+
"""
339+
await self._retry_quota.release(release_amount=token.last_quota_acquired)
340+
341+
342+
class StandardRetryQuota:
343+
"""Retry quota used by :py:class:`StandardRetryStrategy`."""
344+
345+
INITIAL_RETRY_TOKENS = 500
346+
RETRY_COST = 5
347+
NO_RETRY_INCREMENT = 1
348+
TIMEOUT_RETRY_COST = 10
349+
350+
def __init__(self):
351+
self._max_capacity = self.INITIAL_RETRY_TOKENS
352+
self._available_capacity = self.INITIAL_RETRY_TOKENS
353+
self._lock = asyncio.Lock()
354+
355+
async def acquire(self, *, error: Exception) -> int:
356+
"""Attempt to acquire a certain amount of capacity.
357+
358+
If there's no sufficient amount of capacity available, raise an exception.
359+
Otherwise, we return the amount of capacity successfully allocated.
360+
"""
361+
# TODO: update `is_timeout` when `is_timeout_error` is implemented
362+
is_timeout = False
363+
capacity_amount = self.TIMEOUT_RETRY_COST if is_timeout else self.RETRY_COST
364+
365+
async with self._lock:
366+
if capacity_amount > self._available_capacity:
367+
raise RetryError("Retry quota exceeded")
368+
self._available_capacity -= capacity_amount
369+
return capacity_amount
370+
371+
async def release(self, *, release_amount: int) -> None:
372+
"""Release capacity back to the retry quota.
373+
374+
The capacity being released will be truncated if necessary to ensure the max
375+
capacity is never exceeded.
376+
"""
377+
increment = self.NO_RETRY_INCREMENT if release_amount == 0 else release_amount
378+
379+
if self._available_capacity == self._max_capacity:
380+
return
381+
382+
async with self._lock:
383+
self._available_capacity = min(
384+
self._available_capacity + increment, self._max_capacity
385+
)
386+
387+
388+
class RetryStrategyMode(Enum):
389+
"""Enumeration of available retry strategies."""
390+
391+
SIMPLE = "simple"
392+
STANDARD = "standard"
393+
394+
395+
RETRY_MODE_MAP = {
396+
RetryStrategyMode.SIMPLE: SimpleRetryStrategy,
397+
RetryStrategyMode.STANDARD: StandardRetryStrategy,
398+
}

0 commit comments

Comments
 (0)