|
1 | 1 | from __future__ import annotations as _annotations
|
2 | 2 |
|
3 | 3 | import datetime
|
| 4 | +import json |
4 | 5 | import os
|
5 |
| -from typing import Any |
| 6 | +from dataclasses import dataclass |
| 7 | +from typing import Any, cast |
| 8 | +from unittest.mock import AsyncMock, Mock, patch |
6 | 9 |
|
7 | 10 | import pytest
|
8 | 11 | from httpx import Timeout
|
|
35 | 38 | UserPromptPart,
|
36 | 39 | VideoUrl,
|
37 | 40 | )
|
38 |
| -from pydantic_ai.models.google import GoogleModel |
| 41 | +from pydantic_ai.models import ModelRequestParameters |
| 42 | +from pydantic_ai.models.google import GoogleModel, GoogleModelSettings |
39 | 43 | from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
|
40 | 44 | from pydantic_ai.result import Usage
|
41 | 45 |
|
@@ -1445,3 +1449,142 @@ async def test_google_model_function_call_without_text(google_provider: GooglePr
|
1445 | 1449 | # Second part should be the added minimal text
|
1446 | 1450 | assert 'text' in parts[1]
|
1447 | 1451 | assert parts[1]['text'] == 'I have completed the function calls above.'
|
| 1452 | + |
| 1453 | + |
| 1454 | +@dataclass |
| 1455 | +class JsonRetryTestCase: |
| 1456 | + """Test case for JSON retry scenarios.""" |
| 1457 | + |
| 1458 | + id: str |
| 1459 | + side_effect: list[json.JSONDecodeError | Mock] |
| 1460 | + settings: GoogleModelSettings |
| 1461 | + expected_exception: type[Exception] | None |
| 1462 | + expected_call_count: int |
| 1463 | + expected_sleep_count: int |
| 1464 | + expected_sleep_calls: list[float] |
| 1465 | + expected_result: str | None |
| 1466 | + |
| 1467 | + |
| 1468 | +@pytest.mark.parametrize( |
| 1469 | + 'test_case', |
| 1470 | + [ |
| 1471 | + JsonRetryTestCase( |
| 1472 | + id='success_after_retry', |
| 1473 | + side_effect=[json.JSONDecodeError('test error', 'doc', 0), Mock()], |
| 1474 | + settings=cast(GoogleModelSettings, {}), |
| 1475 | + expected_exception=None, |
| 1476 | + expected_call_count=2, |
| 1477 | + expected_sleep_count=1, |
| 1478 | + expected_sleep_calls=[1.0], |
| 1479 | + expected_result='success', |
| 1480 | + ), |
| 1481 | + JsonRetryTestCase( |
| 1482 | + id='max_attempts_reached', |
| 1483 | + side_effect=[json.JSONDecodeError('test error', 'doc', 0), json.JSONDecodeError('test error', 'doc', 0)], |
| 1484 | + settings=cast(GoogleModelSettings, {'google_json_retry_max_attempts': 2}), |
| 1485 | + expected_exception=json.JSONDecodeError, |
| 1486 | + expected_call_count=2, |
| 1487 | + expected_sleep_count=2, |
| 1488 | + expected_sleep_calls=[1.0, 2.0], |
| 1489 | + expected_result=None, |
| 1490 | + ), |
| 1491 | + JsonRetryTestCase( |
| 1492 | + id='custom_settings', |
| 1493 | + side_effect=[json.JSONDecodeError('test error', 'doc', 0), Mock()], |
| 1494 | + settings=cast( |
| 1495 | + GoogleModelSettings, {'google_json_retry_max_attempts': 5, 'google_json_retry_base_delay': 0.5} |
| 1496 | + ), |
| 1497 | + expected_exception=None, |
| 1498 | + expected_call_count=2, |
| 1499 | + expected_sleep_count=1, |
| 1500 | + expected_sleep_calls=[0.5], |
| 1501 | + expected_result='success', |
| 1502 | + ), |
| 1503 | + JsonRetryTestCase( |
| 1504 | + id='retry_disabled', |
| 1505 | + side_effect=[json.JSONDecodeError('test error', 'doc', 0)], |
| 1506 | + settings=cast(GoogleModelSettings, {'google_json_retry_max_attempts': 0}), |
| 1507 | + expected_exception=UnexpectedModelBehavior, |
| 1508 | + expected_call_count=0, |
| 1509 | + expected_sleep_count=0, |
| 1510 | + expected_sleep_calls=[], |
| 1511 | + expected_result=None, |
| 1512 | + ), |
| 1513 | + ], |
| 1514 | + ids=lambda test_case: test_case.id, |
| 1515 | +) |
| 1516 | +async def test_google_model_json_retry_scenarios( |
| 1517 | + google_provider: GoogleProvider, |
| 1518 | + test_case: JsonRetryTestCase, |
| 1519 | +): |
| 1520 | + """Test various JSON retry scenarios with parameterized inputs.""" |
| 1521 | + model = GoogleModel('gemini-1.5-flash', provider=google_provider) |
| 1522 | + |
| 1523 | + # Create mock response for success cases |
| 1524 | + expected_result = test_case.expected_result |
| 1525 | + side_effect = test_case.side_effect |
| 1526 | + if expected_result == 'success': |
| 1527 | + mock_response = Mock() |
| 1528 | + if isinstance(side_effect, list): |
| 1529 | + side_effect[1] = mock_response # Replace placeholder with actual mock |
| 1530 | + expected_result = mock_response |
| 1531 | + |
| 1532 | + mock_func = AsyncMock() |
| 1533 | + mock_func.side_effect = side_effect |
| 1534 | + |
| 1535 | + with patch.object(model.client.aio.models, 'generate_content', mock_func): |
| 1536 | + with patch('asyncio.sleep') as mock_sleep: |
| 1537 | + if test_case.expected_exception: |
| 1538 | + with pytest.raises(test_case.expected_exception): |
| 1539 | + await model._generate_content( # pyright: ignore[reportPrivateUsage] |
| 1540 | + messages=[ModelRequest(parts=[UserPromptPart(content='test')])], |
| 1541 | + stream=False, |
| 1542 | + model_settings=test_case.settings, |
| 1543 | + model_request_parameters=ModelRequestParameters( |
| 1544 | + function_tools=[], output_tools=[], allow_text_output=True |
| 1545 | + ), |
| 1546 | + ) |
| 1547 | + else: |
| 1548 | + result = await model._generate_content( # pyright: ignore[reportPrivateUsage] |
| 1549 | + messages=[ModelRequest(parts=[UserPromptPart(content='test')])], |
| 1550 | + stream=False, |
| 1551 | + model_settings=test_case.settings, |
| 1552 | + model_request_parameters=ModelRequestParameters( |
| 1553 | + function_tools=[], output_tools=[], allow_text_output=True |
| 1554 | + ), |
| 1555 | + ) |
| 1556 | + assert result == expected_result |
| 1557 | + |
| 1558 | + # Verify call counts |
| 1559 | + assert mock_func.call_count == test_case.expected_call_count |
| 1560 | + assert mock_sleep.call_count == test_case.expected_sleep_count |
| 1561 | + |
| 1562 | + # Verify sleep calls |
| 1563 | + for expected_delay in test_case.expected_sleep_calls: |
| 1564 | + mock_sleep.assert_any_call(expected_delay) |
| 1565 | + |
| 1566 | + |
| 1567 | +async def test_google_model_non_json_error_not_retried(google_provider: GoogleProvider): |
| 1568 | + """Test that non-JSON errors are not retried.""" |
| 1569 | + model = GoogleModel('gemini-1.5-flash', provider=google_provider) |
| 1570 | + |
| 1571 | + # Mock the Google client to raise a different type of error |
| 1572 | + mock_func = AsyncMock() |
| 1573 | + mock_func.side_effect = ValueError('some other error') |
| 1574 | + |
| 1575 | + with patch.object(model.client.aio.models, 'generate_content', mock_func): |
| 1576 | + with patch('asyncio.sleep') as mock_sleep: |
| 1577 | + with pytest.raises(ValueError): |
| 1578 | + await model._generate_content( # pyright: ignore[reportPrivateUsage] |
| 1579 | + messages=[ModelRequest(parts=[UserPromptPart(content='test')])], |
| 1580 | + stream=False, |
| 1581 | + model_settings={}, |
| 1582 | + model_request_parameters=ModelRequestParameters( |
| 1583 | + function_tools=[], output_tools=[], allow_text_output=True |
| 1584 | + ), |
| 1585 | + ) |
| 1586 | + |
| 1587 | + # Should have made only 1 call (no retries for non-JSON errors) |
| 1588 | + assert mock_func.call_count == 1 |
| 1589 | + # Should not have slept at all |
| 1590 | + assert mock_sleep.call_count == 0 |
0 commit comments