Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 53 additions & 27 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
Type,
TypeVar,
Union,
cast,
get_type_hints,
overload,
)
Expand All @@ -61,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override

from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse
from ..types._events import ToolResultEvent, ToolStreamEvent
from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -454,43 +456,67 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
# Inject special framework-provided parameters
self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state)

# "Too few arguments" expected, hence the type ignore
if inspect.iscoroutinefunction(self._tool_func):
# Note: "Too few arguments" expected for the _tool_func calls, hence the type ignore

# Async-generators, yield streaming events and final tool result
if inspect.isasyncgenfunction(self._tool_func):
sub_events = self._tool_func(**validated_input) # type: ignore
async for sub_event in sub_events:
yield ToolStreamEvent(tool_use, sub_event)

# The last event is the result
yield self._wrap_tool_result(tool_use_id, sub_event)

# Async functions, yield only the result
elif inspect.iscoroutinefunction(self._tool_func):
result = await self._tool_func(**validated_input) # type: ignore
else:
result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore
yield self._wrap_tool_result(tool_use_id, result)

# FORMAT THE RESULT for Strands Agent
if isinstance(result, dict) and "status" in result and "content" in result:
# Result is already in the expected format, just add toolUseId
result["toolUseId"] = tool_use_id
yield result
# Other functions, yield only the result
else:
# Wrap any other return value in the standard format
# Always include at least one content item for consistency
yield {
"toolUseId": tool_use_id,
"status": "success",
"content": [{"text": str(result)}],
}
result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore
yield self._wrap_tool_result(tool_use_id, result)

except ValueError as e:
# Special handling for validation errors
error_msg = str(e)
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_msg}"}],
}
yield self._wrap_tool_result(
tool_use_id,
{
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_msg}"}],
},
)
except Exception as e:
# Return error result with exception details for any other error
error_type = type(e).__name__
error_msg = str(e)
yield {
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_type} - {error_msg}"}],
}
yield self._wrap_tool_result(
tool_use_id,
{
"toolUseId": tool_use_id,
"status": "error",
"content": [{"text": f"Error: {error_type} - {error_msg}"}],
},
)

def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent:
# FORMAT THE RESULT for Strands Agent
if isinstance(result, dict) and "status" in result and "content" in result:
# Result is already in the expected format, just add toolUseId
result["toolUseId"] = tool_use_d
return ToolResultEvent(cast(ToolResult, result))
else:
# Wrap any other return value in the standard format
# Always include at least one content item for consistency
return ToolResultEvent(
{
"toolUseId": tool_use_d,
"status": "success",
"content": [{"text": str(result)}],
}
)

@property
def supports_hot_reload(self) -> bool:
Expand Down
15 changes: 14 additions & 1 deletion src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,20 @@ async def _stream(
return

async for event in selected_tool.stream(tool_use, invocation_state, **kwargs):
yield ToolStreamEvent(tool_use, event)
# Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream()
# so that we don't needlessly yield ToolStreamEvents for non-generator callbacks.
# In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent
# we yield it directly; all other cases (non-sdk AgentTools), we wrap events in
# ToolStreamEvent and the last even is just the result

if isinstance(event, ToolResultEvent):
# below the last "event" must point to the tool_result
event = event.tool_result
break
elif isinstance(event, ToolStreamEvent):
yield event
else:
yield ToolStreamEvent(tool_use, event)

result = cast(ToolResult, event)

Expand Down
3 changes: 2 additions & 1 deletion src/strands/tools/mcp/mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mcp.types import Tool as MCPTool
from typing_extensions import override

from ...types._events import ToolResultEvent
from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse

if TYPE_CHECKING:
Expand Down Expand Up @@ -96,4 +97,4 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
name=self.tool_name,
arguments=tool_use["input"],
)
yield result
yield ToolResultEvent(result)
2 changes: 1 addition & 1 deletion src/strands/tools/mcp/mcp_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mcp.shared.message import SessionMessage
from typing_extensions import NotRequired

from strands.types.tools import ToolResult
from ...types.tools import ToolResult

"""
MCPTransport defines the interface for MCP transport implementations. This abstracts
Expand Down
5 changes: 3 additions & 2 deletions src/strands/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from typing_extensions import override

from ..types._events import ToolResultEvent
from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -211,7 +212,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
"""
if inspect.iscoroutinefunction(self._tool_func):
result = await self._tool_func(tool_use, **invocation_state)
yield ToolResultEvent(result)
else:
result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state)

yield result
yield ToolResultEvent(result)
6 changes: 2 additions & 4 deletions tests/strands/tools/executors/test_concurrent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from strands.tools.executors import ConcurrentToolExecutor
from strands.types._events import ToolResultEvent, ToolStreamEvent
from strands.types._events import ToolResultEvent
from strands.types.tools import ToolUse


Expand All @@ -22,13 +22,11 @@ async def test_concurrent_executor_execute(

tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id)
exp_events = [
ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
]
assert tru_events == exp_events

tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId"))
exp_results = [exp_events[1].tool_result, exp_events[3].tool_result]
exp_results = [exp_events[0].tool_result, exp_events[1].tool_result]
assert tru_results == exp_results
73 changes: 71 additions & 2 deletions tests/strands/tools/executors/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest.mock
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -39,7 +40,6 @@ async def test_executor_stream_yields_result(

tru_events = await alist(stream)
exp_events = [
ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
]
assert tru_events == exp_events
Expand Down Expand Up @@ -67,6 +67,76 @@ async def test_executor_stream_yields_result(
assert tru_hook_events == exp_hook_events


@pytest.mark.asyncio
async def test_executor_stream_wraps_results(
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator
):
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
stream = executor._stream(agent, tool_use, tool_results, invocation_state)

weather_tool.stream = MagicMock()
weather_tool.stream.return_value = agenerator(
["value 1", {"nested": True}, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}]
)

tru_events = await alist(stream)
exp_events = [
ToolStreamEvent(tool_use, "value 1"),
ToolStreamEvent(tool_use, {"nested": True}),
ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_executor_stream_passes_through_typed_events(
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator
):
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
stream = executor._stream(agent, tool_use, tool_results, invocation_state)

weather_tool.stream = MagicMock()
event_1 = ToolStreamEvent(tool_use, "value 1")
event_2 = ToolStreamEvent(tool_use, {"nested": True})
event_3 = ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]})
weather_tool.stream.return_value = agenerator(
[
event_1,
event_2,
event_3,
]
)

tru_events = await alist(stream)
assert tru_events[0] is event_1
assert tru_events[1] is event_2

# ToolResults are not passed through directly, they're unwrapped then wraped again
assert tru_events[2] == event_3


@pytest.mark.asyncio
async def test_executor_stream_wraps_stream_events_if_no_result(
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator
):
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
stream = executor._stream(agent, tool_use, tool_results, invocation_state)

weather_tool.stream = MagicMock()
last_event = ToolStreamEvent(tool_use, "value 1")
# Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent
weather_tool.stream.return_value = agenerator(
[
last_event,
]
)

tru_events = await alist(stream)
exp_events = [last_event, ToolResultEvent(last_event)]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_executor_stream_yields_tool_error(
executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist
Expand Down Expand Up @@ -129,7 +199,6 @@ async def test_executor_stream_with_trace(

tru_events = await alist(stream)
exp_events = [
ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
]
assert tru_events == exp_events
Expand Down
6 changes: 2 additions & 4 deletions tests/strands/tools/executors/test_sequential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from strands.tools.executors import SequentialToolExecutor
from strands.types._events import ToolResultEvent, ToolStreamEvent
from strands.types._events import ToolResultEvent


@pytest.fixture
Expand All @@ -21,13 +21,11 @@ async def test_sequential_executor_execute(

tru_events = await alist(stream)
exp_events = [
ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}),
ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}),
]
assert tru_events == exp_events

tru_results = tool_results
exp_results = [exp_events[1].tool_result, exp_events[3].tool_result]
exp_results = [exp_events[0].tool_result, exp_events[1].tool_result]
assert tru_results == exp_results
3 changes: 2 additions & 1 deletion tests/strands/tools/mcp/test_mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mcp.types import Tool as MCPTool

from strands.tools.mcp import MCPAgentTool, MCPClient
from strands.types._events import ToolResultEvent


@pytest.fixture
Expand Down Expand Up @@ -62,7 +63,7 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}

tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))
exp_events = [mock_mcp_client.call_tool_async.return_value]
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]

assert tru_events == exp_events
mock_mcp_client.call_tool_async.assert_called_once_with(
Expand Down
Loading
Loading