Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions httpx_retries/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Retry:
HTTPStatus.GATEWAY_TIMEOUT,
]
)
RETRYABLE_EXCEPTIONS: Final[Tuple[Type[httpx.HTTPError], ...]] = (
RETRYABLE_EXCEPTIONS: Final[Tuple[Type[Exception], ...]] = (
httpx.TimeoutException,
httpx.NetworkError,
httpx.RemoteProtocolError,
Expand All @@ -80,7 +80,7 @@ def __init__(
total: int = 10,
allowed_methods: Optional[Iterable[Union[HTTPMethod, str]]] = None,
status_forcelist: Optional[Iterable[Union[HTTPStatus, int]]] = None,
retry_on_exceptions: Optional[Iterable[Type[httpx.HTTPError]]] = None,
retry_on_exceptions: Optional[Iterable[Type[Exception]]] = None,
backoff_factor: float = 0.0,
respect_retry_after_header: bool = True,
max_backoff_wait: float = 120.0,
Expand Down Expand Up @@ -122,7 +122,7 @@ def is_retryable_status_code(self, status_code: int) -> bool:
"""Check if a status code is retryable."""
return status_code in self.status_forcelist

def is_retryable_exception(self, exception: httpx.HTTPError) -> bool:
def is_retryable_exception(self, exception: Exception) -> bool:
"""Check if an exception is retryable."""
return isinstance(exception, self.retryable_exceptions)

Expand Down Expand Up @@ -216,7 +216,7 @@ def _calculate_sleep(self, headers: Union[httpx.Headers, Mapping[str, str]]) ->
# Fall back to backoff strategy
return self.backoff_strategy() if self.attempts_made > 0 else 0.0

def sleep(self, response: Union[httpx.Response, httpx.HTTPError]) -> None:
def sleep(self, response: Union[httpx.Response, Exception]) -> None:
"""
Sleep between retry attempts using the calculated duration.

Expand All @@ -228,7 +228,7 @@ def sleep(self, response: Union[httpx.Response, httpx.HTTPError]) -> None:
logger.debug("sleep seconds=%s", time_to_sleep)
time.sleep(time_to_sleep)

async def asleep(self, response: Union[httpx.Response, httpx.HTTPError]) -> None:
async def asleep(self, response: Union[httpx.Response, Exception]) -> None:
"""
Sleep between retry attempts asynchronously using the calculated duration.

Expand Down
8 changes: 4 additions & 4 deletions httpx_retries/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _retry_operation(
send_method: Callable[..., httpx.Response],
) -> httpx.Response:
retry = self.retry
response: Union[httpx.Response, httpx.HTTPError, None] = None
response: Union[httpx.Response, Exception, None] = None

while True:
if response is not None:
Expand All @@ -122,7 +122,7 @@ def _retry_operation(
retry.sleep(response)
try:
response = send_method(request)
except httpx.HTTPError as e:
except Exception as e:
if retry.is_exhausted() or not retry.is_retryable_exception(e):
raise

Expand All @@ -138,7 +138,7 @@ async def _retry_operation_async(
send_method: Callable[..., Coroutine[Any, Any, httpx.Response]],
) -> httpx.Response:
retry = self.retry
response: Union[httpx.Response, httpx.HTTPError, None] = None
response: Union[httpx.Response, Exception, None] = None

while True:
if response is not None:
Expand All @@ -149,7 +149,7 @@ async def _retry_operation_async(
await retry.asleep(response)
try:
response = await send_method(request)
except httpx.HTTPError as e:
except Exception as e:
if retry.is_exhausted() or not retry.is_retryable_exception(e):
raise

Expand Down
9 changes: 9 additions & 0 deletions tests/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def test_is_retryable_exception() -> None:
assert retry.is_retryable_exception(httpx.LocalProtocolError("")) is False


def test_is_retryable_exception_custom_exception() -> None:
class MyExc(Exception):
pass

retry = Retry(retry_on_exceptions=(MyExc,))
assert retry.is_retryable_exception(httpx.NetworkError("")) is False
assert retry.is_retryable_exception(MyExc()) is True


def test_custom_retryable_methods_str() -> None:
retry = Retry(allowed_methods=["POST"])
assert retry.is_retryable_method("POST") is True
Expand Down
29 changes: 27 additions & 2 deletions tests/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,39 @@ def test_retryable_exception(mock_responses: MockResponse) -> None:
assert mock_sleep.call_count == 10


def test_retryable_exception_custom_exception(mock_responses: MockResponse) -> None:
mock_sleep, _ = mock_responses
transport = RetryTransport(retry=Retry(retry_on_exceptions=[ValueError]))

with patch("httpx.HTTPTransport.handle_request", side_effect=ValueError("oops")):
with httpx.Client(transport=transport) as client:
with pytest.raises(ValueError, match="oops"):
client.get("https://example.com")

assert mock_sleep.call_count == 10


@pytest.mark.asyncio
async def test_async_retryable_exception(mock_async_responses: AsyncMockResponse) -> None:
mock_asleep, _ = mock_async_responses
transport = RetryTransport()

with patch("httpx.AsyncHTTPTransport.handle_async_request", side_effect=httpx.ReadTimeout("Timeout!")):
with patch("httpx.AsyncHTTPTransport.handle_async_request", side_effect=httpx.ReadTimeout("oops")):
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(httpx.ReadTimeout, match="Timeout!"):
with pytest.raises(httpx.ReadTimeout, match="oops"):
await client.get("https://example.com")

assert mock_asleep.call_count == 10


@pytest.mark.asyncio
async def test_async_retryable_exception_custom_exception(mock_async_responses: AsyncMockResponse) -> None:
mock_asleep, _ = mock_async_responses
transport = RetryTransport(retry=Retry(retry_on_exceptions=[ValueError]))

with patch("httpx.AsyncHTTPTransport.handle_async_request", side_effect=ValueError("Timeout!")):
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(ValueError, match="Timeout!"):
await client.get("https://example.com")

assert mock_asleep.call_count == 10
Expand Down
Loading