Skip to content

Commit a980aa0

Browse files
authored
Fix Vercel AI tool input/output always showing up as a JSON string rather than object (pydantic#3399)
1 parent 53962e5 commit a980aa0

File tree

4 files changed

+48
-49
lines changed

4 files changed

+48
-49
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,10 +776,11 @@ def model_response_str(self) -> str:
776776
def model_response_object(self) -> dict[str, Any]:
777777
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
778778
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
779-
if isinstance(self.content, dict):
780-
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
779+
json_content = tool_return_ta.dump_python(self.content, mode='json')
780+
if isinstance(json_content, dict):
781+
return json_content # type: ignore[reportUnknownReturn]
781782
else:
782-
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
783+
return {'return_value': json_content}
783784

784785
def otel_event(self, settings: InstrumentationSettings) -> Event:
785786
return Event(

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic_core import to_json
1010

1111
from ...messages import (
12+
BaseToolReturnPart,
1213
BuiltinToolCallPart,
1314
BuiltinToolReturnPart,
1415
FilePart,
@@ -155,21 +156,23 @@ async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterato
155156
)
156157

157158
async def handle_tool_call_end(self, part: ToolCallPart) -> AsyncIterator[BaseChunk]:
158-
yield ToolInputAvailableChunk(tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=part.args)
159+
yield ToolInputAvailableChunk(
160+
tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=part.args_as_dict()
161+
)
159162

160163
async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseChunk]:
161164
yield ToolInputAvailableChunk(
162165
tool_call_id=part.tool_call_id,
163166
tool_name=part.tool_name,
164-
input=part.args,
167+
input=part.args_as_dict(),
165168
provider_executed=True,
166169
provider_metadata={'pydantic_ai': {'provider_name': part.provider_name}},
167170
)
168171

169172
async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[BaseChunk]:
170173
yield ToolOutputAvailableChunk(
171174
tool_call_id=part.tool_call_id,
172-
output=part.content,
175+
output=self._tool_return_output(part),
173176
provider_executed=True,
174177
)
175178

@@ -178,10 +181,15 @@ async def handle_file(self, part: FilePart) -> AsyncIterator[BaseChunk]:
178181
yield FileChunk(url=file.data_uri, media_type=file.media_type)
179182

180183
async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> AsyncIterator[BaseChunk]:
181-
result = event.result
182-
if isinstance(result, RetryPromptPart):
183-
yield ToolOutputErrorChunk(tool_call_id=result.tool_call_id, error_text=result.model_response())
184+
part = event.result
185+
if isinstance(part, RetryPromptPart):
186+
yield ToolOutputErrorChunk(tool_call_id=part.tool_call_id, error_text=part.model_response())
184187
else:
185-
yield ToolOutputAvailableChunk(tool_call_id=result.tool_call_id, output=result.content)
188+
yield ToolOutputAvailableChunk(tool_call_id=part.tool_call_id, output=self._tool_return_output(part))
186189

187190
# ToolCallResultEvent.content may hold user parts (e.g. text, images) that Vercel AI does not currently have events for
191+
192+
def _tool_return_output(self, part: BaseToolReturnPart) -> Any:
193+
output = part.model_response_object()
194+
# Unwrap the return value from the output dictionary if it exists
195+
return output.get('return_value', output)

tests/test_agent.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3724,13 +3724,11 @@ def test_tool_return_part_binary_content_serialization():
37243724

37253725
assert tool_return.model_response_object() == snapshot(
37263726
{
3727-
'return_value': {
3728-
'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=',
3729-
'media_type': 'image/png',
3730-
'vendor_metadata': None,
3731-
'_identifier': None,
3732-
'kind': 'binary',
3733-
}
3727+
'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=',
3728+
'media_type': 'image/png',
3729+
'vendor_metadata': None,
3730+
'_identifier': None,
3731+
'kind': 'binary',
37343732
}
37353733
)
37363734

tests/test_vercel_ai.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,14 +1167,12 @@ async def stream_function(
11671167
yield {
11681168
1: BuiltinToolReturnPart(
11691169
tool_name=WebSearchTool.kind,
1170-
content={
1171-
'results': [
1172-
{
1173-
'title': '"Hello, World!" program',
1174-
'url': 'https://en.wikipedia.org/wiki/%22Hello,_World!%22_program',
1175-
}
1176-
]
1177-
},
1170+
content=[
1171+
{
1172+
'title': '"Hello, World!" program',
1173+
'url': 'https://en.wikipedia.org/wiki/%22Hello,_World!%22_program',
1174+
}
1175+
],
11781176
tool_call_id='search_1',
11791177
provider_name='function',
11801178
)
@@ -1210,21 +1208,19 @@ async def stream_function(
12101208
'type': 'tool-input-available',
12111209
'toolCallId': 'search_1',
12121210
'toolName': 'web_search',
1213-
'input': '{"query":"Hello world"}',
1211+
'input': {'query': 'Hello world'},
12141212
'providerExecuted': True,
12151213
'providerMetadata': {'pydantic_ai': {'provider_name': 'function'}},
12161214
},
12171215
{
12181216
'type': 'tool-output-available',
12191217
'toolCallId': 'search_1',
1220-
'output': {
1221-
'results': [
1222-
{
1223-
'title': '"Hello, World!" program',
1224-
'url': 'https://en.wikipedia.org/wiki/%22Hello,_World!%22_program',
1225-
}
1226-
]
1227-
},
1218+
'output': [
1219+
{
1220+
'title': '"Hello, World!" program',
1221+
'url': 'https://en.wikipedia.org/wiki/%22Hello,_World!%22_program',
1222+
}
1223+
],
12281224
'providerExecuted': True,
12291225
},
12301226
{'type': 'text-start', 'id': IsStr()},
@@ -1302,7 +1298,7 @@ async def web_search(query: str) -> dict[str, list[dict[str, str]]]:
13021298
'type': 'tool-input-available',
13031299
'toolCallId': 'search_1',
13041300
'toolName': 'web_search',
1305-
'input': '{"query":"Hello world"}',
1301+
'input': {'query': 'Hello world'},
13061302
},
13071303
{
13081304
'type': 'tool-output-available',
@@ -1421,9 +1417,13 @@ def web_search(query: str) -> dict[str, list[dict[str, str]]]:
14211417
'type': 'tool-input-available',
14221418
'toolCallId': 'search_1',
14231419
'toolName': 'final_result',
1424-
'input': '{"query":"Hello world"}',
1420+
'input': {'query': 'Hello world'},
1421+
},
1422+
{
1423+
'type': 'tool-output-available',
1424+
'toolCallId': 'search_1',
1425+
'output': 'Final result processed.',
14251426
},
1426-
{'type': 'tool-output-available', 'toolCallId': 'search_1', 'output': 'Final result processed.'},
14271427
{'type': 'finish-step'},
14281428
{'type': 'finish'},
14291429
'[DONE]',
@@ -1468,11 +1468,7 @@ async def stream_function(
14681468
'toolCallId': IsStr(),
14691469
'toolName': 'unknown_tool',
14701470
},
1471-
{
1472-
'type': 'tool-input-available',
1473-
'toolCallId': IsStr(),
1474-
'toolName': 'unknown_tool',
1475-
},
1471+
{'type': 'tool-input-available', 'toolCallId': IsStr(), 'toolName': 'unknown_tool', 'input': {}},
14761472
{
14771473
'type': 'tool-output-error',
14781474
'toolCallId': IsStr(),
@@ -1489,11 +1485,7 @@ async def stream_function(
14891485
'toolCallId': IsStr(),
14901486
'toolName': 'unknown_tool',
14911487
},
1492-
{
1493-
'type': 'tool-input-available',
1494-
'toolCallId': IsStr(),
1495-
'toolName': 'unknown_tool',
1496-
},
1488+
{'type': 'tool-input-available', 'toolCallId': IsStr(), 'toolName': 'unknown_tool', 'input': {}},
14971489
{'type': 'error', 'errorText': 'Exceeded maximum retries (1) for output validation'},
14981490
{'type': 'finish-step'},
14991491
{'type': 'finish'},
@@ -1832,7 +1824,7 @@ async def test_adapter_load_messages():
18321824
UserPromptPart(
18331825
content=[
18341826
'Here are some files:',
1835-
BinaryImage(data=b'fake', media_type='image/png'),
1827+
BinaryImage(data=b'fake', media_type='image/png', _identifier='c053ec'),
18361828
ImageUrl(url='https://example.com/image.png', _media_type='image/png'),
18371829
VideoUrl(url='https://example.com/video.mp4', _media_type='video/mp4'),
18381830
AudioUrl(url='https://example.com/audio.mp3', _media_type='audio/mpeg'),
@@ -1846,7 +1838,7 @@ async def test_adapter_load_messages():
18461838
parts=[
18471839
ThinkingPart(content='I should tell the user how nice those files are and share another one'),
18481840
TextPart(content='Nice files, here is another one:'),
1849-
FilePart(content=BinaryImage(data=b'fake', media_type='image/png')),
1841+
FilePart(content=BinaryImage(data=b'fake', media_type='image/png', _identifier='c053ec')),
18501842
],
18511843
timestamp=IsDatetime(),
18521844
),

0 commit comments

Comments
 (0)