Skip to content

Commit 39e8a10

Browse files
authored
retry on JSONDecoderError (#4)
1 parent 173c915 commit 39e8a10

File tree

2 files changed

+179
-3
lines changed

2 files changed

+179
-3
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations as _annotations
22

3+
import asyncio
34
import base64
5+
import json
46
from collections.abc import AsyncIterator, Awaitable
57
from contextlib import asynccontextmanager
68
from dataclasses import dataclass, field
@@ -119,6 +121,18 @@ class GoogleModelSettings(ModelSettings, total=False):
119121
See <https://ai.google.dev/api/generate-content#MediaResolution> for more information.
120122
"""
121123

124+
google_json_retry_max_attempts: int
125+
"""Maximum number of retry attempts for JSON decode errors during streaming.
126+
127+
Defaults to 3. Set to 0 to disable retries.
128+
"""
129+
130+
google_json_retry_base_delay: float
131+
"""Base delay in seconds for exponential backoff between JSON decode error retries.
132+
133+
Defaults to 1.0. The actual delay will be base_delay * (2 ** attempt_number).
134+
"""
135+
122136

123137
@dataclass(init=False)
124138
class GoogleModel(Model):
@@ -301,7 +315,26 @@ async def _generate_content(
301315
)
302316

303317
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
304-
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
318+
319+
# Get retry configuration with defaults
320+
max_attempts = model_settings.get('google_json_retry_max_attempts', 3)
321+
base_delay = model_settings.get('google_json_retry_base_delay', 1.0)
322+
323+
# Retry loop for JSON decode errors
324+
json_decode_error = None
325+
for attempt in range(max_attempts):
326+
try:
327+
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
328+
except json.JSONDecodeError as error:
329+
json_decode_error = error
330+
delay = base_delay * (2**attempt)
331+
await asyncio.sleep(delay)
332+
333+
if json_decode_error is None:
334+
raise UnexpectedModelBehavior(
335+
'JSON retry loop completed without encountering any JSON decode errors. This should not happen if max_attempts > 0.'
336+
)
337+
raise json_decode_error
305338

306339
def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
307340
if not response.candidates or len(response.candidates) != 1:

tests/models/test_google.py

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations as _annotations
22

33
import datetime
4+
import json
45
import os
5-
from typing import Any
6+
from dataclasses import dataclass
7+
from typing import Any, cast
8+
from unittest.mock import AsyncMock, Mock, patch
69

710
import pytest
811
from httpx import Timeout
@@ -35,7 +38,8 @@
3538
UserPromptPart,
3639
VideoUrl,
3740
)
38-
from pydantic_ai.models.google import GoogleModel
41+
from pydantic_ai.models import ModelRequestParameters
42+
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
3943
from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
4044
from pydantic_ai.result import Usage
4145

@@ -1445,3 +1449,142 @@ async def test_google_model_function_call_without_text(google_provider: GooglePr
14451449
# Second part should be the added minimal text
14461450
assert 'text' in parts[1]
14471451
assert parts[1]['text'] == 'I have completed the function calls above.'
1452+
1453+
1454+
@dataclass
1455+
class JsonRetryTestCase:
1456+
"""Test case for JSON retry scenarios."""
1457+
1458+
id: str
1459+
side_effect: list[json.JSONDecodeError | Mock]
1460+
settings: GoogleModelSettings
1461+
expected_exception: type[Exception] | None
1462+
expected_call_count: int
1463+
expected_sleep_count: int
1464+
expected_sleep_calls: list[float]
1465+
expected_result: str | None
1466+
1467+
1468+
@pytest.mark.parametrize(
1469+
'test_case',
1470+
[
1471+
JsonRetryTestCase(
1472+
id='success_after_retry',
1473+
side_effect=[json.JSONDecodeError('test error', 'doc', 0), Mock()],
1474+
settings=cast(GoogleModelSettings, {}),
1475+
expected_exception=None,
1476+
expected_call_count=2,
1477+
expected_sleep_count=1,
1478+
expected_sleep_calls=[1.0],
1479+
expected_result='success',
1480+
),
1481+
JsonRetryTestCase(
1482+
id='max_attempts_reached',
1483+
side_effect=[json.JSONDecodeError('test error', 'doc', 0), json.JSONDecodeError('test error', 'doc', 0)],
1484+
settings=cast(GoogleModelSettings, {'google_json_retry_max_attempts': 2}),
1485+
expected_exception=json.JSONDecodeError,
1486+
expected_call_count=2,
1487+
expected_sleep_count=2,
1488+
expected_sleep_calls=[1.0, 2.0],
1489+
expected_result=None,
1490+
),
1491+
JsonRetryTestCase(
1492+
id='custom_settings',
1493+
side_effect=[json.JSONDecodeError('test error', 'doc', 0), Mock()],
1494+
settings=cast(
1495+
GoogleModelSettings, {'google_json_retry_max_attempts': 5, 'google_json_retry_base_delay': 0.5}
1496+
),
1497+
expected_exception=None,
1498+
expected_call_count=2,
1499+
expected_sleep_count=1,
1500+
expected_sleep_calls=[0.5],
1501+
expected_result='success',
1502+
),
1503+
JsonRetryTestCase(
1504+
id='retry_disabled',
1505+
side_effect=[json.JSONDecodeError('test error', 'doc', 0)],
1506+
settings=cast(GoogleModelSettings, {'google_json_retry_max_attempts': 0}),
1507+
expected_exception=UnexpectedModelBehavior,
1508+
expected_call_count=0,
1509+
expected_sleep_count=0,
1510+
expected_sleep_calls=[],
1511+
expected_result=None,
1512+
),
1513+
],
1514+
ids=lambda test_case: test_case.id,
1515+
)
1516+
async def test_google_model_json_retry_scenarios(
1517+
google_provider: GoogleProvider,
1518+
test_case: JsonRetryTestCase,
1519+
):
1520+
"""Test various JSON retry scenarios with parameterized inputs."""
1521+
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
1522+
1523+
# Create mock response for success cases
1524+
expected_result = test_case.expected_result
1525+
side_effect = test_case.side_effect
1526+
if expected_result == 'success':
1527+
mock_response = Mock()
1528+
if isinstance(side_effect, list):
1529+
side_effect[1] = mock_response # Replace placeholder with actual mock
1530+
expected_result = mock_response
1531+
1532+
mock_func = AsyncMock()
1533+
mock_func.side_effect = side_effect
1534+
1535+
with patch.object(model.client.aio.models, 'generate_content', mock_func):
1536+
with patch('asyncio.sleep') as mock_sleep:
1537+
if test_case.expected_exception:
1538+
with pytest.raises(test_case.expected_exception):
1539+
await model._generate_content( # pyright: ignore[reportPrivateUsage]
1540+
messages=[ModelRequest(parts=[UserPromptPart(content='test')])],
1541+
stream=False,
1542+
model_settings=test_case.settings,
1543+
model_request_parameters=ModelRequestParameters(
1544+
function_tools=[], output_tools=[], allow_text_output=True
1545+
),
1546+
)
1547+
else:
1548+
result = await model._generate_content( # pyright: ignore[reportPrivateUsage]
1549+
messages=[ModelRequest(parts=[UserPromptPart(content='test')])],
1550+
stream=False,
1551+
model_settings=test_case.settings,
1552+
model_request_parameters=ModelRequestParameters(
1553+
function_tools=[], output_tools=[], allow_text_output=True
1554+
),
1555+
)
1556+
assert result == expected_result
1557+
1558+
# Verify call counts
1559+
assert mock_func.call_count == test_case.expected_call_count
1560+
assert mock_sleep.call_count == test_case.expected_sleep_count
1561+
1562+
# Verify sleep calls
1563+
for expected_delay in test_case.expected_sleep_calls:
1564+
mock_sleep.assert_any_call(expected_delay)
1565+
1566+
1567+
async def test_google_model_non_json_error_not_retried(google_provider: GoogleProvider):
1568+
"""Test that non-JSON errors are not retried."""
1569+
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
1570+
1571+
# Mock the Google client to raise a different type of error
1572+
mock_func = AsyncMock()
1573+
mock_func.side_effect = ValueError('some other error')
1574+
1575+
with patch.object(model.client.aio.models, 'generate_content', mock_func):
1576+
with patch('asyncio.sleep') as mock_sleep:
1577+
with pytest.raises(ValueError):
1578+
await model._generate_content( # pyright: ignore[reportPrivateUsage]
1579+
messages=[ModelRequest(parts=[UserPromptPart(content='test')])],
1580+
stream=False,
1581+
model_settings={},
1582+
model_request_parameters=ModelRequestParameters(
1583+
function_tools=[], output_tools=[], allow_text_output=True
1584+
),
1585+
)
1586+
1587+
# Should have made only 1 call (no retries for non-JSON errors)
1588+
assert mock_func.call_count == 1
1589+
# Should not have slept at all
1590+
assert mock_sleep.call_count == 0

0 commit comments

Comments
 (0)