|
3 | 3 | """
|
4 | 4 |
|
5 | 5 | from asyncio import Queue
|
6 |
| -from typing import Any, Dict, Optional, Union |
| 6 | +from typing import Any, AsyncGenerator, Dict, Optional, Union |
7 | 7 | from unittest.mock import MagicMock
|
8 | 8 |
|
9 | 9 | import pytest
|
10 | 10 |
|
11 | 11 | import strands
|
12 | 12 | from strands import Agent
|
13 |
| -from strands.types._events import ToolResultEvent |
| 13 | +from strands.types._events import ToolResultEvent, ToolStreamEvent |
14 | 14 | from strands.types.tools import AgentTool, ToolContext, ToolUse
|
15 | 15 |
|
16 | 16 |
|
@@ -1222,3 +1222,144 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str:
|
1222 | 1222 | "toolUseId": "test-id-2",
|
1223 | 1223 | }
|
1224 | 1224 | )
|
| 1225 | + |
| 1226 | + |
| 1227 | +@pytest.mark.asyncio |
| 1228 | +async def test_tool_async_generator(): |
| 1229 | + """Test that async generators yield results appropriately.""" |
| 1230 | + |
| 1231 | + @strands.tool(context=False) |
| 1232 | + async def async_generator() -> AsyncGenerator: |
| 1233 | + """Tool that expects tool_context as a regular string parameter.""" |
| 1234 | + yield 0 |
| 1235 | + yield "Value 1" |
| 1236 | + yield {"nested": "value"} |
| 1237 | + yield { |
| 1238 | + "status": "success", |
| 1239 | + "content": [{"text": "Looks like tool result"}], |
| 1240 | + "toolUseId": "test-id-2", |
| 1241 | + } |
| 1242 | + yield "final result" |
| 1243 | + |
| 1244 | + tool: AgentTool = async_generator |
| 1245 | + tool_use: ToolUse = { |
| 1246 | + "toolUseId": "test-id-2", |
| 1247 | + "name": "context_tool", |
| 1248 | + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, |
| 1249 | + } |
| 1250 | + generator = tool.stream( |
| 1251 | + tool_use=tool_use, |
| 1252 | + invocation_state={ |
| 1253 | + "agent": Agent(name="test_agent"), |
| 1254 | + }, |
| 1255 | + ) |
| 1256 | + act_results = [value async for value in generator] |
| 1257 | + exp_results = [ |
| 1258 | + ToolStreamEvent(tool_use, 0), |
| 1259 | + ToolStreamEvent(tool_use, "Value 1"), |
| 1260 | + ToolStreamEvent(tool_use, {"nested": "value"}), |
| 1261 | + ToolStreamEvent( |
| 1262 | + tool_use, |
| 1263 | + { |
| 1264 | + "status": "success", |
| 1265 | + "content": [{"text": "Looks like tool result"}], |
| 1266 | + "toolUseId": "test-id-2", |
| 1267 | + }, |
| 1268 | + ), |
| 1269 | + ToolStreamEvent(tool_use, "final result"), |
| 1270 | + ToolResultEvent( |
| 1271 | + { |
| 1272 | + "status": "success", |
| 1273 | + "content": [{"text": "final result"}], |
| 1274 | + "toolUseId": "test-id-2", |
| 1275 | + } |
| 1276 | + ), |
| 1277 | + ] |
| 1278 | + |
| 1279 | + assert act_results == exp_results |
| 1280 | + |
| 1281 | + |
| 1282 | +@pytest.mark.asyncio |
| 1283 | +async def test_tool_async_generator_exceptions_result_in_error(): |
| 1284 | + """Test that async generators handle exceptions.""" |
| 1285 | + |
| 1286 | + @strands.tool(context=False) |
| 1287 | + async def async_generator() -> AsyncGenerator: |
| 1288 | + """Tool that expects tool_context as a regular string parameter.""" |
| 1289 | + yield 13 |
| 1290 | + raise ValueError("It's an error!") |
| 1291 | + |
| 1292 | + tool: AgentTool = async_generator |
| 1293 | + tool_use: ToolUse = { |
| 1294 | + "toolUseId": "test-id-2", |
| 1295 | + "name": "context_tool", |
| 1296 | + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, |
| 1297 | + } |
| 1298 | + generator = tool.stream( |
| 1299 | + tool_use=tool_use, |
| 1300 | + invocation_state={ |
| 1301 | + "agent": Agent(name="test_agent"), |
| 1302 | + }, |
| 1303 | + ) |
| 1304 | + act_results = [value async for value in generator] |
| 1305 | + exp_results = [ |
| 1306 | + ToolStreamEvent(tool_use, 13), |
| 1307 | + ToolResultEvent( |
| 1308 | + { |
| 1309 | + "status": "error", |
| 1310 | + "content": [{"text": "Error: It's an error!"}], |
| 1311 | + "toolUseId": "test-id-2", |
| 1312 | + } |
| 1313 | + ), |
| 1314 | + ] |
| 1315 | + |
| 1316 | + assert act_results == exp_results |
| 1317 | + |
| 1318 | + |
| 1319 | +@pytest.mark.asyncio |
| 1320 | +async def test_tool_async_generator_yield_object_result(): |
| 1321 | + """Test that async generators handle exceptions.""" |
| 1322 | + |
| 1323 | + @strands.tool(context=False) |
| 1324 | + async def async_generator() -> AsyncGenerator: |
| 1325 | + """Tool that expects tool_context as a regular string parameter.""" |
| 1326 | + yield 13 |
| 1327 | + yield { |
| 1328 | + "status": "success", |
| 1329 | + "content": [{"text": "final result"}], |
| 1330 | + "toolUseId": "test-id-2", |
| 1331 | + } |
| 1332 | + |
| 1333 | + tool: AgentTool = async_generator |
| 1334 | + tool_use: ToolUse = { |
| 1335 | + "toolUseId": "test-id-2", |
| 1336 | + "name": "context_tool", |
| 1337 | + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, |
| 1338 | + } |
| 1339 | + generator = tool.stream( |
| 1340 | + tool_use=tool_use, |
| 1341 | + invocation_state={ |
| 1342 | + "agent": Agent(name="test_agent"), |
| 1343 | + }, |
| 1344 | + ) |
| 1345 | + act_results = [value async for value in generator] |
| 1346 | + exp_results = [ |
| 1347 | + ToolStreamEvent(tool_use, 13), |
| 1348 | + ToolStreamEvent( |
| 1349 | + tool_use, |
| 1350 | + { |
| 1351 | + "status": "success", |
| 1352 | + "content": [{"text": "final result"}], |
| 1353 | + "toolUseId": "test-id-2", |
| 1354 | + }, |
| 1355 | + ), |
| 1356 | + ToolResultEvent( |
| 1357 | + { |
| 1358 | + "status": "success", |
| 1359 | + "content": [{"text": "final result"}], |
| 1360 | + "toolUseId": "test-id-2", |
| 1361 | + } |
| 1362 | + ), |
| 1363 | + ] |
| 1364 | + |
| 1365 | + assert act_results == exp_results |
0 commit comments