Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ async def _handle_tool_execution(

agent.messages.append(tool_result_message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
yield ToolResultMessageEvent(message=message)
yield ToolResultMessageEvent(message=tool_result_message)

if cycle_span:
tracer = get_tracer()
Expand Down
329 changes: 303 additions & 26 deletions tests/strands/agent/hooks/test_agent_events.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import unittest.mock
from unittest.mock import ANY, MagicMock, call

Expand All @@ -11,49 +12,333 @@
from tests.fixtures.mocked_model_provider import MockedModelProvider


@strands.tool
def normal_tool(agent: Agent):
return f"Done with synchronous {agent.name}!"


@strands.tool
async def async_tool(agent: Agent):
await asyncio.sleep(0.1)
return f"Done with asynchronous {agent.name}!"


@strands.tool
async def streaming_tool():
await asyncio.sleep(0.2)
yield {"tool_streaming": True}
yield "Final result"


@pytest.fixture
def mock_time():
with unittest.mock.patch.object(strands.event_loop.event_loop, "time") as mock:
yield mock


@pytest.mark.asyncio
async def test_stream_async_e2e(alist, mock_time):
@strands.tool
def fake_tool(agent: Agent):
return "Done!"
any_props = {
"agent": ANY,
"event_loop_cycle_id": ANY,
"event_loop_cycle_span": ANY,
"event_loop_cycle_trace": ANY,
"request_state": {},
}


@pytest.mark.asyncio
async def test_stream_e2e_success(alist):
mock_provider = MockedModelProvider(
[
{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"},
{"role": "assistant", "content": [{"text": "Okay invoking tool!"}]},
{
"role": "assistant",
"content": [{"toolUse": {"name": "fake_tool", "toolUseId": "123", "input": {}}}],
"content": [
{"text": "Okay invoking normal tool"},
{"toolUse": {"name": "normal_tool", "toolUseId": "123", "input": {}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Invoking async tool"},
{"toolUse": {"name": "async_tool", "toolUseId": "1234", "input": {}}},
],
},
{
"role": "assistant",
"content": [
{"text": "Invoking streaming tool"},
{"toolUse": {"name": "streaming_tool", "toolUseId": "12345", "input": {}}},
],
},
{
"role": "assistant",
"content": [
{"text": "I invoked the tools!"},
],
},
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
]
)

mock_callback = unittest.mock.Mock()
agent = Agent(model=mock_provider, tools=[async_tool, normal_tool, streaming_tool], callback_handler=mock_callback)

stream = agent.stream_async("Do the stuff", arg1=1013)

tool_config = {
"toolChoice": {"auto": {}},
"tools": [
{
"toolSpec": {
"description": "async_tool",
"inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}},
"name": "async_tool",
}
},
{
"toolSpec": {
"description": "normal_tool",
"inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}},
"name": "normal_tool",
}
},
{
"toolSpec": {
"description": "streaming_tool",
"inputSchema": {"json": {"properties": {}, "required": [], "type": "object"}},
"name": "streaming_tool",
}
},
],
}

tru_events = await alist(stream)
exp_events = [
# Cycle 1: Initialize and invoke normal_tool
{"arg1": 1013, "init_event_loop": True},
{"start": True},
{"start_event_loop": True},
{"event": {"messageStart": {"role": "assistant"}}},
{"event": {"contentBlockStart": {"start": {}}}},
{"event": {"contentBlockDelta": {"delta": {"text": "Okay invoking normal tool"}}}},
{
**any_props,
"arg1": 1013,
"data": "Okay invoking normal tool",
"delta": {"text": "Okay invoking normal tool"},
},
{"event": {"contentBlockStop": {}}},
{"event": {"contentBlockStart": {"start": {"toolUse": {"name": "normal_tool", "toolUseId": "123"}}}}},
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}},
{
**any_props,
"arg1": 1013,
"current_tool_use": {"input": {}, "name": "normal_tool", "toolUseId": "123"},
"delta": {"toolUse": {"input": "{}"}},
},
{"event": {"contentBlockStop": {}}},
{"event": {"messageStop": {"stopReason": "tool_use"}}},
{
"message": {
"content": [
{"text": "Okay invoking normal tool"},
{"toolUse": {"input": {}, "name": "normal_tool", "toolUseId": "123"}},
],
"role": "assistant",
}
},
{
"message": {
"content": [
{
"toolResult": {
"content": [{"text": "Done with synchronous Strands Agents!"}],
"status": "success",
"toolUseId": "123",
}
},
],
"role": "user",
}
},
# Cycle 2: Invoke async_tool
{"start": True},
{"start": True},
{"start_event_loop": True},
{"event": {"messageStart": {"role": "assistant"}}},
{"event": {"contentBlockStart": {"start": {}}}},
{"event": {"contentBlockDelta": {"delta": {"text": "Invoking async tool"}}}},
{
**any_props,
"arg1": 1013,
"data": "Invoking async tool",
"delta": {"text": "Invoking async tool"},
"event_loop_parent_cycle_id": ANY,
"messages": ANY,
"model": ANY,
"system_prompt": None,
"tool_config": tool_config,
},
{"event": {"contentBlockStop": {}}},
{"event": {"contentBlockStart": {"start": {"toolUse": {"name": "async_tool", "toolUseId": "1234"}}}}},
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}},
{
**any_props,
"arg1": 1013,
"current_tool_use": {"input": {}, "name": "async_tool", "toolUseId": "1234"},
"delta": {"toolUse": {"input": "{}"}},
"event_loop_parent_cycle_id": ANY,
"messages": ANY,
"model": ANY,
"system_prompt": None,
"tool_config": tool_config,
},
{"event": {"contentBlockStop": {}}},
{"event": {"messageStop": {"stopReason": "tool_use"}}},
{
"message": {
"content": [
{"text": "Invoking async tool"},
{"toolUse": {"input": {}, "name": "async_tool", "toolUseId": "1234"}},
],
"role": "assistant",
}
},
{
"message": {
"content": [
{
"toolResult": {
"content": [{"text": "Done with asynchronous Strands Agents!"}],
"status": "success",
"toolUseId": "1234",
}
},
],
"role": "user",
}
},
# Cycle 3: Invoke streaming_tool
{"start": True},
{"start": True},
{"start_event_loop": True},
{"event": {"messageStart": {"role": "assistant"}}},
{"event": {"contentBlockStart": {"start": {}}}},
{"event": {"contentBlockDelta": {"delta": {"text": "Invoking streaming tool"}}}},
{
**any_props,
"arg1": 1013,
"data": "Invoking streaming tool",
"delta": {"text": "Invoking streaming tool"},
"event_loop_parent_cycle_id": ANY,
"messages": ANY,
"model": ANY,
"system_prompt": None,
"tool_config": tool_config,
},
{"event": {"contentBlockStop": {}}},
{"event": {"contentBlockStart": {"start": {"toolUse": {"name": "streaming_tool", "toolUseId": "12345"}}}}},
{"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}},
{
**any_props,
"arg1": 1013,
"current_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
"delta": {"toolUse": {"input": "{}"}},
"event_loop_parent_cycle_id": ANY,
"messages": ANY,
"model": ANY,
"system_prompt": None,
"tool_config": tool_config,
},
{"event": {"contentBlockStop": {}}},
{"event": {"messageStop": {"stopReason": "tool_use"}}},
{
"message": {
"content": [
{"text": "Invoking streaming tool"},
{"toolUse": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}},
],
"role": "assistant",
}
},
{
"message": {
"content": [
{
"toolResult": {
# TODO update this text when we get tool streaming implemented; right now this
# TODO is of the form '<async_generator object streaming_tool at 0x107d18a00>'
"content": [{"text": ANY}],
"status": "success",
"toolUseId": "12345",
}
},
],
"role": "user",
}
},
# Cycle 4: Final response
{"start": True},
{"start": True},
{"start_event_loop": True},
{"event": {"messageStart": {"role": "assistant"}}},
{"event": {"contentBlockStart": {"start": {}}}},
{"event": {"contentBlockDelta": {"delta": {"text": "I invoked the tools!"}}}},
{
**any_props,
"arg1": 1013,
"data": "I invoked the tools!",
"delta": {"text": "I invoked the tools!"},
"event_loop_parent_cycle_id": ANY,
"messages": ANY,
"model": ANY,
"system_prompt": None,
"tool_config": tool_config,
},
{"event": {"contentBlockStop": {}}},
{"event": {"messageStop": {"stopReason": "end_turn"}}},
{"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant"}},
{
"result": AgentResult(
stop_reason="end_turn",
message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"},
metrics=ANY,
state={},
)
},
]
assert tru_events == exp_events

exp_calls = [call(**event) for event in exp_events]
act_calls = mock_callback.call_args_list
assert act_calls == exp_calls

# Ensure that all events coming out of the agent are *not* typed events
typed_events = [event for event in tru_events if isinstance(event, TypedEvent)]
assert typed_events == []


@pytest.mark.asyncio
async def test_stream_e2e_throttle_and_redact(alist, mock_time):
model = MagicMock()
model.stream.side_effect = [
ModelThrottledException("ThrottlingException | ConverseStream"),
ModelThrottledException("ThrottlingException | ConverseStream"),
mock_provider.stream([]),
MockedModelProvider(
[
{"redactedUserContent": "BLOCKED!", "redactedAssistantContent": "INPUT BLOCKED!"},
]
).stream([]),
]

mock_callback = unittest.mock.Mock()
agent = Agent(model=model, tools=[fake_tool], callback_handler=mock_callback)
agent = Agent(model=model, tools=[normal_tool], callback_handler=mock_callback)

stream = agent.stream_async("Do the stuff", arg1=1013)

# Base object with common properties
throttle_props = {
"agent": ANY,
"event_loop_cycle_id": ANY,
"event_loop_cycle_span": ANY,
"event_loop_cycle_trace": ANY,
**any_props,
"arg1": 1013,
"request_state": {},
}

tru_events = await alist(stream)
Expand All @@ -68,14 +353,10 @@ def fake_tool(agent: Agent):
{"event": {"contentBlockStart": {"start": {}}}},
{"event": {"contentBlockDelta": {"delta": {"text": "INPUT BLOCKED!"}}}},
{
"agent": ANY,
**any_props,
"arg1": 1013,
"data": "INPUT BLOCKED!",
"delta": {"text": "INPUT BLOCKED!"},
"event_loop_cycle_id": ANY,
"event_loop_cycle_span": ANY,
"event_loop_cycle_trace": ANY,
"request_state": {},
},
{"event": {"contentBlockStop": {}}},
{"event": {"messageStop": {"stopReason": "guardrail_intervened"}}},
Expand Down Expand Up @@ -128,12 +409,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end(

# Base object with common properties
common_props = {
"agent": ANY,
"event_loop_cycle_id": ANY,
"event_loop_cycle_span": ANY,
"event_loop_cycle_trace": ANY,
**any_props,
"arg1": 1013,
"request_state": {},
}

exp_events = [
Expand Down
Loading