Skip to content

Commit fd8387c

Browse files
authored
Wrap BedrockConverseModel errors in ModelHTTPError to work well with FallbackModel (pydantic#3377)
1 parent b8d2904 commit fd8387c

File tree

4 files changed

+106
-4
lines changed

4 files changed

+106
-4
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import anyio
1313
import anyio.to_thread
14+
from botocore.exceptions import ClientError
1415
from typing_extensions import ParamSpec, assert_never
1516

1617
from pydantic_ai import (
@@ -21,6 +22,7 @@
2122
DocumentUrl,
2223
FinishReason,
2324
ImageUrl,
25+
ModelHTTPError,
2426
ModelMessage,
2527
ModelProfileSpec,
2628
ModelRequest,
@@ -408,10 +410,16 @@ async def _messages_create(
408410
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
409411
params['promptVariables'] = prompt_variables
410412

411-
if stream:
412-
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
413-
else:
414-
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
413+
try:
414+
if stream:
415+
model_response = await anyio.to_thread.run_sync(
416+
functools.partial(self.client.converse_stream, **params)
417+
)
418+
else:
419+
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
420+
except ClientError as e:
421+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
422+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
415423
return model_response
416424

417425
@staticmethod
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "user", "content": [{"text": "hello"}]}], "system": [], "inferenceConfig": {}}'
4+
headers:
5+
amz-sdk-invocation-id:
6+
- !!binary |
7+
ZWIzNTA0MWYtOTNhZi00YTFmLTk3YjEtMzE0MTFiNjA4ZjU5
8+
amz-sdk-request:
9+
- !!binary |
10+
YXR0ZW1wdD0x
11+
content-length:
12+
- '101'
13+
content-type:
14+
- !!binary |
15+
YXBwbGljYXRpb24vanNvbg==
16+
method: POST
17+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.does-not-exist-model-v1%3A0/converse
18+
response:
19+
headers:
20+
connection:
21+
- keep-alive
22+
content-length:
23+
- '55'
24+
content-type:
25+
- application/json
26+
parsed_body:
27+
message: The provided model identifier is invalid.
28+
status:
29+
code: 400
30+
message: Bad Request
31+
version: 1
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "user", "content": [{"text": "hello"}]}], "system": [], "inferenceConfig": {}}'
4+
headers:
5+
amz-sdk-invocation-id:
6+
- !!binary |
7+
ZGQ5YWNhZjAtNjM4Mi00NjI5LTkwMWMtOGY4MWY1Yjc5OGYz
8+
amz-sdk-request:
9+
- !!binary |
10+
YXR0ZW1wdD0x
11+
content-length:
12+
- '101'
13+
content-type:
14+
- !!binary |
15+
YXBwbGljYXRpb24vanNvbg==
16+
method: POST
17+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.does-not-exist-model-v1%3A0/converse-stream
18+
response:
19+
headers:
20+
connection:
21+
- keep-alive
22+
content-length:
23+
- '55'
24+
content-type:
25+
- application/json
26+
parsed_body:
27+
message: The provided model identifier is invalid.
28+
status:
29+
code: 400
30+
message: Bad Request
31+
version: 1

tests/models/test_bedrock.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
FunctionToolCallEvent,
1515
FunctionToolResultEvent,
1616
ImageUrl,
17+
ModelHTTPError,
1718
ModelRequest,
1819
ModelResponse,
1920
PartDeltaEvent,
@@ -1376,3 +1377,34 @@ async def test_bedrock_model_stream_empty_text_delta(allow_model_requests: None,
13761377
PartEndEvent(index=1, part=TextPart(content='Hello! How can I help you today?')),
13771378
]
13781379
)
1380+
1381+
1382+
@pytest.mark.vcr()
1383+
async def test_bedrock_error(allow_model_requests: None, bedrock_provider: BedrockProvider):
1384+
"""Test that errors convert to ModelHTTPError."""
1385+
model_id = 'us.does-not-exist-model-v1:0'
1386+
model = BedrockConverseModel(model_id, provider=bedrock_provider)
1387+
agent = Agent(model)
1388+
1389+
with pytest.raises(ModelHTTPError) as exc_info:
1390+
await agent.run('hello')
1391+
1392+
assert exc_info.value.status_code == 400
1393+
assert exc_info.value.model_name == model_id
1394+
assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr]
1395+
1396+
1397+
@pytest.mark.vcr()
1398+
async def test_bedrock_streaming_error(allow_model_requests: None, bedrock_provider: BedrockProvider):
1399+
"""Test that errors during streaming convert to ModelHTTPError."""
1400+
model_id = 'us.does-not-exist-model-v1:0'
1401+
model = BedrockConverseModel(model_id, provider=bedrock_provider)
1402+
agent = Agent(model)
1403+
1404+
with pytest.raises(ModelHTTPError) as exc_info:
1405+
async with agent.run_stream('hello'):
1406+
pass
1407+
1408+
assert exc_info.value.status_code == 400
1409+
assert exc_info.value.model_name == model_id
1410+
assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr]

0 commit comments

Comments
 (0)