Skip to content

Commit 76ef69d

Browse files
committed
fix: Fix tool result message event
Expand the Unit Tests for the yielded event to verify actual tool calls - previous to this, the events were not being emitted because the test was bailing out due to mocked guard rails. To better test the situation, we now have a much more extensive test for the successful tool call
1 parent 7a5caad commit 76ef69d

File tree

2 files changed

+304
-27
lines changed

2 files changed

+304
-27
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ async def _handle_tool_execution(
361361

362362
agent.messages.append(tool_result_message)
363363
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
364-
yield ToolResultMessageEvent(message=message)
364+
yield ToolResultMessageEvent(message=tool_result_message)
365365

366366
if cycle_span:
367367
tracer = get_tracer()

tests/strands/agent/hooks/test_agent_events.py

Lines changed: 303 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import unittest.mock
23
from unittest.mock import ANY, MagicMock, call
34

@@ -11,49 +12,333 @@
1112
from tests.fixtures.mocked_model_provider import MockedModelProvider
1213

1314

15+
@strands.tool
16+
def normal_tool(agent: Agent):
17+
return f"Done with synchronous {agent.name}!"
18+
19+
20+
@strands.tool
21+
async def async_tool(agent: Agent):
22+
await asyncio.sleep(0.1)
23+
return f"Done with asynchronous {agent.name}!"
24+
25+
26+
@strands.tool
27+
async def streaming_tool():
28+
await asyncio.sleep(0.2)
29+
yield {"tool_streaming": True}
30+
yield "Final result"
31+
32+
1433
@pytest.fixture
1534
def mock_time():
1635
with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock:
1736
yield mock
1837

1938

20-
@pytest.mark.asyncio
21-
async def test_stream_async_e2e(alist, mock_time):
22-
@strands.tool
23-
def fake_tool(agent: Agent):
24-
return "Done!"
39+
any_props = {
40+
"agent": ANY,
41+
"event_loop_cycle_id": ANY,
42+
"event_loop_cycle_span": ANY,
43+
"event_loop_cycle_trace": ANY,
44+
"request_state": {},
45+
}
46+
2547

48+
@pytest.mark.asyncio
49+
async def test_stream_e2e_success(alist):
2650
mock_provider = MockedModelProvider(
2751
[
28-
{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"},
29-
{"role": "assistant", "content": [{"text": "Okay invoking tool!"}]},
3052
{
3153
"role": "assistant",
32-
"content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}],
54+
"content": [
55+
{"text": "Okay invoking normal tool"},
56+
{"toolUse": {"name": "normal_tool", "toolUseId": "123", "input": {}}},
57+
],
58+
},
59+
{
60+
"role": "assistant",
61+
"content": [
62+
{"text": "Invoking async tool"},
63+
{"toolUse": {"name": "async_tool", "toolUseId": "1234", "input": {}}},
64+
],
65+
},
66+
{
67+
"role": "assistant",
68+
"content": [
69+
{"text": "Invoking streaming tool"},
70+
{"toolUse": {"name": "streaming_tool", "toolUseId": "12345", "input": {}}},
71+
],
72+
},
73+
{
74+
"role": "assistant",
75+
"content": [
76+
{"text": "I invoked the tools!"},
77+
],
3378
},
34-
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
3579
]
3680
)
81+
82+
mock_callback = unittest.mock.Mock()
83+
agent = Agent(model=mock_provider, tools=[async_tool, normal_tool, streaming_tool], callback_handler=mock_callback)
84+
85+
stream = agent.stream_async("Do the stuff", arg1=1013)
86+
87+
tool_config = {
88+
"toolChoice": {"auto": {}},
89+
"tools": [
90+
{
91+
"toolSpec": {
92+
"description": "async_tool",
93+
"inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}},
94+
"name": "async_tool",
95+
}
96+
},
97+
{
98+
"toolSpec": {
99+
"description": "normal_tool",
100+
"inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}},
101+
"name": "normal_tool",
102+
}
103+
},
104+
{
105+
"toolSpec": {
106+
"description": "streaming_tool",
107+
"inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}},
108+
"name": "streaming_tool",
109+
}
110+
},
111+
],
112+
}
113+
114+
tru_events = await alist(stream)
115+
exp_events = [
116+
# Cycle 1: Initialize and invoke normal_tool
117+
{"arg1": 1013, "init_event_loop": True},
118+
{"start": True},
119+
{"start_event_loop": True},
120+
{"event": {"messageStart": {"role": "assistant"}}},
121+
{"event": {"contentBlockStart": {"start": {}}}},
122+
{"event": {"contentBlockDelta": {"delta": {"text": "Okay invoking normal tool"}}}},
123+
{
124+
**any_props,
125+
"arg1": 1013,
126+
"data": "Okay invoking normal tool",
127+
"delta": {"text": "Okay invoking normal tool"},
128+
},
129+
{"event": {"contentBlockStop": {}}},
130+
{"event": {"contentBlockStart": {"start": {"toolUse": {"name": "normal_tool", "toolUseId": "123"}}}}},
131+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}},
132+
{
133+
**any_props,
134+
"arg1": 1013,
135+
"current_tool_use": {"input": {}, "name": "normal_tool", "toolUseId": "123"},
136+
"delta": {"toolUse": {"input": "{}"}},
137+
},
138+
{"event": {"contentBlockStop": {}}},
139+
{"event": {"messageStop": {"stopReason": "tool_use"}}},
140+
{
141+
"message": {
142+
"content": [
143+
{"text": "Okay invoking normal tool"},
144+
{"toolUse": {"input": {}, "name": "normal_tool", "toolUseId": "123"}},
145+
],
146+
"role": "assistant",
147+
}
148+
},
149+
{
150+
"message": {
151+
"content": [
152+
{
153+
"toolResult": {
154+
"content": [{"text": "Done with synchronous Strands Agents!"}],
155+
"status": "success",
156+
"toolUseId": "123",
157+
}
158+
},
159+
],
160+
"role": "user",
161+
}
162+
},
163+
# Cycle 2: Invoke async_tool
164+
{"start": True},
165+
{"start": True},
166+
{"start_event_loop": True},
167+
{"event": {"messageStart": {"role": "assistant"}}},
168+
{"event": {"contentBlockStart": {"start": {}}}},
169+
{"event": {"contentBlockDelta": {"delta": {"text": "Invoking async tool"}}}},
170+
{
171+
**any_props,
172+
"arg1": 1013,
173+
"data": "Invoking async tool",
174+
"delta": {"text": "Invoking async tool"},
175+
"event_loop_parent_cycle_id": ANY,
176+
"messages": ANY,
177+
"model": ANY,
178+
"system_prompt": None,
179+
"tool_config": tool_config,
180+
},
181+
{"event": {"contentBlockStop": {}}},
182+
{"event": {"contentBlockStart": {"start": {"toolUse": {"name": "async_tool", "toolUseId": "1234"}}}}},
183+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}},
184+
{
185+
**any_props,
186+
"arg1": 1013,
187+
"current_tool_use": {"input": {}, "name": "async_tool", "toolUseId": "1234"},
188+
"delta": {"toolUse": {"input": "{}"}},
189+
"event_loop_parent_cycle_id": ANY,
190+
"messages": ANY,
191+
"model": ANY,
192+
"system_prompt": None,
193+
"tool_config": tool_config,
194+
},
195+
{"event": {"contentBlockStop": {}}},
196+
{"event": {"messageStop": {"stopReason": "tool_use"}}},
197+
{
198+
"message": {
199+
"content": [
200+
{"text": "Invoking async tool"},
201+
{"toolUse": {"input": {}, "name": "async_tool", "toolUseId": "1234"}},
202+
],
203+
"role": "assistant",
204+
}
205+
},
206+
{
207+
"message": {
208+
"content": [
209+
{
210+
"toolResult": {
211+
"content": [{"text": "Done with asynchronous Strands Agents!"}],
212+
"status": "success",
213+
"toolUseId": "1234",
214+
}
215+
},
216+
],
217+
"role": "user",
218+
}
219+
},
220+
# Cycle 3: Invoke streaming_tool
221+
{"start": True},
222+
{"start": True},
223+
{"start_event_loop": True},
224+
{"event": {"messageStart": {"role": "assistant"}}},
225+
{"event": {"contentBlockStart": {"start": {}}}},
226+
{"event": {"contentBlockDelta": {"delta": {"text": "Invoking streaming tool"}}}},
227+
{
228+
**any_props,
229+
"arg1": 1013,
230+
"data": "Invoking streaming tool",
231+
"delta": {"text": "Invoking streaming tool"},
232+
"event_loop_parent_cycle_id": ANY,
233+
"messages": ANY,
234+
"model": ANY,
235+
"system_prompt": None,
236+
"tool_config": tool_config,
237+
},
238+
{"event": {"contentBlockStop": {}}},
239+
{"event": {"contentBlockStart": {"start": {"toolUse": {"name": "streaming_tool", "toolUseId": "12345"}}}}},
240+
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}},
241+
{
242+
**any_props,
243+
"arg1": 1013,
244+
"current_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
245+
"delta": {"toolUse": {"input": "{}"}},
246+
"event_loop_parent_cycle_id": ANY,
247+
"messages": ANY,
248+
"model": ANY,
249+
"system_prompt": None,
250+
"tool_config": tool_config,
251+
},
252+
{"event": {"contentBlockStop": {}}},
253+
{"event": {"messageStop": {"stopReason": "tool_use"}}},
254+
{
255+
"message": {
256+
"content": [
257+
{"text": "Invoking streaming tool"},
258+
{"toolUse": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}},
259+
],
260+
"role": "assistant",
261+
}
262+
},
263+
{
264+
"message": {
265+
"content": [
266+
{
267+
"toolResult": {
268+
# TODO update this text when we get tool streaming implemented; right now this
269+
# TODO is of the form '<async_generator object streaming_tool at 0x107d18a00>'
270+
"content": [{"text": ANY}],
271+
"status": "success",
272+
"toolUseId": "12345",
273+
}
274+
},
275+
],
276+
"role": "user",
277+
}
278+
},
279+
# Cycle 4: Final response
280+
{"start": True},
281+
{"start": True},
282+
{"start_event_loop": True},
283+
{"event": {"messageStart": {"role": "assistant"}}},
284+
{"event": {"contentBlockStart": {"start": {}}}},
285+
{"event": {"contentBlockDelta": {"delta": {"text": "I invoked the tools!"}}}},
286+
{
287+
**any_props,
288+
"arg1": 1013,
289+
"data": "I invoked the tools!",
290+
"delta": {"text": "I invoked the tools!"},
291+
"event_loop_parent_cycle_id": ANY,
292+
"messages": ANY,
293+
"model": ANY,
294+
"system_prompt": None,
295+
"tool_config": tool_config,
296+
},
297+
{"event": {"contentBlockStop": {}}},
298+
{"event": {"messageStop": {"stopReason": "end_turn"}}},
299+
{"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant"}},
300+
{
301+
"result": AgentResult(
302+
stop_reason="end_turn",
303+
message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"},
304+
metrics=ANY,
305+
state={},
306+
)
307+
},
308+
]
309+
assert tru_events == exp_events
310+
311+
exp_calls = [call(**event) for event in exp_events]
312+
act_calls = mock_callback.call_args_list
313+
assert act_calls == exp_calls
314+
315+
# Ensure that all events coming out of the agent are *not* typed events
316+
typed_events = [event for event in tru_events if isinstance(event, TypedEvent)]
317+
assert typed_events == []
318+
319+
320+
@pytest.mark.asyncio
321+
async def test_stream_e2e_throttle_and_redact(alist, mock_time):
37322
model = MagicMock()
38323
model.stream.side_effect = [
39324
ModelThrottledException("ThrottlingException | ConverseStream"),
40325
ModelThrottledException("ThrottlingException | ConverseStream"),
41-
mock_provider.stream([]),
326+
MockedModelProvider(
327+
[
328+
{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"},
329+
]
330+
).stream([]),
42331
]
43332

44333
mock_callback = unittest.mock.Mock()
45-
agent = Agent(model=model, tools=[fake_tool], callback_handler=mock_callback)
334+
agent = Agent(model=model, tools=[normal_tool], callback_handler=mock_callback)
46335

47336
stream = agent.stream_async("Do the stuff", arg1=1013)
48337

49338
# Base object with common properties
50339
throttle_props = {
51-
"agent": ANY,
52-
"event_loop_cycle_id": ANY,
53-
"event_loop_cycle_span": ANY,
54-
"event_loop_cycle_trace": ANY,
340+
**any_props,
55341
"arg1": 1013,
56-
"request_state": {},
57342
}
58343

59344
tru_events = await alist(stream)
@@ -68,14 +353,10 @@ def fake_tool(agent: Agent):
68353
{"event": {"contentBlockStart": {"start": {}}}},
69354
{"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}},
70355
{
71-
"agent": ANY,
356+
**any_props,
72357
"arg1": 1013,
73358
"data": "INPUT BLOCKED!",
74359
"delta": {"text": "INPUT BLOCKED!"},
75-
"event_loop_cycle_id": ANY,
76-
"event_loop_cycle_span": ANY,
77-
"event_loop_cycle_trace": ANY,
78-
"request_state": {},
79360
},
80361
{"event": {"contentBlockStop": {}}},
81362
{"event": {"messageStop": {"stopReason": "guardrail_intervened"}}},
@@ -128,12 +409,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end(
128409

129410
# Base object with common properties
130411
common_props = {
131-
"agent": ANY,
132-
"event_loop_cycle_id": ANY,
133-
"event_loop_cycle_span": ANY,
134-
"event_loop_cycle_trace": ANY,
412+
**any_props,
135413
"arg1": 1013,
136-
"request_state": {},
137414
}
138415

139416
exp_events = [

0 commit comments

Comments
 (0)