Skip to content

Commit f3fbd75

Browse files
authored
refactor(multiagent): Swarm - Remove unnecessary complete_swarm_task tool (#473)
1 parent 680f17a commit f3fbd75

File tree

2 files changed

+42
-75
lines changed

2 files changed

+42
-75
lines changed

src/strands/multiagent/swarm.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,6 @@ def _inject_swarm_tools(self) -> None:
327327
# Create tool functions with proper closures
328328
swarm_tools = [
329329
self._create_handoff_tool(),
330-
self._create_complete_tool(),
331330
]
332331

333332
for node in self.nodes.values():
@@ -337,8 +336,6 @@ def _inject_swarm_tools(self) -> None:
337336

338337
if "handoff_to_agent" in existing_tools:
339338
conflicting_tools.append("handoff_to_agent")
340-
if "complete_swarm_task" in existing_tools:
341-
conflicting_tools.append("complete_swarm_task")
342339

343340
if conflicting_tools:
344341
raise ValueError(
@@ -388,27 +385,6 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No
388385

389386
return handoff_to_agent
390387

391-
def _create_complete_tool(self) -> Callable[..., Any]:
392-
"""Create completion tool for task completion."""
393-
swarm_ref = self # Capture swarm reference
394-
395-
@tool
396-
def complete_swarm_task() -> dict[str, Any]:
397-
"""Mark the task as complete. No more agents will be called.
398-
399-
Returns:
400-
Task completion confirmation
401-
"""
402-
try:
403-
# Mark swarm as complete
404-
swarm_ref._handle_completion()
405-
406-
return {"status": "success", "content": [{"text": "Task completed"}]}
407-
except Exception as e:
408-
return {"status": "error", "content": [{"text": f"Error completing task: {str(e)}"}]}
409-
410-
return complete_swarm_task
411-
412388
def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[str, Any]) -> None:
413389
"""Handle handoff to another agent."""
414390
# If task is already completed, don't allow further handoffs
@@ -437,12 +413,6 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
437413
target_node.node_id,
438414
)
439415

440-
def _handle_completion(self) -> None:
441-
"""Handle task completion."""
442-
self.state.completion_status = Status.COMPLETED
443-
444-
logger.debug("swarm task completed")
445-
446416
def _build_node_input(self, target_node: SwarmNode) -> str:
447417
"""Build input text for a node based on shared context and handoffs.
448418
@@ -463,7 +433,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
463433
Agent name: code_reviewer.
464434
Agent name: security_specialist. Agent description: Focuses on secure coding practices and vulnerability assessment
465435
466-
You have access to swarm coordination tools if you need help from other agents or want to complete the task.
436+
You have access to swarm coordination tools if you need help from other agents. If you don't hand off to another agent, the swarm will consider the task complete.
467437
```
468438
""" # noqa: E501
469439
context_info: dict[str, Any] = {
@@ -511,8 +481,8 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
511481
context_text += "\n"
512482

513483
context_text += (
514-
"You have access to swarm coordination tools if you need help from other agents "
515-
"or want to complete the task."
484+
"You have access to swarm coordination tools if you need help from other agents. "
485+
"If you don't hand off to another agent, the swarm will consider the task complete."
516486
)
517487

518488
return context_text
@@ -564,9 +534,11 @@ async def _execute_swarm(self) -> None:
564534

565535
logger.debug("node=<%s> | node execution completed", current_node.node_id)
566536

567-
# Immediate check for completion after node execution
568-
if self.state.completion_status != Status.EXECUTING:
569-
logger.debug("status=<%s> | task completed with status", self.state.completion_status) # type: ignore[unreachable]
537+
# Check if the current node is still the same after execution
538+
# If it is, then no handoff occurred and we consider the swarm complete
539+
if self.state.current_node == current_node:
540+
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
541+
self.state.completion_status = Status.COMPLETED
570542
break
571543

572544
except asyncio.TimeoutError:

tests/strands/multiagent/test_swarm.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
import time
32
from unittest.mock import MagicMock, Mock, patch
43

@@ -14,9 +13,7 @@
1413
from strands.types.content import ContentBlock
1514

1615

17-
def create_mock_agent(
18-
name, response_text="Default response", metrics=None, agent_id=None, complete_after_calls=1, should_fail=False
19-
):
16+
def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None, should_fail=False):
2017
"""Create a mock Agent with specified properties."""
2118
agent = Mock(spec=Agent)
2219
agent.name = name
@@ -27,8 +24,6 @@ def create_mock_agent(
2724
agent.tool_registry.registry = {}
2825
agent.tool_registry.process_tools = Mock()
2926
agent._call_count = 0
30-
agent._complete_after = complete_after_calls
31-
agent._swarm_ref = None # Will be set by the swarm
3227
agent._should_fail = should_fail
3328
agent._session_manager = None
3429
agent.hooks = HookRegistry()
@@ -46,11 +41,6 @@ def create_mock_result():
4641
if agent._should_fail:
4742
raise Exception("Simulated agent failure")
4843

49-
# After specified calls, complete the task
50-
if agent._call_count >= agent._complete_after and agent._swarm_ref:
51-
# Directly call the completion handler
52-
agent._swarm_ref._handle_completion()
53-
5444
return AgentResult(
5545
message={"role": "assistant", "content": [{"text": response_text}]},
5646
stop_reason="end_turn",
@@ -73,9 +63,9 @@ async def mock_invoke_async(*args, **kwargs):
7363
def mock_agents():
7464
"""Create a set of mock agents for testing."""
7565
return {
76-
"coordinator": create_mock_agent("coordinator", "Coordinating task", complete_after_calls=1),
77-
"specialist": create_mock_agent("specialist", "Specialized response", complete_after_calls=1),
78-
"reviewer": create_mock_agent("reviewer", "Review complete", complete_after_calls=1),
66+
"coordinator": create_mock_agent("coordinator", "Coordinating task"),
67+
"specialist": create_mock_agent("specialist", "Specialized response"),
68+
"reviewer": create_mock_agent("reviewer", "Review complete"),
7969
}
8070

8171

@@ -91,10 +81,6 @@ def mock_swarm(mock_agents):
9181
node_timeout=10.0,
9282
)
9383

94-
# Set swarm reference on agents so they can call completion
95-
for agent in agents:
96-
agent._swarm_ref = swarm
97-
9884
return swarm
9985

10086

@@ -273,10 +259,6 @@ def test_swarm_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag
273259
node_timeout=5.0,
274260
)
275261

276-
# Set swarm reference on agents so they can call completion
277-
for agent in agents:
278-
agent._swarm_ref = swarm
279-
280262
# Test synchronous execution
281263
result = swarm("Test synchronous swarm execution")
282264

@@ -335,27 +317,25 @@ def test_swarm_builder_validation(mock_agents):
335317
with pytest.raises(ValueError, match="already has tools with names that conflict"):
336318
Swarm(nodes=[conflicting_agent])
337319

338-
# Test tool name conflicts - complete tool
339-
conflicting_complete_agent = create_mock_agent("conflicting_complete")
340-
conflicting_complete_agent.tool_registry.registry = {"complete_swarm_task": Mock()}
341-
342-
with pytest.raises(ValueError, match="already has tools with names that conflict"):
343-
Swarm(nodes=[conflicting_complete_agent])
344-
345320

346321
def test_swarm_handoff_functionality():
347322
"""Test swarm handoff functionality."""
348323

349324
# Create an agent that will hand off to another agent
350325
def create_handoff_agent(name, target_agent_name, response_text="Handing off"):
351326
"""Create a mock agent that performs handoffs."""
352-
agent = create_mock_agent(name, response_text, complete_after_calls=math.inf) # Never complete naturally
327+
agent = create_mock_agent(name, response_text)
353328
agent._handoff_done = False # Track if handoff has been performed
354329

355330
def create_handoff_result():
356331
agent._call_count += 1
357332
# Perform handoff on first execution call (not setup calls)
358-
if not agent._handoff_done and agent._swarm_ref and hasattr(agent._swarm_ref.state, "completion_status"):
333+
if (
334+
not agent._handoff_done
335+
and hasattr(agent, "_swarm_ref")
336+
and agent._swarm_ref
337+
and hasattr(agent._swarm_ref.state, "completion_status")
338+
):
359339
target_node = agent._swarm_ref.nodes.get(target_agent_name)
360340
if target_node:
361341
agent._swarm_ref._handle_handoff(
@@ -382,9 +362,9 @@ async def mock_invoke_async(*args, **kwargs):
382362
agent.invoke_async = MagicMock(side_effect=mock_invoke_async)
383363
return agent
384364

385-
# Create agents - first one hands off, second one completes
365+
# Create agents - first one hands off, second one completes by not handing off
386366
handoff_agent = create_handoff_agent("handoff_agent", "completion_agent")
387-
completion_agent = create_mock_agent("completion_agent", "Task completed", complete_after_calls=1)
367+
completion_agent = create_mock_agent("completion_agent", "Task completed")
388368

389369
# Create a swarm with reasonable limits
390370
handoff_swarm = Swarm(nodes=[handoff_agent, completion_agent], max_handoffs=10, max_iterations=10)
@@ -427,18 +407,13 @@ def test_swarm_tool_creation_and_execution():
427407
assert error_result["status"] == "error"
428408
assert "not found" in error_result["content"][0]["text"]
429409

430-
completion_tool = error_swarm._create_complete_tool()
431-
completion_result = completion_tool()
432-
assert completion_result["status"] == "success"
433-
434410

435411
def test_swarm_failure_handling(mock_strands_tracer, mock_use_span):
436412
"""Test swarm execution with agent failures."""
437413
# Test execution with agent failures
438414
failing_agent = create_mock_agent("failing_agent")
439415
failing_agent._should_fail = True # Set failure flag after creation
440416
failing_swarm = Swarm(nodes=[failing_agent], node_timeout=1.0)
441-
failing_agent._swarm_ref = failing_swarm
442417

443418
# The swarm catches exceptions internally and sets status to FAILED
444419
result = failing_swarm("Test failure handling")
@@ -451,12 +426,32 @@ def test_swarm_metrics_handling():
451426
"""Test swarm metrics handling with missing metrics."""
452427
no_metrics_agent = create_mock_agent("no_metrics", metrics=None)
453428
no_metrics_swarm = Swarm(nodes=[no_metrics_agent])
454-
no_metrics_agent._swarm_ref = no_metrics_swarm
455429

456430
result = no_metrics_swarm("Test no metrics")
457431
assert result.status == Status.COMPLETED
458432

459433

434+
def test_swarm_auto_completion_without_handoff():
435+
"""Test swarm auto-completion when no handoff occurs."""
436+
# Create a simple agent that doesn't hand off
437+
no_handoff_agent = create_mock_agent("no_handoff_agent", "Task completed without handoff")
438+
439+
# Create a swarm with just this agent
440+
auto_complete_swarm = Swarm(nodes=[no_handoff_agent])
441+
442+
# Execute swarm - this should complete automatically since there's no handoff
443+
result = auto_complete_swarm("Test auto-completion without handoff")
444+
445+
# Verify the swarm completed successfully
446+
assert result.status == Status.COMPLETED
447+
assert result.execution_count == 1
448+
assert len(result.node_history) == 1
449+
assert result.node_history[0].node_id == "no_handoff_agent"
450+
451+
# Verify the agent was called
452+
no_handoff_agent.invoke_async.assert_called()
453+
454+
460455
def test_swarm_validate_unsupported_features():
461456
"""Test Swarm validation for session persistence and callbacks."""
462457
# Test with normal agent (should work)

0 commit comments

Comments
 (0)