diff --git a/tests/test_fetcher_ng.py b/tests/test_fetcher_ng.py index d04b09f427..c5f842293a 100644 --- a/tests/test_fetcher_ng.py +++ b/tests/test_fetcher_ng.py @@ -137,6 +137,108 @@ def test_session_get_timeout(self, mock_session_get: Mock) -> None: self.fetcher.fetch(self.url) mock_session_get.assert_called_once() + # Test retry on ReadTimeoutError during streaming + @patch.object(urllib3.PoolManager, "request") + def test_download_bytes_retry_on_streaming_timeout( + self, mock_request: Mock + ) -> None: + """Test that download_bytes retries when ReadTimeoutError occurs during streaming.""" + mock_response_fail = Mock() + mock_response_fail.status = 200 + mock_response_fail.stream.side_effect = ( + urllib3.exceptions.ReadTimeoutError( + urllib3.connectionpool.ConnectionPool("localhost"), + "", + "Read timed out", + ) + ) + + mock_response_success = Mock() + mock_response_success.status = 200 + mock_response_success.stream.return_value = iter( + [self.file_contents[:4], self.file_contents[4:]] + ) + + mock_request.side_effect = [ + mock_response_fail, + mock_response_fail, + mock_response_success, + ] + + data = self.fetcher.download_bytes(self.url, self.file_length) + self.assertEqual(self.file_contents, data) + self.assertEqual(mock_request.call_count, 3) + + # Test retry exhaustion + @patch.object(urllib3.PoolManager, "request") + def test_download_bytes_retry_exhaustion(self, mock_request: Mock) -> None: + """Test that download_bytes fails after exhausting all retries.""" + # All attempts fail + mock_response = Mock() + mock_response.status = 200 + mock_response.stream.side_effect = urllib3.exceptions.ReadTimeoutError( + urllib3.connectionpool.ConnectionPool("localhost"), + "", + "Read timed out", + ) + mock_request.return_value = mock_response + + with self.assertRaises(exceptions.SlowRetrievalError): + self.fetcher.download_bytes(self.url, self.file_length) + # Should have been called 3 times (max_retries=3) + self.assertEqual(mock_request.call_count, 3) + + # Test retry on ProtocolError during streaming + @patch.object(urllib3.PoolManager, "request") + def test_download_bytes_retry_on_protocol_error( + self, mock_request: Mock + ) -> None: + """Test that download_bytes retries when ProtocolError occurs during streaming.""" + # First attempt fails with protocol error, second succeeds + mock_response_fail = Mock() + mock_response_fail.status = 200 + mock_response_fail.stream.side_effect = ( + urllib3.exceptions.ProtocolError("Connection broken") + ) + + mock_response_success = Mock() + mock_response_success.status = 200 + mock_response_success.stream.return_value = iter( + [self.file_contents[:4], self.file_contents[4:]] + ) + + mock_request.side_effect = [ + mock_response_fail, + mock_response_success, + ] + + data = self.fetcher.download_bytes(self.url, self.file_length) + self.assertEqual(self.file_contents, data) + self.assertEqual(mock_request.call_count, 2) + + # Test that non-timeout errors are not retried + @patch.object(urllib3.PoolManager, "request") + def test_download_bytes_no_retry_on_http_error( + self, mock_request: Mock + ) -> None: + """Test that download_bytes does not retry on HTTP errors like 404.""" + mock_response = Mock() + mock_response.status = 404 + mock_request.return_value = mock_response + + with self.assertRaises(exceptions.DownloadHTTPError): + self.fetcher.download_bytes(self.url, self.file_length) + # Should only be called once, no retries + mock_request.assert_called_once() + + # Test that length mismatch errors are not retried + def test_download_bytes_no_retry_on_length_mismatch(self) -> None: + """Test that download_bytes does not retry on length mismatch errors.""" + # Try to download more data than the file contains + with self.assertRaises(exceptions.DownloadLengthMismatchError): + # File is self.file_length bytes, asking for less should fail + self.fetcher.download_bytes(self.url, self.file_length - 4) + # Simple bytes download def test_download_bytes(self) -> None: data = self.fetcher.download_bytes(self.url, self.file_length) diff --git a/tuf/ngclient/urllib3_fetcher.py b/tuf/ngclient/urllib3_fetcher.py index 88d447bd30..3641be8e2c 100644 --- a/tuf/ngclient/urllib3_fetcher.py +++ b/tuf/ngclient/urllib3_fetcher.py @@ -12,6 +12,7 @@ # Imports import urllib3 +from urllib3.util.retry import Retry import tuf from tuf.api import exceptions @@ -50,7 +51,21 @@ def __init__( if app_user_agent is not None: ua = f"{app_user_agent} {ua}" - self._proxy_env = ProxyEnvironment(headers={"User-Agent": ua}) + # Configure retry strategy for connection-level retries. + # Note: This only retries at the HTTP request level (before streaming + # begins). Streaming failures are handled by the retry loop in + # download_bytes(). + retry_strategy = Retry( + total=3, + read=3, + connect=3, + status_forcelist=[500, 502, 503, 504], + raise_on_status=False, + ) + + self._proxy_env = ProxyEnvironment( + headers={"User-Agent": ua}, retries=retry_strategy + ) def _fetch(self, url: str) -> Iterator[bytes]: """Fetch the contents of HTTP/HTTPS url from a remote server. @@ -82,6 +97,7 @@ def _fetch(self, url: str) -> Iterator[bytes]: except urllib3.exceptions.MaxRetryError as e: if isinstance(e.reason, urllib3.exceptions.TimeoutError): raise exceptions.SlowRetrievalError from e + raise if response.status >= 400: response.close() @@ -106,6 +122,59 @@ def _chunks( except urllib3.exceptions.MaxRetryError as e: if isinstance(e.reason, urllib3.exceptions.TimeoutError): raise exceptions.SlowRetrievalError from e + raise + except ( + urllib3.exceptions.ReadTimeoutError, + urllib3.exceptions.ProtocolError, + ) as e: + raise exceptions.SlowRetrievalError from e finally: response.release_conn() + + def download_bytes(self, url: str, max_length: int) -> bytes: + """Download bytes from given ``url`` with retry on streaming failures. + + This override adds retry logic for mid-stream timeout and connection + errors that are not automatically retried by urllib3. + + Args: + url: URL string that represents the location of the file. + max_length: Upper bound of data size in bytes. + + Raises: + exceptions.DownloadError: An error occurred during download. + exceptions.DownloadLengthMismatchError: Downloaded bytes exceed + ``max_length``. + exceptions.DownloadHTTPError: An HTTP error code was received. + + Returns: + Content of the file in bytes. + """ + max_retries = 3 + last_exception: Exception | None = None + + for attempt in range(max_retries): + try: + return super().download_bytes(url, max_length) + except exceptions.SlowRetrievalError as e: + last_exception = e + if attempt < max_retries - 1: + logger.debug( + "Retrying download after streaming error " + "(attempt %d/%d): %s", + attempt + 1, + max_retries, + url, + ) + continue + raise + except ( + exceptions.DownloadHTTPError, + exceptions.DownloadLengthMismatchError, + ): + raise + + if last_exception: + raise last_exception + raise exceptions.DownloadError(f"Failed to download {url}")