Skip to content

Commit 65f93f3

Browse files
committed
feat: Implement async generator tools
Enable decorated tools to be an async generator, enabling streaming of tool events back to to the caller.
1 parent 5ca7ac3 commit 65f93f3

File tree

3 files changed

+155
-19
lines changed

3 files changed

+155
-19
lines changed

src/strands/types/_events.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,25 +275,20 @@ def is_callback_event(self) -> bool:
275275
class ToolStreamEvent(TypedEvent):
276276
"""Event emitted when a tool yields sub-events as part of tool execution."""
277277

278-
def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None:
278+
def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None:
279279
"""Initialize with tool streaming data.
280280
281281
Args:
282282
tool_use: The tool invocation producing the stream
283-
tool_sub_event: The yielded event from the tool execution
283+
tool_stream_data: The yielded event from the tool execution
284284
"""
285-
super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event})
285+
super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_data": tool_stream_data})
286286

287287
@property
288288
def tool_use_id(self) -> str:
289289
"""The toolUseId associated with this stream."""
290290
return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId"))
291291

292-
@property
293-
@override
294-
def is_callback_event(self) -> bool:
295-
return False
296-
297292

298293
class ModelMessageEvent(TypedEvent):
299294
"""Event emitted when the model invocation has completed.

tests/strands/agent/hooks/test_agent_events.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,18 +260,18 @@ async def test_stream_e2e_success(alist):
260260
"role": "assistant",
261261
}
262262
},
263+
{
264+
"tool_stream_data": {"tool_streaming": True},
265+
"tool_stream_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
266+
},
267+
{
268+
"tool_stream_data": "Final result",
269+
"tool_stream_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
270+
},
263271
{
264272
"message": {
265273
"content": [
266-
{
267-
"toolResult": {
268-
# TODO update this text when we get tool streaming implemented; right now this
269-
# TODO is of the form '<async_generator object streaming_tool at 0x107d18a00>'
270-
"content": [{"text": ANY}],
271-
"status": "success",
272-
"toolUseId": "12345",
273-
}
274-
},
274+
{"toolResult": {"content": [{"text": "Final result"}], "status": "success", "toolUseId": "12345"}}
275275
],
276276
"role": "user",
277277
}

tests/strands/tools/test_decorator.py

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
"""
44

55
from asyncio import Queue
6-
from typing import Any, Dict, Optional, Union
6+
from typing import Any, AsyncGenerator, Dict, Optional, Union
77
from unittest.mock import MagicMock
88

99
import pytest
1010

1111
import strands
1212
from strands import Agent
13-
from strands.types._events import ToolResultEvent
13+
from strands.types._events import ToolResultEvent, ToolStreamEvent
1414
from strands.types.tools import AgentTool, ToolContext, ToolUse
1515

1616

@@ -1222,3 +1222,144 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str:
12221222
"toolUseId": "test-id-2",
12231223
}
12241224
)
1225+
1226+
1227+
@pytest.mark.asyncio
1228+
async def test_tool_async_generator():
1229+
"""Test that async generators yield results appropriately."""
1230+
1231+
@strands.tool(context=False)
1232+
async def async_generator() -> AsyncGenerator:
1233+
"""Tool that expects tool_context as a regular string parameter."""
1234+
yield 0
1235+
yield "Value 1"
1236+
yield {"nested": "value"}
1237+
yield {
1238+
"status": "success",
1239+
"content": [{"text": "Looks like tool result"}],
1240+
"toolUseId": "test-id-2",
1241+
}
1242+
yield "final result"
1243+
1244+
tool: AgentTool = async_generator
1245+
tool_use: ToolUse = {
1246+
"toolUseId": "test-id-2",
1247+
"name": "context_tool",
1248+
"input": {"message": "some_message", "tool_context": "my_custom_context_string"},
1249+
}
1250+
generator = tool.stream(
1251+
tool_use=tool_use,
1252+
invocation_state={
1253+
"agent": Agent(name="test_agent"),
1254+
},
1255+
)
1256+
act_results = [value async for value in generator]
1257+
exp_results = [
1258+
ToolStreamEvent(tool_use, 0),
1259+
ToolStreamEvent(tool_use, "Value 1"),
1260+
ToolStreamEvent(tool_use, {"nested": "value"}),
1261+
ToolStreamEvent(
1262+
tool_use,
1263+
{
1264+
"status": "success",
1265+
"content": [{"text": "Looks like tool result"}],
1266+
"toolUseId": "test-id-2",
1267+
},
1268+
),
1269+
ToolStreamEvent(tool_use, "final result"),
1270+
ToolResultEvent(
1271+
{
1272+
"status": "success",
1273+
"content": [{"text": "final result"}],
1274+
"toolUseId": "test-id-2",
1275+
}
1276+
),
1277+
]
1278+
1279+
assert act_results == exp_results
1280+
1281+
1282+
@pytest.mark.asyncio
1283+
async def test_tool_async_generator_exceptions_result_in_error():
1284+
"""Test that async generators handle exceptions."""
1285+
1286+
@strands.tool(context=False)
1287+
async def async_generator() -> AsyncGenerator:
1288+
"""Tool that expects tool_context as a regular string parameter."""
1289+
yield 13
1290+
raise ValueError("It's an error!")
1291+
1292+
tool: AgentTool = async_generator
1293+
tool_use: ToolUse = {
1294+
"toolUseId": "test-id-2",
1295+
"name": "context_tool",
1296+
"input": {"message": "some_message", "tool_context": "my_custom_context_string"},
1297+
}
1298+
generator = tool.stream(
1299+
tool_use=tool_use,
1300+
invocation_state={
1301+
"agent": Agent(name="test_agent"),
1302+
},
1303+
)
1304+
act_results = [value async for value in generator]
1305+
exp_results = [
1306+
ToolStreamEvent(tool_use, 13),
1307+
ToolResultEvent(
1308+
{
1309+
"status": "error",
1310+
"content": [{"text": "Error: It's an error!"}],
1311+
"toolUseId": "test-id-2",
1312+
}
1313+
),
1314+
]
1315+
1316+
assert act_results == exp_results
1317+
1318+
1319+
@pytest.mark.asyncio
1320+
async def test_tool_async_generator_yield_object_result():
1321+
"""Test that async generators handle exceptions."""
1322+
1323+
@strands.tool(context=False)
1324+
async def async_generator() -> AsyncGenerator:
1325+
"""Tool that expects tool_context as a regular string parameter."""
1326+
yield 13
1327+
yield {
1328+
"status": "success",
1329+
"content": [{"text": "final result"}],
1330+
"toolUseId": "test-id-2",
1331+
}
1332+
1333+
tool: AgentTool = async_generator
1334+
tool_use: ToolUse = {
1335+
"toolUseId": "test-id-2",
1336+
"name": "context_tool",
1337+
"input": {"message": "some_message", "tool_context": "my_custom_context_string"},
1338+
}
1339+
generator = tool.stream(
1340+
tool_use=tool_use,
1341+
invocation_state={
1342+
"agent": Agent(name="test_agent"),
1343+
},
1344+
)
1345+
act_results = [value async for value in generator]
1346+
exp_results = [
1347+
ToolStreamEvent(tool_use, 13),
1348+
ToolStreamEvent(
1349+
tool_use,
1350+
{
1351+
"status": "success",
1352+
"content": [{"text": "final result"}],
1353+
"toolUseId": "test-id-2",
1354+
},
1355+
),
1356+
ToolResultEvent(
1357+
{
1358+
"status": "success",
1359+
"content": [{"text": "final result"}],
1360+
"toolUseId": "test-id-2",
1361+
}
1362+
),
1363+
]
1364+
1365+
assert act_results == exp_results

0 commit comments

Comments
 (0)