From 5ca7ac3738ded458511a03e3d3a29331e9c1b0bc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Fri, 29 Aug 2025 16:23:29 -0400 Subject: [PATCH] fix: don't emit ToolStream events for non generator functions Our current implementation of AgentTool.stream() has a problem that we don't differentiate between intermediate streaming events and the final ToolResult events. Our only contract is that the last event *must be* be the tool result that is passed to the LLM. Our switch to Typed Events (#755) pushes us in the right direction but for backwards compatibility we can't update the signature of `AgentTool.stream()` (nor have we exposed externally TypedEvents yet). That means that if we implemented tool-streaming today, then callers would see non-generator functions yielding both a `ToolStreamEvent` and `ToolResultEvent` even though they're not actually streaming responses. To avoid the odd behavior noted above, we'll special-case SDK-defined functions by allowing them to emit `ToolStreamEvent` and `ToolResultEvent` types directly (bypassing our normal wrapping), since they have the knowledge of when tools are actually generators or not. There's no observable difference in behavior to callers (this is all internal behavior), but this means that when we switch the flip for Tool Streaming, non-generator tools will **not** emit ToolStreamEvents - at least for AgentTool implementations that are in the SDK. --- src/strands/tools/decorator.py | 80 ++++--- src/strands/tools/executors/_executor.py | 15 +- src/strands/tools/mcp/mcp_agent_tool.py | 3 +- src/strands/tools/mcp/mcp_types.py | 2 +- src/strands/tools/tools.py | 5 +- .../tools/executors/test_concurrent.py | 6 +- .../strands/tools/executors/test_executor.py | 73 +++++- .../tools/executors/test_sequential.py | 6 +- .../strands/tools/mcp/test_mcp_agent_tool.py | 3 +- tests/strands/tools/test_decorator.py | 217 ++++++++++-------- tests/strands/tools/test_tools.py | 3 +- 11 files changed, 271 insertions(+), 142 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 2ce6d946f..8b218dfa1 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -53,6 +53,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: Type, TypeVar, Union, + cast, get_type_hints, overload, ) @@ -61,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse +from ..types._events import ToolResultEvent, ToolStreamEvent +from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -454,43 +456,67 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw # Inject special framework-provided parameters self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state) - # "Too few arguments" expected, hence the type ignore - if inspect.iscoroutinefunction(self._tool_func): + # Note: "Too few arguments" expected for the _tool_func calls, hence the type ignore + + # Async-generators, yield streaming events and final tool result + if inspect.isasyncgenfunction(self._tool_func): + sub_events = self._tool_func(**validated_input) # type: ignore + async for sub_event in sub_events: + yield ToolStreamEvent(tool_use, sub_event) + + # The last event is the result + yield self._wrap_tool_result(tool_use_id, sub_event) + + # Async functions, yield only the result + elif inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(**validated_input) # type: ignore - else: - result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) - # FORMAT THE RESULT for Strands Agent - if isinstance(result, dict) and "status" in result and "content" in result: - # Result is already in the expected format, just add toolUseId - result["toolUseId"] = tool_use_id - yield result + # Other functions, yield only the result else: - # Wrap any other return value in the standard format - # Always include at least one content item for consistency - yield { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": str(result)}], - } + result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) except ValueError as e: # Special handling for validation errors error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_msg}"}], + }, + ) except Exception as e: # Return error result with exception details for any other error error_type = type(e).__name__ error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_type} - {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_type} - {error_msg}"}], + }, + ) + + def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + # FORMAT THE RESULT for Strands Agent + if isinstance(result, dict) and "status" in result and "content" in result: + # Result is already in the expected format, just add toolUseId + result["toolUseId"] = tool_use_d + return ToolResultEvent(cast(ToolResult, result)) + else: + # Wrap any other return value in the standard format + # Always include at least one content item for consistency + return ToolResultEvent( + { + "toolUseId": tool_use_d, + "status": "success", + "content": [{"text": str(result)}], + } + ) @property def supports_hot_reload(self) -> bool: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 701a3bac0..5354991c3 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -119,7 +119,20 @@ async def _stream( return async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - yield ToolStreamEvent(tool_use, event) + # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() + # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. + # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent + # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in + # ToolStreamEvent and the last even is just the result + + if isinstance(event, ToolResultEvent): + # below the last "event" must point to the tool_result + event = event.tool_result + break + elif isinstance(event, ToolStreamEvent): + yield event + else: + yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index f9c8d6061..f15bb1718 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -11,6 +11,7 @@ from mcp.types import Tool as MCPTool from typing_extensions import override +from ...types._events import ToolResultEvent from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse if TYPE_CHECKING: @@ -96,4 +97,4 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw name=self.tool_name, arguments=tool_use["input"], ) - yield result + yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 5fafed5dc..66eda08ae 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -9,7 +9,7 @@ from mcp.shared.message import SessionMessage from typing_extensions import NotRequired -from strands.types.tools import ToolResult +from ...types.tools import ToolResult """ MCPTransport defines the interface for MCP transport implementations. This abstracts diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 465063095..9e1c0e608 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -12,6 +12,7 @@ from typing_extensions import override +from ..types._events import ToolResultEvent from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -211,7 +212,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw """ if inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(tool_use, **invocation_state) + yield ToolResultEvent(result) else: result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state) - - yield result + yield ToolResultEvent(result) diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 140537add..f7fc64b25 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import ConcurrentToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -22,13 +22,11 @@ async def test_concurrent_executor_execute( tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 56caa950a..903a11e5a 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -1,4 +1,5 @@ import unittest.mock +from unittest.mock import MagicMock import pytest @@ -39,7 +40,6 @@ async def test_executor_stream_yields_result( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events @@ -67,6 +67,76 @@ async def test_executor_stream_yields_result( assert tru_hook_events == exp_hook_events +@pytest.mark.asyncio +async def test_executor_stream_wraps_results( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + weather_tool.stream.return_value = agenerator( + ["value 1", {"nested": True}, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}] + ) + + tru_events = await alist(stream) + exp_events = [ + ToolStreamEvent(tool_use, "value 1"), + ToolStreamEvent(tool_use, {"nested": True}), + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_executor_stream_passes_through_typed_events( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + event_1 = ToolStreamEvent(tool_use, "value 1") + event_2 = ToolStreamEvent(tool_use, {"nested": True}) + event_3 = ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}) + weather_tool.stream.return_value = agenerator( + [ + event_1, + event_2, + event_3, + ] + ) + + tru_events = await alist(stream) + assert tru_events[0] is event_1 + assert tru_events[1] is event_2 + + # ToolResults are not passed through directly, they're unwrapped then wraped again + assert tru_events[2] == event_3 + + +@pytest.mark.asyncio +async def test_executor_stream_wraps_stream_events_if_no_result( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + last_event = ToolStreamEvent(tool_use, "value 1") + # Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent + weather_tool.stream.return_value = agenerator( + [ + last_event, + ] + ) + + tru_events = await alist(stream) + exp_events = [last_event, ToolResultEvent(last_event)] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_executor_stream_yields_tool_error( executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist @@ -129,7 +199,6 @@ async def test_executor_stream_with_trace( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index d4e98223e..37e098142 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent @pytest.fixture @@ -21,13 +21,11 @@ async def test_sequential_executor_execute( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 874006683..1c025f5f2 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -4,6 +4,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPAgentTool, MCPClient +from strands.types._events import ToolResultEvent @pytest.fixture @@ -62,7 +63,7 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) - exp_events = [mock_mcp_client.call_tool_async.return_value] + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 02e7eb445..a13c2833e 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -10,6 +10,7 @@ import strands from strands import Agent +from strands.types._events import ToolResultEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -117,7 +118,7 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"toolUseId": "t1", "input": {"a": 2}}, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]}] + exp_events = [ToolResultEvent({"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]})] assert tru_events == exp_events @@ -131,7 +132,9 @@ def identity(a: int, agent: dict = None): stream = identity.stream({"input": {"a": 2}}, {"agent": {"state": 1}}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}) + ] assert tru_events == exp_events @@ -180,7 +183,9 @@ def test_tool(param1: str, param2: int) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] assert tru_events == exp_events # Make sure these are set properly @@ -229,7 +234,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}) + ] assert tru_events == exp_events # Test with both params @@ -237,7 +244,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] @pytest.mark.asyncio @@ -256,8 +265,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "validation error for test_tooltool\nrequired\n" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "validation error for test_tooltool\nrequired\n" in result["tool_result"]["content"][0]["text"].lower(), ( "Validation error should indicate which argument is missing" ) @@ -266,8 +275,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test error" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "test error" in result["tool_result"]["content"][0]["text"].lower(), ( "Runtime error should contain the original error message" ) @@ -313,14 +322,14 @@ def test_tool(param: str, agent=None) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "Param: test" + assert result["tool_result"]["content"][0]["text"] == "Param: test" # Test with agent stream = test_tool.stream(tool_use, {"agent": mock_agent}) result = (await alist(stream))[-1] - assert "Agent:" in result["content"][0]["text"] - assert "test" in result["content"][0]["text"] + assert "Agent:" in result["tool_result"]["content"][0]["text"] + assert "test" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -350,23 +359,23 @@ def none_return_tool(param: str) -> None: stream = dict_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" - assert result["toolUseId"] == "test-id" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" + assert result["tool_result"]["toolUseId"] == "test-id" # Test the string return - should wrap in standard format stream = string_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" # Test None return - should still create valid ToolResult with "None" text stream = none_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -403,7 +412,7 @@ def test_method(self, param: str) -> str: stream = instance.test_method.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "Test: tool-value" in result["content"][0]["text"] + assert "Test: tool-value" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -422,7 +431,9 @@ class MyThing: ... stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -444,7 +455,9 @@ def test_method(param: str) -> str: stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -474,14 +487,14 @@ def tool_with_defaults(required: str, optional: str = "default", number: int = 4 stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 42" + assert result["tool_result"]["content"][0]["text"] == "hello default 42" # Call with some but not all optional parameters tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "number": 100}} stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 100" + assert result["tool_result"]["content"][0]["text"] == "hello default 100" @pytest.mark.asyncio @@ -496,14 +509,15 @@ def test_tool(required: str) -> str: # Test with completely empty tool use stream = test_tool.stream({}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "unknown" in result["toolUseId"] + print(result) + assert result["tool_result"]["status"] == "error" + assert "unknown" in result["tool_result"]["toolUseId"] # Test with missing input stream = test_tool.stream({"toolUseId": "test-id"}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test-id" in result["toolUseId"] + assert result["tool_result"]["status"] == "error" + assert "test-id" in result["tool_result"]["toolUseId"] @pytest.mark.asyncio @@ -529,8 +543,8 @@ def add_numbers(a: int, b: int) -> int: stream = add_numbers.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "5" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "5" @pytest.mark.asyncio @@ -565,8 +579,8 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "hello, default_str, 42, True, 3.14" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "hello, default_str, 42, True, 3.14" in result["tool_result"]["content"][0]["text"] # Test calling with some optional parameters tool_use = { @@ -576,7 +590,7 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "hello, default_str, 100, True, 2.718" in result["content"][0]["text"] + assert "hello, default_str, 100, True, 2.718" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -603,8 +617,8 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "42" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "42" # Test with return that doesn't match declared type # Note: This should still work because Python doesn't enforce return types at runtime @@ -613,16 +627,16 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "not an int" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "not an int" # Test with None return from a non-None return type tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Define tool with Union return type @strands.tool @@ -644,22 +658,25 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "{'key': 'value'}" in result["content"][0]["text"] or '{"key": "value"}' in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert ( + "{'key': 'value'}" in result["tool_result"]["content"][0]["text"] + or '{"key": "value"}' in result["tool_result"]["content"][0]["text"] + ) tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "string result" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "string result" tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -682,8 +699,8 @@ def no_params_tool() -> str: stream = no_params_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Success - no parameters needed" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Success - no parameters needed" # Test direct call direct_result = no_params_tool() @@ -711,8 +728,8 @@ def complex_type_tool(config: Dict[str, Any]) -> str: stream = complex_type_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "Got config with 3 keys" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "Got config with 3 keys" in result["tool_result"]["content"][0]["text"] # Direct call direct_result = complex_type_tool(nested_dict) @@ -742,12 +759,12 @@ def custom_result_tool(param: str) -> Dict[str, Any]: # The wrapper should preserve our format and just add the toolUseId result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["toolUseId"] == "custom-id" - assert len(result["content"]) == 2 - assert result["content"][0]["text"] == "First line: test" - assert result["content"][1]["text"] == "Second line" - assert result["content"][1]["type"] == "markdown" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["toolUseId"] == "custom-id" + assert len(result["tool_result"]["content"]) == 2 + assert result["tool_result"]["content"][0]["text"] == "First line: test" + assert result["tool_result"]["content"][1]["text"] == "Second line" + assert result["tool_result"]["content"][1]["type"] == "markdown" def test_docstring_parsing(): @@ -816,8 +833,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] # Test missing required parameter tool_use = { @@ -831,8 +848,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -855,16 +872,16 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Test with empty dict tool_use = {"toolUseId": "test-id", "input": {"param": {}}} stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "{}" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "{}" # Test with a complex nested dictionary nested_dict = {"key1": {"nested": [1, 2, 3]}, "key2": None} @@ -872,9 +889,9 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "key1" in result["content"][0]["text"] - assert "nested" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "key1" in result["tool_result"]["content"][0]["text"] + assert "nested" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -922,8 +939,8 @@ def test_method(self): stream = instance.test_method.stream({"toolUseId": "test-id", "input": {"param": "direct"}}, {}) direct_result = (await alist(stream))[-1] - assert direct_result["status"] == "success" - assert direct_result["content"][0]["text"] == "Method Got: direct" + assert direct_result["tool_result"]["status"] == "success" + assert direct_result["tool_result"]["content"][0]["text"] == "Method Got: direct" # Create a standalone function to test regular function calls @strands.tool @@ -944,8 +961,8 @@ def standalone_tool(p1: str, p2: str = "default") -> str: stream = standalone_tool.stream({"toolUseId": "test-id", "input": {"p1": "value1"}}, {}) tool_use_result = (await alist(stream))[-1] - assert tool_use_result["status"] == "success" - assert tool_use_result["content"][0]["text"] == "Standalone: value1, default" + assert tool_use_result["tool_result"]["status"] == "success" + assert tool_use_result["tool_result"]["content"][0]["text"] == "Standalone: value1, default" @pytest.mark.asyncio @@ -976,9 +993,9 @@ def failing_tool(param: str) -> str: stream = failing_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" + assert result["tool_result"]["status"] == "error" - error_message = result["content"][0]["text"] + error_message = result["tool_result"]["content"][0]["text"] # Check that error type is included if error_type == "value_error": @@ -1011,33 +1028,33 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "list: [1, 2, 3]" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "list: [1, 2, 3]" in result["tool_result"]["content"][0]["text"] # Test with a dict tool_use = {"toolUseId": "test-id", "input": {"union_param": {"key": "value"}}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "dict:" in result["content"][0]["text"] - assert "key" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "dict:" in result["tool_result"]["content"][0]["text"] + assert "key" in result["tool_result"]["content"][0]["text"] # Test with a string tool_use = {"toolUseId": "test-id", "input": {"union_param": "test_string"}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "str: test_string" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "str: test_string" in result["tool_result"]["content"][0]["text"] # Test with None tool_use = {"toolUseId": "test-id", "input": {"union_param": None}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "NoneType: None" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "NoneType: None" in result["tool_result"]["content"][0]["text"] async def _run_context_injection_test(context_tool: AgentTool, additional_context=None): @@ -1061,15 +1078,17 @@ async def _run_context_injection_test(context_tool: AgentTool, additional_contex assert len(tool_results) == 1 tool_result = tool_results[0] - assert tool_result == { - "status": "success", - "content": [ - {"text": "Tool 'context_tool' (ID: test-id)"}, - {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"}, - ], - "toolUseId": "test-id", - } + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [ + {"text": "Tool 'context_tool' (ID: test-id)"}, + {"text": "injected agent 'test_agent' processed: some_message"}, + {"text": "context agent 'test_agent'"}, + ], + "toolUseId": "test-id", + } + ) @pytest.mark.asyncio @@ -1164,9 +1183,9 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: tool_result = tool_results[0] # Should get a validation error because tool_context is required but not provided - assert tool_result["status"] == "error" - assert "tool_context" in tool_result["content"][0]["text"].lower() - assert "validation" in tool_result["content"][0]["text"].lower() + assert tool_result["tool_result"]["status"] == "error" + assert "tool_context" in tool_result["tool_result"]["content"][0]["text"].lower() + assert "validation" in tool_result["tool_result"]["content"][0]["text"].lower() @pytest.mark.asyncio @@ -1196,8 +1215,10 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_result = tool_results[0] # Should succeed with the string parameter - assert tool_result == { - "status": "success", - "content": [{"text": "success"}], - "toolUseId": "test-id-2", - } + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [{"text": "success"}], + "toolUseId": "test-id-2", + } + ) diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 240c24717..b305a1a90 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -9,6 +9,7 @@ validate_tool_use, validate_tool_use_name, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -506,5 +507,5 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"tool_use": 1}, {"a": 2}) tru_events = await alist(stream) - exp_events = [({"tool_use": 1}, 2)] + exp_events = [ToolResultEvent(({"tool_use": 1}, 2))] assert tru_events == exp_events