From 76ef69dbbd12628f8e949b2fc2576db1935d8c3e Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Fri, 29 Aug 2025 13:41:53 -0400 Subject: [PATCH] fix: Fix tool result message event Expand the Unit Tests for the yielded event to verify actual tool calls - previous to this, the events were not being emitted because the test was bailing out due to mocked guard rails. To better test the situation, we now have a much more extensive test for the successful tool call --- src/strands/event_loop/event_loop.py | 2 +- .../strands/agent/hooks/test_agent_events.py | 329 ++++++++++++++++-- 2 files changed, 304 insertions(+), 27 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index a99ecc8a6..5d5085101 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -361,7 +361,7 @@ async def _handle_tool_execution( agent.messages.append(tool_result_message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) - yield ToolResultMessageEvent(message=message) + yield ToolResultMessageEvent(message=tool_result_message) if cycle_span: tracer = get_tracer() diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index d63dd97d4..04b832259 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -1,3 +1,4 @@ +import asyncio import unittest.mock from unittest.mock import ANY, MagicMock, call @@ -11,49 +12,333 @@ from tests.fixtures.mocked_model_provider import MockedModelProvider +@strands.tool +def normal_tool(agent: Agent): + return f"Done with synchronous {agent.name}!" + + +@strands.tool +async def async_tool(agent: Agent): + await asyncio.sleep(0.1) + return f"Done with asynchronous {agent.name}!" + + +@strands.tool +async def streaming_tool(): + await asyncio.sleep(0.2) + yield {"tool_streaming": True} + yield "Final result" + + @pytest.fixture def mock_time(): with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock: yield mock -@pytest.mark.asyncio -async def test_stream_async_e2e(alist, mock_time): - @strands.tool - def fake_tool(agent: Agent): - return "Done!" +any_props = { + "agent": ANY, + "event_loop_cycle_id": ANY, + "event_loop_cycle_span": ANY, + "event_loop_cycle_trace": ANY, + "request_state": {}, +} + +@pytest.mark.asyncio +async def test_stream_e2e_success(alist): mock_provider = MockedModelProvider( [ - {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}, - {"role": "assistant", "content": [{"text": "Okay invoking tool!"}]}, { "role": "assistant", - "content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}], + "content": [ + {"text": "Okay invoking normal tool"}, + {"toolUse": {"name": "normal_tool", "toolUseId": "123", "input": {}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Invoking async tool"}, + {"toolUse": {"name": "async_tool", "toolUseId": "1234", "input": {}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Invoking streaming tool"}, + {"toolUse": {"name": "streaming_tool", "toolUseId": "12345", "input": {}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "I invoked the tools!"}, + ], }, - {"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, ] ) + + mock_callback = unittest.mock.Mock() + agent = Agent(model=mock_provider, tools=[async_tool, normal_tool, streaming_tool], callback_handler=mock_callback) + + stream = agent.stream_async("Do the stuff", arg1=1013) + + tool_config = { + "toolChoice": {"auto": {}}, + "tools": [ + { + "toolSpec": { + "description": "async_tool", + "inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}}, + "name": "async_tool", + } + }, + { + "toolSpec": { + "description": "normal_tool", + "inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}}, + "name": "normal_tool", + } + }, + { + "toolSpec": { + "description": "streaming_tool", + "inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}}, + "name": "streaming_tool", + } + }, + ], + } + + tru_events = await alist(stream) + exp_events = [ + # Cycle 1: Initialize and invoke normal_tool + {"arg1": 1013, "init_event_loop": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Okay invoking normal tool"}}}}, + { + **any_props, + "arg1": 1013, + "data": "Okay invoking normal tool", + "delta": {"text": "Okay invoking normal tool"}, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"contentBlockStart": {"start": {"toolUse": {"name": "normal_tool", "toolUseId": "123"}}}}}, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}}, + { + **any_props, + "arg1": 1013, + "current_tool_use": {"input": {}, "name": "normal_tool", "toolUseId": "123"}, + "delta": {"toolUse": {"input": "{}"}}, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "tool_use"}}}, + { + "message": { + "content": [ + {"text": "Okay invoking normal tool"}, + {"toolUse": {"input": {}, "name": "normal_tool", "toolUseId": "123"}}, + ], + "role": "assistant", + } + }, + { + "message": { + "content": [ + { + "toolResult": { + "content": [{"text": "Done with synchronous Strands Agents!"}], + "status": "success", + "toolUseId": "123", + } + }, + ], + "role": "user", + } + }, + # Cycle 2: Invoke async_tool + {"start": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Invoking async tool"}}}}, + { + **any_props, + "arg1": 1013, + "data": "Invoking async tool", + "delta": {"text": "Invoking async tool"}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"contentBlockStart": {"start": {"toolUse": {"name": "async_tool", "toolUseId": "1234"}}}}}, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}}, + { + **any_props, + "arg1": 1013, + "current_tool_use": {"input": {}, "name": "async_tool", "toolUseId": "1234"}, + "delta": {"toolUse": {"input": "{}"}}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "tool_use"}}}, + { + "message": { + "content": [ + {"text": "Invoking async tool"}, + {"toolUse": {"input": {}, "name": "async_tool", "toolUseId": "1234"}}, + ], + "role": "assistant", + } + }, + { + "message": { + "content": [ + { + "toolResult": { + "content": [{"text": "Done with asynchronous Strands Agents!"}], + "status": "success", + "toolUseId": "1234", + } + }, + ], + "role": "user", + } + }, + # Cycle 3: Invoke streaming_tool + {"start": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "Invoking streaming tool"}}}}, + { + **any_props, + "arg1": 1013, + "data": "Invoking streaming tool", + "delta": {"text": "Invoking streaming tool"}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"contentBlockStart": {"start": {"toolUse": {"name": "streaming_tool", "toolUseId": "12345"}}}}}, + {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}}, + { + **any_props, + "arg1": 1013, + "current_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + "delta": {"toolUse": {"input": "{}"}}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "tool_use"}}}, + { + "message": { + "content": [ + {"text": "Invoking streaming tool"}, + {"toolUse": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}}, + ], + "role": "assistant", + } + }, + { + "message": { + "content": [ + { + "toolResult": { + # TODO update this text when we get tool streaming implemented; right now this + # TODO is of the form '' + "content": [{"text": ANY}], + "status": "success", + "toolUseId": "12345", + } + }, + ], + "role": "user", + } + }, + # Cycle 4: Final response + {"start": True}, + {"start": True}, + {"start_event_loop": True}, + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "I invoked the tools!"}}}}, + { + **any_props, + "arg1": 1013, + "data": "I invoked the tools!", + "delta": {"text": "I invoked the tools!"}, + "event_loop_parent_cycle_id": ANY, + "messages": ANY, + "model": ANY, + "system_prompt": None, + "tool_config": tool_config, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "end_turn"}}}, + {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant"}}, + { + "result": AgentResult( + stop_reason="end_turn", + message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"}, + metrics=ANY, + state={}, + ) + }, + ] + assert tru_events == exp_events + + exp_calls = [call(**event) for event in exp_events] + act_calls = mock_callback.call_args_list + assert act_calls == exp_calls + + # Ensure that all events coming out of the agent are *not* typed events + typed_events = [event for event in tru_events if isinstance(event, TypedEvent)] + assert typed_events == [] + + +@pytest.mark.asyncio +async def test_stream_e2e_throttle_and_redact(alist, mock_time): model = MagicMock() model.stream.side_effect = [ ModelThrottledException("ThrottlingException | ConverseStream"), ModelThrottledException("ThrottlingException | ConverseStream"), - mock_provider.stream([]), + MockedModelProvider( + [ + {"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"}, + ] + ).stream([]), ] mock_callback = unittest.mock.Mock() - agent = Agent(model=model, tools=[fake_tool], callback_handler=mock_callback) + agent = Agent(model=model, tools=[normal_tool], callback_handler=mock_callback) stream = agent.stream_async("Do the stuff", arg1=1013) # Base object with common properties throttle_props = { - "agent": ANY, - "event_loop_cycle_id": ANY, - "event_loop_cycle_span": ANY, - "event_loop_cycle_trace": ANY, + **any_props, "arg1": 1013, - "request_state": {}, } tru_events = await alist(stream) @@ -68,14 +353,10 @@ def fake_tool(agent: Agent): {"event": {"contentBlockStart": {"start": {}}}}, {"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}}, { - "agent": ANY, + **any_props, "arg1": 1013, "data": "INPUT BLOCKED!", "delta": {"text": "INPUT BLOCKED!"}, - "event_loop_cycle_id": ANY, - "event_loop_cycle_span": ANY, - "event_loop_cycle_trace": ANY, - "request_state": {}, }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, @@ -128,12 +409,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end( # Base object with common properties common_props = { - "agent": ANY, - "event_loop_cycle_id": ANY, - "event_loop_cycle_span": ANY, - "event_loop_cycle_trace": ANY, + **any_props, "arg1": 1013, - "request_state": {}, } exp_events = [