Skip to content

Commit 70b1d10

Browse files
authored
interrupts - multiagent - do not emit AfterNodeCallEvent on interrupt (#1539)
1 parent 78a1c28 commit 70b1d10

File tree

5 files changed

+30
-5
lines changed

5 files changed

+30
-5
lines changed

src/strands/multiagent/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
10051005
raise
10061006

10071007
finally:
1008-
await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state))
1008+
if node.execution_status != Status.INTERRUPTED:
1009+
await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state))
10091010

10101011
def _accumulate_metrics(self, node_result: NodeResult) -> None:
10111012
"""Accumulate metrics from a node result."""

src/strands/multiagent/swarm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -782,9 +782,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
782782
break
783783

784784
finally:
785-
await self.hooks.invoke_callbacks_async(
786-
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
787-
)
785+
if self.state.completion_status != Status.INTERRUPTED:
786+
await self.hooks.invoke_callbacks_async(
787+
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
788+
)
788789

789790
logger.debug("node=<%s> | node execution completed", current_node.node_id)
790791

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
import pytest
22

3-
from strands.hooks import BeforeNodeCallEvent, HookProvider
3+
from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookProvider
44

55

66
@pytest.fixture
77
def interrupt_hook():
88
class Hook(HookProvider):
9+
def __init__(self):
10+
self.after_count = 0
11+
912
def register_hooks(self, registry):
1013
registry.add_callback(BeforeNodeCallEvent, self.interrupt)
14+
registry.add_callback(AfterNodeCallEvent, self.cleanup)
1115

1216
def interrupt(self, event):
1317
return event.interrupt("test_name", reason="test_reason")
1418

19+
def cleanup(self, event):
20+
self.after_count += 1
21+
1522
return Hook()

tests/strands/multiagent/test_graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2126,6 +2126,10 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
21262126
]
21272127
assert tru_interrupts == exp_interrupts
21282128

2129+
tru_after_count = interrupt_hook.after_count
2130+
exp_after_count = 0
2131+
assert tru_after_count == exp_after_count
2132+
21292133
interrupt = multiagent_result.interrupts[0]
21302134
responses = [
21312135
{
@@ -2152,4 +2156,8 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
21522156
exp_message = "Task completed"
21532157
assert tru_message == exp_message
21542158

2159+
tru_after_count = interrupt_hook.after_count
2160+
exp_after_count = 1
2161+
assert tru_after_count == exp_after_count
2162+
21552163
assert multiagent_result.execution_time >= first_execution_time

tests/strands/multiagent/test_swarm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook):
12591259
]
12601260
assert tru_interrupts == exp_interrupts
12611261

1262+
tru_after_count = interrupt_hook.after_count
1263+
exp_after_count = 0
1264+
assert tru_after_count == exp_after_count
1265+
12621266
interrupt = multiagent_result.interrupts[0]
12631267
responses = [
12641268
{
@@ -1281,6 +1285,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook):
12811285
exp_message = "Task completed"
12821286
assert tru_message == exp_message
12831287

1288+
tru_after_count = interrupt_hook.after_count
1289+
exp_after_count = 1
1290+
assert tru_after_count == exp_after_count
1291+
12841292
assert multiagent_result.execution_time >= first_execution_time
12851293

12861294

0 commit comments

Comments
 (0)