Skip to content

Commit 5ca7ac3

Browse files
committed
fix: don't emit ToolStream events for non generator functions
Our current implementation of AgentTool.stream() has a problem that we don't differentiate between intermediate streaming events and the final ToolResult events. Our only contract is that the last event *must be* be the tool result that is passed to the LLM. Our switch to Typed Events (#755) pushes us in the right direction but for backwards compatibility we can't update the signature of `AgentTool.stream()` (nor have we exposed externally TypedEvents yet). That means that if we implemented tool-streaming today, then callers would see non-generator functions yielding both a `ToolStreamEvent` and `ToolResultEvent` even though they're not actually streaming responses. To avoid the odd behavior noted above, we'll special-case SDK-defined functions by allowing them to emit `ToolStreamEvent` and `ToolResultEvent` types directly (bypassing our normal wrapping), since they have the knowledge of when tools are actually generators or not. There's no observable difference in behavior to callers (this is all internal behavior), but this means that when we switch the flip for Tool Streaming, non-generator tools will **not** emit ToolStreamEvents - at least for AgentTool implementations that are in the SDK.
1 parent cb4b7fb commit 5ca7ac3

File tree

11 files changed

+271
-142
lines changed

11 files changed

+271
-142
lines changed

src/strands/tools/decorator.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5353
Type,
5454
TypeVar,
5555
Union,
56+
cast,
5657
get_type_hints,
5758
overload,
5859
)
@@ -61,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6162
from pydantic import BaseModel, Field, create_model
6263
from typing_extensions import override
6364

64-
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse
65+
from ..types._events import ToolResultEvent, ToolStreamEvent
66+
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse
6567

6668
logger = logging.getLogger(__name__)
6769

@@ -454,43 +456,67 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
454456
# Inject special framework-provided parameters
455457
self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)
456458

457-
# "Too few arguments" expected, hence the type ignore
458-
if inspect.iscoroutinefunction(self._tool_func):
459+
# Note: "Too few arguments" expected for the _tool_func calls, hence the type ignore
460+
461+
# Async-generators, yield streaming events and final tool result
462+
if inspect.isasyncgenfunction(self._tool_func):
463+
sub_events = self._tool_func(**validated_input) # type: ignore
464+
async for sub_event in sub_events:
465+
yield ToolStreamEvent(tool_use, sub_event)
466+
467+
# The last event is the result
468+
yield self._wrap_tool_result(tool_use_id, sub_event)
469+
470+
# Async functions, yield only the result
471+
elif inspect.iscoroutinefunction(self._tool_func):
459472
result = await self._tool_func(**validated_input) # type: ignore
460-
else:
461-
result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore
473+
yield self._wrap_tool_result(tool_use_id, result)
462474

463-
# FORMAT THE RESULT for Strands Agent
464-
if isinstance(result, dict) and "status" in result and "content" in result:
465-
# Result is already in the expected format, just add toolUseId
466-
result["toolUseId"] = tool_use_id
467-
yield result
475+
# Other functions, yield only the result
468476
else:
469-
# Wrap any other return value in the standard format
470-
# Always include at least one content item for consistency
471-
yield {
472-
"toolUseId": tool_use_id,
473-
"status": "success",
474-
"content": [{"text": str(result)}],
475-
}
477+
result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore
478+
yield self._wrap_tool_result(tool_use_id, result)
476479

477480
except ValueError as e:
478481
# Special handling for validation errors
479482
error_msg = str(e)
480-
yield {
481-
"toolUseId": tool_use_id,
482-
"status": "error",
483-
"content": [{"text": f"Error: {error_msg}"}],
484-
}
483+
yield self._wrap_tool_result(
484+
tool_use_id,
485+
{
486+
"toolUseId": tool_use_id,
487+
"status": "error",
488+
"content": [{"text": f"Error: {error_msg}"}],
489+
},
490+
)
485491
except Exception as e:
486492
# Return error result with exception details for any other error
487493
error_type = type(e).__name__
488494
error_msg = str(e)
489-
yield {
490-
"toolUseId": tool_use_id,
491-
"status": "error",
492-
"content": [{"text": f"Error: {error_type} - {error_msg}"}],
493-
}
495+
yield self._wrap_tool_result(
496+
tool_use_id,
497+
{
498+
"toolUseId": tool_use_id,
499+
"status": "error",
500+
"content": [{"text": f"Error: {error_type} - {error_msg}"}],
501+
},
502+
)
503+
504+
def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent:
505+
# FORMAT THE RESULT for Strands Agent
506+
if isinstance(result, dict) and "status" in result and "content" in result:
507+
# Result is already in the expected format, just add toolUseId
508+
result["toolUseId"] = tool_use_d
509+
return ToolResultEvent(cast(ToolResult, result))
510+
else:
511+
# Wrap any other return value in the standard format
512+
# Always include at least one content item for consistency
513+
return ToolResultEvent(
514+
{
515+
"toolUseId": tool_use_d,
516+
"status": "success",
517+
"content": [{"text": str(result)}],
518+
}
519+
)
494520

495521
@property
496522
def supports_hot_reload(self) -> bool:

src/strands/tools/executors/_executor.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,20 @@ async def _stream(
119119
return
120120

121121
async for event in selected_tool.stream(tool_use, invocation_state, **kwargs):
122-
yield ToolStreamEvent(tool_use, event)
122+
# Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream()
123+
# so that we don't needlessly yield ToolStreamEvents for non-generator callbacks.
124+
# In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent
125+
# we yield it directly; all other cases (non-sdk AgentTools), we wrap events in
126+
# ToolStreamEvent and the last even is just the result
127+
128+
if isinstance(event, ToolResultEvent):
129+
# below the last "event" must point to the tool_result
130+
event = event.tool_result
131+
break
132+
elif isinstance(event, ToolStreamEvent):
133+
yield event
134+
else:
135+
yield ToolStreamEvent(tool_use, event)
123136

124137
result = cast(ToolResult, event)
125138

src/strands/tools/mcp/mcp_agent_tool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mcp.types import Tool as MCPTool
1212
from typing_extensions import override
1313

14+
from ...types._events import ToolResultEvent
1415
from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse
1516

1617
if TYPE_CHECKING:
@@ -96,4 +97,4 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
9697
name=self.tool_name,
9798
arguments=tool_use["input"],
9899
)
99-
yield result
100+
yield ToolResultEvent(result)

src/strands/tools/mcp/mcp_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mcp.shared.message import SessionMessage
1010
from typing_extensions import NotRequired
1111

12-
from strands.types.tools import ToolResult
12+
from ...types.tools import ToolResult
1313

1414
"""
1515
MCPTransport defines the interface for MCP transport implementations. This abstracts

src/strands/tools/tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from typing_extensions import override
1414

15+
from ..types._events import ToolResultEvent
1516
from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse
1617

1718
logger = logging.getLogger(__name__)
@@ -211,7 +212,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
211212
"""
212213
if inspect.iscoroutinefunction(self._tool_func):
213214
result = await self._tool_func(tool_use, **invocation_state)
215+
yield ToolResultEvent(result)
214216
else:
215217
result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state)
216-
217-
yield result
218+
yield ToolResultEvent(result)

tests/strands/tools/executors/test_concurrent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from strands.tools.executors import ConcurrentToolExecutor
4-
from strands.types._events import ToolResultEvent, ToolStreamEvent
4+
from strands.types._events import ToolResultEvent
55
from strands.types.tools import ToolUse
66

77

@@ -22,13 +22,11 @@ async def test_concurrent_executor_execute(
2222

2323
tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id)
2424
exp_events = [
25-
ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
2625
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
27-
ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
2826
ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
2927
]
3028
assert tru_events == exp_events
3129

3230
tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId"))
33-
exp_results = [exp_events[1].tool_result, exp_events[3].tool_result]
31+
exp_results = [exp_events[0].tool_result, exp_events[1].tool_result]
3432
assert tru_results == exp_results

tests/strands/tools/executors/test_executor.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest.mock
2+
from unittest.mock import MagicMock
23

34
import pytest
45

@@ -39,7 +40,6 @@ async def test_executor_stream_yields_result(
3940

4041
tru_events = await alist(stream)
4142
exp_events = [
42-
ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
4343
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
4444
]
4545
assert tru_events == exp_events
@@ -67,6 +67,76 @@ async def test_executor_stream_yields_result(
6767
assert tru_hook_events == exp_hook_events
6868

6969

70+
@pytest.mark.asyncio
71+
async def test_executor_stream_wraps_results(
72+
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator
73+
):
74+
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
75+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
76+
77+
weather_tool.stream = MagicMock()
78+
weather_tool.stream.return_value = agenerator(
79+
["value 1", {"nested": True}, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}]
80+
)
81+
82+
tru_events = await alist(stream)
83+
exp_events = [
84+
ToolStreamEvent(tool_use, "value 1"),
85+
ToolStreamEvent(tool_use, {"nested": True}),
86+
ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
87+
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
88+
]
89+
assert tru_events == exp_events
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_executor_stream_passes_through_typed_events(
94+
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator
95+
):
96+
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
97+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
98+
99+
weather_tool.stream = MagicMock()
100+
event_1 = ToolStreamEvent(tool_use, "value 1")
101+
event_2 = ToolStreamEvent(tool_use, {"nested": True})
102+
event_3 = ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]})
103+
weather_tool.stream.return_value = agenerator(
104+
[
105+
event_1,
106+
event_2,
107+
event_3,
108+
]
109+
)
110+
111+
tru_events = await alist(stream)
112+
assert tru_events[0] is event_1
113+
assert tru_events[1] is event_2
114+
115+
# ToolResults are not passed through directly, they're unwrapped then wraped again
116+
assert tru_events[2] == event_3
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_executor_stream_wraps_stream_events_if_no_result(
121+
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator
122+
):
123+
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
124+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
125+
126+
weather_tool.stream = MagicMock()
127+
last_event = ToolStreamEvent(tool_use, "value 1")
128+
# Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent
129+
weather_tool.stream.return_value = agenerator(
130+
[
131+
last_event,
132+
]
133+
)
134+
135+
tru_events = await alist(stream)
136+
exp_events = [last_event, ToolResultEvent(last_event)]
137+
assert tru_events == exp_events
138+
139+
70140
@pytest.mark.asyncio
71141
async def test_executor_stream_yields_tool_error(
72142
executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist
@@ -129,7 +199,6 @@ async def test_executor_stream_with_trace(
129199

130200
tru_events = await alist(stream)
131201
exp_events = [
132-
ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
133202
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
134203
]
135204
assert tru_events == exp_events

tests/strands/tools/executors/test_sequential.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from strands.tools.executors import SequentialToolExecutor
4-
from strands.types._events import ToolResultEvent, ToolStreamEvent
4+
from strands.types._events import ToolResultEvent
55

66

77
@pytest.fixture
@@ -21,13 +21,11 @@ async def test_sequential_executor_execute(
2121

2222
tru_events = await alist(stream)
2323
exp_events = [
24-
ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
2524
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
26-
ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
2725
ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
2826
]
2927
assert tru_events == exp_events
3028

3129
tru_results = tool_results
32-
exp_results = [exp_events[1].tool_result, exp_events[3].tool_result]
30+
exp_results = [exp_events[0].tool_result, exp_events[1].tool_result]
3331
assert tru_results == exp_results

tests/strands/tools/mcp/test_mcp_agent_tool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from mcp.types import Tool as MCPTool
55

66
from strands.tools.mcp import MCPAgentTool, MCPClient
7+
from strands.types._events import ToolResultEvent
78

89

910
@pytest.fixture
@@ -62,7 +63,7 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
6263
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}
6364

6465
tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))
65-
exp_events = [mock_mcp_client.call_tool_async.return_value]
66+
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]
6667

6768
assert tru_events == exp_events
6869
mock_mcp_client.call_tool_async.assert_called_once_with(

0 commit comments

Comments
 (0)