Skip to content

Commit 94b41b4

Browse files
authored
feat: Enable hooks for MultiAgents (#760)
It's been a customer ask and we don't have a pressing need to keep it restricted. The primary concern is that because agent's state is manipulated between invocations (state is reset) hooks designed for a single agent may not work for multi-agents. With documentation, we can guide folks to be aware of what happens rather than restricting it outright. --------- Co-authored-by: Mackenzie Zastrow <[email protected]>
1 parent 47faba0 commit 94b41b4

File tree

7 files changed

+106
-55
lines changed

7 files changed

+106
-55
lines changed

src/strands/multiagent/graph.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def _validate_node_executor(
201201
if executor._session_manager is not None:
202202
raise ValueError("Session persistence is not supported for Graph agents yet.")
203203

204-
# Check for callbacks
205-
if executor.hooks.has_callbacks():
206-
raise ValueError("Agent callbacks are not supported for Graph agents yet.")
207-
208204

209205
class GraphBuilder:
210206
"""Builder pattern for constructing graphs."""

src/strands/multiagent/swarm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,6 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
318318
if node._session_manager is not None:
319319
raise ValueError("Session persistence is not supported for Swarm agents yet.")
320320

321-
# Check for callbacks
322-
if node.hooks.has_callbacks():
323-
raise ValueError("Agent callbacks are not supported for Swarm agents yet.")
324-
325321
def _inject_swarm_tools(self) -> None:
326322
"""Add swarm coordination tools to each agent."""
327323
# Create tool functions with proper closures

tests/fixtures/mock_hook_provider.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,44 @@
1-
from typing import Iterator, Tuple, Type
1+
from typing import Iterator, Literal, Tuple, Type
22

3-
from strands.hooks import HookEvent, HookProvider, HookRegistry
3+
from strands import Agent
4+
from strands.experimental.hooks import (
5+
AfterModelInvocationEvent,
6+
AfterToolInvocationEvent,
7+
BeforeModelInvocationEvent,
8+
BeforeToolInvocationEvent,
9+
)
10+
from strands.hooks import (
11+
AfterInvocationEvent,
12+
AgentInitializedEvent,
13+
BeforeInvocationEvent,
14+
HookEvent,
15+
HookProvider,
16+
HookRegistry,
17+
MessageAddedEvent,
18+
)
419

520

621
class MockHookProvider(HookProvider):
7-
def __init__(self, event_types: list[Type]):
22+
def __init__(self, event_types: list[Type] | Literal["all"]):
23+
if event_types == "all":
24+
event_types = [
25+
AgentInitializedEvent,
26+
BeforeInvocationEvent,
27+
AfterInvocationEvent,
28+
AfterToolInvocationEvent,
29+
BeforeToolInvocationEvent,
30+
BeforeModelInvocationEvent,
31+
AfterModelInvocationEvent,
32+
MessageAddedEvent,
33+
]
34+
835
self.events_received = []
936
self.events_types = event_types
1037

38+
@property
39+
def event_types_received(self):
40+
return [type(event) for event in self.events_received]
41+
1142
def get_events(self) -> Tuple[int, Iterator[HookEvent]]:
1243
return len(self.events_received), iter(self.events_received)
1344

@@ -17,3 +48,11 @@ def register_hooks(self, registry: HookRegistry) -> None:
1748

1849
def add_event(self, event: HookEvent) -> None:
1950
self.events_received.append(event)
51+
52+
def extract_for(self, agent: Agent) -> "MockHookProvider":
53+
"""Extracts a hook provider for the given agent, including the events that were fired for that agent.
54+
55+
Convenience method when sharing a hook provider between multiple agents."""
56+
child_provider = MockHookProvider(self.events_types)
57+
child_provider.events_received = [event for event in self.events_received if event.agent == agent]
58+
return child_provider

tests/strands/multiagent/test_graph.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -873,15 +873,6 @@ class TestHookProvider(HookProvider):
873873
def register_hooks(self, registry, **kwargs):
874874
registry.add_callback(AgentInitializedEvent, lambda e: None)
875875

876-
agent_with_hooks = create_mock_agent("agent_with_hooks")
877-
agent_with_hooks._session_manager = None
878-
agent_with_hooks.hooks = HookRegistry()
879-
agent_with_hooks.hooks.add_hook(TestHookProvider())
880-
881-
builder = GraphBuilder()
882-
with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"):
883-
builder.add_node(agent_with_hooks)
884-
885876
# Test validation in Graph constructor (when nodes are passed directly)
886877
# Test with session manager in Graph constructor
887878
node_with_session = GraphNode("node_with_session", agent_with_session)
@@ -892,15 +883,6 @@ def register_hooks(self, registry, **kwargs):
892883
entry_points=set(),
893884
)
894885

895-
# Test with callbacks in Graph constructor
896-
node_with_hooks = GraphNode("node_with_hooks", agent_with_hooks)
897-
with pytest.raises(ValueError, match="Agent callbacks are not supported for Graph agents yet"):
898-
Graph(
899-
nodes={"node_with_hooks": node_with_hooks},
900-
edges=set(),
901-
entry_points=set(),
902-
)
903-
904886

905887
@pytest.mark.asyncio
906888
async def test_controlled_cyclic_execution():

tests/strands/multiagent/test_swarm.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
from strands.agent import Agent, AgentResult
77
from strands.agent.state import AgentState
8-
from strands.hooks import AgentInitializedEvent
9-
from strands.hooks.registry import HookProvider, HookRegistry
8+
from strands.hooks.registry import HookRegistry
109
from strands.multiagent.base import Status
1110
from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState
1211
from strands.session.session_manager import SessionManager
@@ -470,16 +469,3 @@ def test_swarm_validate_unsupported_features():
470469

471470
with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"):
472471
Swarm([agent_with_session])
473-
474-
# Test with callbacks (should fail)
475-
class TestHookProvider(HookProvider):
476-
def register_hooks(self, registry, **kwargs):
477-
registry.add_callback(AgentInitializedEvent, lambda e: None)
478-
479-
agent_with_hooks = create_mock_agent("agent_with_hooks")
480-
agent_with_hooks._session_manager = None
481-
agent_with_hooks.hooks = HookRegistry()
482-
agent_with_hooks.hooks.add_hook(TestHookProvider())
483-
484-
with pytest.raises(ValueError, match="Agent callbacks are not supported for Swarm agents yet"):
485-
Swarm([agent_with_hooks])

tests_integ/test_multiagent_graph.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import pytest
22

33
from strands import Agent, tool
4+
from strands.experimental.hooks import AfterModelInvocationEvent, BeforeModelInvocationEvent
5+
from strands.hooks import AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, MessageAddedEvent
46
from strands.multiagent.graph import GraphBuilder
57
from strands.types.content import ContentBlock
8+
from tests.fixtures.mock_hook_provider import MockHookProvider
69

710

811
@tool
@@ -18,49 +21,59 @@ def multiply_numbers(x: int, y: int) -> int:
1821

1922

2023
@pytest.fixture
21-
def math_agent():
24+
def hook_provider():
25+
return MockHookProvider("all")
26+
27+
28+
@pytest.fixture
29+
def math_agent(hook_provider):
2230
"""Create an agent specialized in mathematical operations."""
2331
return Agent(
2432
model="us.amazon.nova-pro-v1:0",
2533
system_prompt="You are a mathematical assistant. Always provide clear, step-by-step calculations.",
34+
hooks=[hook_provider],
2635
tools=[calculate_sum, multiply_numbers],
2736
)
2837

2938

3039
@pytest.fixture
31-
def analysis_agent():
40+
def analysis_agent(hook_provider):
3241
"""Create an agent specialized in data analysis."""
3342
return Agent(
3443
model="us.amazon.nova-pro-v1:0",
44+
hooks=[hook_provider],
3545
system_prompt="You are a data analysis expert. Provide insights and interpretations of numerical results.",
3646
)
3747

3848

3949
@pytest.fixture
40-
def summary_agent():
50+
def summary_agent(hook_provider):
4151
"""Create an agent specialized in summarization."""
4252
return Agent(
4353
model="us.amazon.nova-lite-v1:0",
54+
hooks=[hook_provider],
4455
system_prompt="You are a summarization expert. Create concise, clear summaries of complex information.",
4556
)
4657

4758

4859
@pytest.fixture
49-
def validation_agent():
60+
def validation_agent(hook_provider):
5061
"""Create an agent specialized in validation."""
5162
return Agent(
5263
model="us.amazon.nova-pro-v1:0",
64+
hooks=[hook_provider],
5365
system_prompt="You are a validation expert. Check results for accuracy and completeness.",
5466
)
5567

5668

5769
@pytest.fixture
58-
def image_analysis_agent():
70+
def image_analysis_agent(hook_provider):
5971
"""Create an agent specialized in image analysis."""
6072
return Agent(
73+
hooks=[hook_provider],
6174
system_prompt=(
6275
"You are an image analysis expert. Describe what you see in images and provide detailed analysis."
63-
)
76+
),
6477
)
6578

6679

@@ -149,7 +162,7 @@ def proceed_to_second_summary(state):
149162

150163

151164
@pytest.mark.asyncio
152-
async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img):
165+
async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img, hook_provider):
153166
"""Test graph execution with multi-modal image input."""
154167
builder = GraphBuilder()
155168

@@ -186,3 +199,16 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y
186199
# Verify both nodes completed
187200
assert "image_analyzer" in result.results
188201
assert "summarizer" in result.results
202+
203+
expected_hook_events = [
204+
AgentInitializedEvent,
205+
BeforeInvocationEvent,
206+
MessageAddedEvent,
207+
BeforeModelInvocationEvent,
208+
AfterModelInvocationEvent,
209+
MessageAddedEvent,
210+
AfterInvocationEvent,
211+
]
212+
213+
assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events
214+
assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events

tests_integ/test_multiagent_swarm.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
import pytest
22

33
from strands import Agent, tool
4+
from strands.experimental.hooks import (
5+
AfterModelInvocationEvent,
6+
AfterToolInvocationEvent,
7+
BeforeModelInvocationEvent,
8+
BeforeToolInvocationEvent,
9+
)
10+
from strands.hooks import AfterInvocationEvent, BeforeInvocationEvent, MessageAddedEvent
411
from strands.multiagent.swarm import Swarm
512
from strands.types.content import ContentBlock
13+
from tests.fixtures.mock_hook_provider import MockHookProvider
614

715

816
@tool
@@ -22,44 +30,52 @@ def calculate(expression: str) -> str:
2230

2331

2432
@pytest.fixture
25-
def researcher_agent():
33+
def hook_provider():
34+
return MockHookProvider("all")
35+
36+
37+
@pytest.fixture
38+
def researcher_agent(hook_provider):
2639
"""Create an agent specialized in research."""
2740
return Agent(
2841
name="researcher",
2942
system_prompt=(
3043
"You are a research specialist who excels at finding information. When you need to perform calculations or"
3144
" format documents, hand off to the appropriate specialist."
3245
),
46+
hooks=[hook_provider],
3347
tools=[web_search],
3448
)
3549

3650

3751
@pytest.fixture
38-
def analyst_agent():
52+
def analyst_agent(hook_provider):
3953
"""Create an agent specialized in data analysis."""
4054
return Agent(
4155
name="analyst",
4256
system_prompt=(
4357
"You are a data analyst who excels at calculations and numerical analysis. When you need"
4458
" research or document formatting, hand off to the appropriate specialist."
4559
),
60+
hooks=[hook_provider],
4661
tools=[calculate],
4762
)
4863

4964

5065
@pytest.fixture
51-
def writer_agent():
66+
def writer_agent(hook_provider):
5267
"""Create an agent specialized in writing and formatting."""
5368
return Agent(
5469
name="writer",
70+
hooks=[hook_provider],
5571
system_prompt=(
5672
"You are a professional writer who excels at formatting and presenting information. When you need research"
5773
" or calculations, hand off to the appropriate specialist."
5874
),
5975
)
6076

6177

62-
def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent):
78+
def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider):
6379
"""Test swarm execution with string input."""
6480
# Create the swarm
6581
swarm = Swarm([researcher_agent, analyst_agent, writer_agent])
@@ -82,6 +98,16 @@ def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_age
8298
# Verify agent history - at least one agent should have been used
8399
assert len(result.node_history) > 0
84100

101+
# Just ensure that hooks are emitted; actual content is not verified
102+
researcher_hooks = hook_provider.extract_for(researcher_agent).event_types_received
103+
assert BeforeInvocationEvent in researcher_hooks
104+
assert MessageAddedEvent in researcher_hooks
105+
assert BeforeModelInvocationEvent in researcher_hooks
106+
assert BeforeToolInvocationEvent in researcher_hooks
107+
assert AfterToolInvocationEvent in researcher_hooks
108+
assert AfterModelInvocationEvent in researcher_hooks
109+
assert AfterInvocationEvent in researcher_hooks
110+
85111

86112
@pytest.mark.asyncio
87113
async def test_swarm_execution_with_image(researcher_agent, analyst_agent, writer_agent, yellow_img):

0 commit comments

Comments
 (0)