diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 2ce6d946f..8b218dfa1 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -53,6 +53,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: Type, TypeVar, Union, + cast, get_type_hints, overload, ) @@ -61,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolSpec, ToolUse +from ..types._events import ToolResultEvent, ToolStreamEvent +from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -454,43 +456,67 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw # Inject special framework-provided parameters self._metadata.inject_special_parameters(validated_input, tool_use, invocation_state) - # "Too few arguments" expected, hence the type ignore - if inspect.iscoroutinefunction(self._tool_func): + # Note: "Too few arguments" expected for the _tool_func calls, hence the type ignore + + # Async-generators, yield streaming events and final tool result + if inspect.isasyncgenfunction(self._tool_func): + sub_events = self._tool_func(**validated_input) # type: ignore + async for sub_event in sub_events: + yield ToolStreamEvent(tool_use, sub_event) + + # The last event is the result + yield self._wrap_tool_result(tool_use_id, sub_event) + + # Async functions, yield only the result + elif inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(**validated_input) # type: ignore - else: - result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) - # FORMAT THE RESULT for Strands Agent - if isinstance(result, dict) and "status" in result and "content" in result: - # Result is already in the expected format, just add toolUseId - result["toolUseId"] = tool_use_id - yield result + # Other functions, yield only the result else: - # Wrap any other return value in the standard format - # Always include at least one content item for consistency - yield { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": str(result)}], - } + result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore + yield self._wrap_tool_result(tool_use_id, result) except ValueError as e: # Special handling for validation errors error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_msg}"}], + }, + ) except Exception as e: # Return error result with exception details for any other error error_type = type(e).__name__ error_msg = str(e) - yield { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error_type} - {error_msg}"}], - } + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error_type} - {error_msg}"}], + }, + ) + + def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + # FORMAT THE RESULT for Strands Agent + if isinstance(result, dict) and "status" in result and "content" in result: + # Result is already in the expected format, just add toolUseId + result["toolUseId"] = tool_use_d + return ToolResultEvent(cast(ToolResult, result)) + else: + # Wrap any other return value in the standard format + # Always include at least one content item for consistency + return ToolResultEvent( + { + "toolUseId": tool_use_d, + "status": "success", + "content": [{"text": str(result)}], + } + ) @property def supports_hot_reload(self) -> bool: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 701a3bac0..5354991c3 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -119,7 +119,20 @@ async def _stream( return async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - yield ToolStreamEvent(tool_use, event) + # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() + # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. + # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent + # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in + # ToolStreamEvent and the last even is just the result + + if isinstance(event, ToolResultEvent): + # below the last "event" must point to the tool_result + event = event.tool_result + break + elif isinstance(event, ToolStreamEvent): + yield event + else: + yield ToolStreamEvent(tool_use, event) result = cast(ToolResult, event) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index f9c8d6061..f15bb1718 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -11,6 +11,7 @@ from mcp.types import Tool as MCPTool from typing_extensions import override +from ...types._events import ToolResultEvent from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse if TYPE_CHECKING: @@ -96,4 +97,4 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw name=self.tool_name, arguments=tool_use["input"], ) - yield result + yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 5fafed5dc..66eda08ae 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -9,7 +9,7 @@ from mcp.shared.message import SessionMessage from typing_extensions import NotRequired -from strands.types.tools import ToolResult +from ...types.tools import ToolResult """ MCPTransport defines the interface for MCP transport implementations. This abstracts diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 465063095..9e1c0e608 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -12,6 +12,7 @@ from typing_extensions import override +from ..types._events import ToolResultEvent from ..types.tools import AgentTool, ToolFunc, ToolGenerator, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -211,7 +212,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw """ if inspect.iscoroutinefunction(self._tool_func): result = await self._tool_func(tool_use, **invocation_state) + yield ToolResultEvent(result) else: result = await asyncio.to_thread(self._tool_func, tool_use, **invocation_state) - - yield result + yield ToolResultEvent(result) diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index 140537add..f7fc64b25 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import ConcurrentToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -22,13 +22,11 @@ async def test_concurrent_executor_execute( tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 56caa950a..903a11e5a 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -1,4 +1,5 @@ import unittest.mock +from unittest.mock import MagicMock import pytest @@ -39,7 +40,6 @@ async def test_executor_stream_yields_result( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events @@ -67,6 +67,76 @@ async def test_executor_stream_yields_result( assert tru_hook_events == exp_hook_events +@pytest.mark.asyncio +async def test_executor_stream_wraps_results( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + weather_tool.stream.return_value = agenerator( + ["value 1", {"nested": True}, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}] + ) + + tru_events = await alist(stream) + exp_events = [ + ToolStreamEvent(tool_use, "value 1"), + ToolStreamEvent(tool_use, {"nested": True}), + ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), + ] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_executor_stream_passes_through_typed_events( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + event_1 = ToolStreamEvent(tool_use, "value 1") + event_2 = ToolStreamEvent(tool_use, {"nested": True}) + event_3 = ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}) + weather_tool.stream.return_value = agenerator( + [ + event_1, + event_2, + event_3, + ] + ) + + tru_events = await alist(stream) + assert tru_events[0] is event_1 + assert tru_events[1] is event_2 + + # ToolResults are not passed through directly, they're unwrapped then wraped again + assert tru_events[2] == event_3 + + +@pytest.mark.asyncio +async def test_executor_stream_wraps_stream_events_if_no_result( + executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist, agenerator +): + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + weather_tool.stream = MagicMock() + last_event = ToolStreamEvent(tool_use, "value 1") + # Only ToolResultEvent can be the last value; all others are wrapped in ToolResultEvent + weather_tool.stream.return_value = agenerator( + [ + last_event, + ] + ) + + tru_events = await alist(stream) + exp_events = [last_event, ToolResultEvent(last_event)] + assert tru_events == exp_events + + @pytest.mark.asyncio async def test_executor_stream_yields_tool_error( executor, agent, tool_results, invocation_state, hook_events, exception_tool, alist @@ -129,7 +199,6 @@ async def test_executor_stream_with_trace( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_use, {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ] assert tru_events == exp_events diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index d4e98223e..37e098142 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,7 @@ import pytest from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolResultEvent @pytest.fixture @@ -21,13 +21,11 @@ async def test_sequential_executor_execute( tru_events = await alist(stream) exp_events = [ - ToolStreamEvent(tool_uses[0], {"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), ToolResultEvent({"toolUseId": "1", "status": "success", "content": [{"text": "sunny"}]}), - ToolStreamEvent(tool_uses[1], {"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ToolResultEvent({"toolUseId": "2", "status": "success", "content": [{"text": "75F"}]}), ] assert tru_events == exp_events tru_results = tool_results - exp_results = [exp_events[1].tool_result, exp_events[3].tool_result] + exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 874006683..1c025f5f2 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -4,6 +4,7 @@ from mcp.types import Tool as MCPTool from strands.tools.mcp import MCPAgentTool, MCPClient +from strands.types._events import ToolResultEvent @pytest.fixture @@ -62,7 +63,7 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}} tru_events = await alist(mcp_agent_tool.stream(tool_use, {})) - exp_events = [mock_mcp_client.call_tool_async.return_value] + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 02e7eb445..a13c2833e 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -10,6 +10,7 @@ import strands from strands import Agent +from strands.types._events import ToolResultEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -117,7 +118,7 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"toolUseId": "t1", "input": {"a": 2}}, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]}] + exp_events = [ToolResultEvent({"toolUseId": "t1", "status": "success", "content": [{"text": "2"}]})] assert tru_events == exp_events @@ -131,7 +132,9 @@ def identity(a: int, agent: dict = None): stream = identity.stream({"input": {"a": 2}}, {"agent": {"state": 1}}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}) + ] assert tru_events == exp_events @@ -180,7 +183,9 @@ def test_tool(param1: str, param2: int) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] assert tru_events == exp_events # Make sure these are set properly @@ -229,7 +234,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello"}]}) + ] assert tru_events == exp_events # Test with both params @@ -237,7 +244,9 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: stream = test_tool.stream(tool_use, {}) tru_events = await alist(stream) - exp_events = [{"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}] + exp_events = [ + ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) + ] @pytest.mark.asyncio @@ -256,8 +265,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "validation error for test_tooltool\nrequired\n" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "validation error for test_tooltool\nrequired\n" in result["tool_result"]["content"][0]["text"].lower(), ( "Validation error should indicate which argument is missing" ) @@ -266,8 +275,8 @@ def test_tool(required: str) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test error" in result["content"][0]["text"].lower(), ( + assert result["tool_result"]["status"] == "error" + assert "test error" in result["tool_result"]["content"][0]["text"].lower(), ( "Runtime error should contain the original error message" ) @@ -313,14 +322,14 @@ def test_tool(param: str, agent=None) -> str: stream = test_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "Param: test" + assert result["tool_result"]["content"][0]["text"] == "Param: test" # Test with agent stream = test_tool.stream(tool_use, {"agent": mock_agent}) result = (await alist(stream))[-1] - assert "Agent:" in result["content"][0]["text"] - assert "test" in result["content"][0]["text"] + assert "Agent:" in result["tool_result"]["content"][0]["text"] + assert "test" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -350,23 +359,23 @@ def none_return_tool(param: str) -> None: stream = dict_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" - assert result["toolUseId"] == "test-id" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" + assert result["tool_result"]["toolUseId"] == "test-id" # Test the string return - should wrap in standard format stream = string_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Result: test" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Result: test" # Test None return - should still create valid ToolResult with "None" text stream = none_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -403,7 +412,7 @@ def test_method(self, param: str) -> str: stream = instance.test_method.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "Test: tool-value" in result["content"][0]["text"] + assert "Test: tool-value" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -422,7 +431,9 @@ class MyThing: ... stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -444,7 +455,9 @@ def test_method(param: str) -> str: stream = instance.field.stream({"toolUseId": "test-id", "input": {"param": "example"}}, {}) result2 = (await alist(stream))[-1] - assert result2 == {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + assert result2 == ToolResultEvent( + {"content": [{"text": "param: example"}], "status": "success", "toolUseId": "test-id"} + ) @pytest.mark.asyncio @@ -474,14 +487,14 @@ def tool_with_defaults(required: str, optional: str = "default", number: int = 4 stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 42" + assert result["tool_result"]["content"][0]["text"] == "hello default 42" # Call with some but not all optional parameters tool_use = {"toolUseId": "test-id", "input": {"required": "hello", "number": 100}} stream = tool_with_defaults.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["content"][0]["text"] == "hello default 100" + assert result["tool_result"]["content"][0]["text"] == "hello default 100" @pytest.mark.asyncio @@ -496,14 +509,15 @@ def test_tool(required: str) -> str: # Test with completely empty tool use stream = test_tool.stream({}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "unknown" in result["toolUseId"] + print(result) + assert result["tool_result"]["status"] == "error" + assert "unknown" in result["tool_result"]["toolUseId"] # Test with missing input stream = test_tool.stream({"toolUseId": "test-id"}, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "test-id" in result["toolUseId"] + assert result["tool_result"]["status"] == "error" + assert "test-id" in result["tool_result"]["toolUseId"] @pytest.mark.asyncio @@ -529,8 +543,8 @@ def add_numbers(a: int, b: int) -> int: stream = add_numbers.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "5" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "5" @pytest.mark.asyncio @@ -565,8 +579,8 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "hello, default_str, 42, True, 3.14" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "hello, default_str, 42, True, 3.14" in result["tool_result"]["content"][0]["text"] # Test calling with some optional parameters tool_use = { @@ -576,7 +590,7 @@ def multi_default_tool( stream = multi_default_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert "hello, default_str, 100, True, 2.718" in result["content"][0]["text"] + assert "hello, default_str, 100, True, 2.718" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -603,8 +617,8 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "42" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "42" # Test with return that doesn't match declared type # Note: This should still work because Python doesn't enforce return types at runtime @@ -613,16 +627,16 @@ def int_return_tool(param: str) -> int: stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "not an int" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "not an int" # Test with None return from a non-None return type tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = int_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Define tool with Union return type @strands.tool @@ -644,22 +658,25 @@ def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "{'key': 'value'}" in result["content"][0]["text"] or '{"key": "value"}' in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert ( + "{'key': 'value'}" in result["tool_result"]["content"][0]["text"] + or '{"key": "value"}' in result["tool_result"]["content"][0]["text"] + ) tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "string result" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "string result" tool_use = {"toolUseId": "test-id", "input": {"param": "none"}} stream = union_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" @pytest.mark.asyncio @@ -682,8 +699,8 @@ def no_params_tool() -> str: stream = no_params_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "Success - no parameters needed" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "Success - no parameters needed" # Test direct call direct_result = no_params_tool() @@ -711,8 +728,8 @@ def complex_type_tool(config: Dict[str, Any]) -> str: stream = complex_type_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "Got config with 3 keys" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "Got config with 3 keys" in result["tool_result"]["content"][0]["text"] # Direct call direct_result = complex_type_tool(nested_dict) @@ -742,12 +759,12 @@ def custom_result_tool(param: str) -> Dict[str, Any]: # The wrapper should preserve our format and just add the toolUseId result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["toolUseId"] == "custom-id" - assert len(result["content"]) == 2 - assert result["content"][0]["text"] == "First line: test" - assert result["content"][1]["text"] == "Second line" - assert result["content"][1]["type"] == "markdown" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["toolUseId"] == "custom-id" + assert len(result["tool_result"]["content"]) == 2 + assert result["tool_result"]["content"][0]["text"] == "First line: test" + assert result["tool_result"]["content"][1]["text"] == "Second line" + assert result["tool_result"]["content"][1]["type"] == "markdown" def test_docstring_parsing(): @@ -816,8 +833,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] # Test missing required parameter tool_use = { @@ -831,8 +848,8 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: stream = validation_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" - assert "int_param" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "error" + assert "int_param" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -855,16 +872,16 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "None" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "None" # Test with empty dict tool_use = {"toolUseId": "test-id", "input": {"param": {}}} stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert result["content"][0]["text"] == "{}" + assert result["tool_result"]["status"] == "success" + assert result["tool_result"]["content"][0]["text"] == "{}" # Test with a complex nested dictionary nested_dict = {"key1": {"nested": [1, 2, 3]}, "key2": None} @@ -872,9 +889,9 @@ def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: stream = edge_case_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "key1" in result["content"][0]["text"] - assert "nested" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "key1" in result["tool_result"]["content"][0]["text"] + assert "nested" in result["tool_result"]["content"][0]["text"] @pytest.mark.asyncio @@ -922,8 +939,8 @@ def test_method(self): stream = instance.test_method.stream({"toolUseId": "test-id", "input": {"param": "direct"}}, {}) direct_result = (await alist(stream))[-1] - assert direct_result["status"] == "success" - assert direct_result["content"][0]["text"] == "Method Got: direct" + assert direct_result["tool_result"]["status"] == "success" + assert direct_result["tool_result"]["content"][0]["text"] == "Method Got: direct" # Create a standalone function to test regular function calls @strands.tool @@ -944,8 +961,8 @@ def standalone_tool(p1: str, p2: str = "default") -> str: stream = standalone_tool.stream({"toolUseId": "test-id", "input": {"p1": "value1"}}, {}) tool_use_result = (await alist(stream))[-1] - assert tool_use_result["status"] == "success" - assert tool_use_result["content"][0]["text"] == "Standalone: value1, default" + assert tool_use_result["tool_result"]["status"] == "success" + assert tool_use_result["tool_result"]["content"][0]["text"] == "Standalone: value1, default" @pytest.mark.asyncio @@ -976,9 +993,9 @@ def failing_tool(param: str) -> str: stream = failing_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "error" + assert result["tool_result"]["status"] == "error" - error_message = result["content"][0]["text"] + error_message = result["tool_result"]["content"][0]["text"] # Check that error type is included if error_type == "value_error": @@ -1011,33 +1028,33 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None] stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "list: [1, 2, 3]" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "list: [1, 2, 3]" in result["tool_result"]["content"][0]["text"] # Test with a dict tool_use = {"toolUseId": "test-id", "input": {"union_param": {"key": "value"}}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "dict:" in result["content"][0]["text"] - assert "key" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "dict:" in result["tool_result"]["content"][0]["text"] + assert "key" in result["tool_result"]["content"][0]["text"] # Test with a string tool_use = {"toolUseId": "test-id", "input": {"union_param": "test_string"}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "str: test_string" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "str: test_string" in result["tool_result"]["content"][0]["text"] # Test with None tool_use = {"toolUseId": "test-id", "input": {"union_param": None}} stream = complex_schema_tool.stream(tool_use, {}) result = (await alist(stream))[-1] - assert result["status"] == "success" - assert "NoneType: None" in result["content"][0]["text"] + assert result["tool_result"]["status"] == "success" + assert "NoneType: None" in result["tool_result"]["content"][0]["text"] async def _run_context_injection_test(context_tool: AgentTool, additional_context=None): @@ -1061,15 +1078,17 @@ async def _run_context_injection_test(context_tool: AgentTool, additional_contex assert len(tool_results) == 1 tool_result = tool_results[0] - assert tool_result == { - "status": "success", - "content": [ - {"text": "Tool 'context_tool' (ID: test-id)"}, - {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"}, - ], - "toolUseId": "test-id", - } + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [ + {"text": "Tool 'context_tool' (ID: test-id)"}, + {"text": "injected agent 'test_agent' processed: some_message"}, + {"text": "context agent 'test_agent'"}, + ], + "toolUseId": "test-id", + } + ) @pytest.mark.asyncio @@ -1164,9 +1183,9 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: tool_result = tool_results[0] # Should get a validation error because tool_context is required but not provided - assert tool_result["status"] == "error" - assert "tool_context" in tool_result["content"][0]["text"].lower() - assert "validation" in tool_result["content"][0]["text"].lower() + assert tool_result["tool_result"]["status"] == "error" + assert "tool_context" in tool_result["tool_result"]["content"][0]["text"].lower() + assert "validation" in tool_result["tool_result"]["content"][0]["text"].lower() @pytest.mark.asyncio @@ -1196,8 +1215,10 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_result = tool_results[0] # Should succeed with the string parameter - assert tool_result == { - "status": "success", - "content": [{"text": "success"}], - "toolUseId": "test-id-2", - } + assert tool_result == ToolResultEvent( + { + "status": "success", + "content": [{"text": "success"}], + "toolUseId": "test-id-2", + } + ) diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 240c24717..b305a1a90 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -9,6 +9,7 @@ validate_tool_use, validate_tool_use_name, ) +from strands.types._events import ToolResultEvent from strands.types.tools import ToolUse @@ -506,5 +507,5 @@ async def test_stream(identity_tool, alist): stream = identity_tool.stream({"tool_use": 1}, {"a": 2}) tru_events = await alist(stream) - exp_events = [({"tool_use": 1}, 2)] + exp_events = [ToolResultEvent(({"tool_use": 1}, 2))] assert tru_events == exp_events