Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union
from typing import Any, AsyncIterator, Union

from ..agent import AgentResult
from ..types.content import ContentBlock
Expand Down Expand Up @@ -97,6 +97,32 @@ async def invoke_async(
"""
raise NotImplementedError("invoke_async not implemented")

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during multi-agent execution.

This default implementation provides backward compatibility by executing
invoke_async and yielding a single result event. Subclasses can override
this method to provide true streaming capabilities.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.

Yields:
Dictionary events containing multi-agent execution information including:
- Multi-agent coordination events (node start/complete, handoffs)
- Forwarded single-agent events with node context
- Final result event
"""
# Default implementation for backward compatibility
# Execute invoke_async and yield the result as a single event
result = await self.invoke_async(task, invocation_state, **kwargs)
yield {"result": result}

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
Expand Down
300 changes: 222 additions & 78 deletions src/strands/multiagent/graph.py

Large diffs are not rendered by default.

166 changes: 133 additions & 33 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple
from typing import Any, AsyncIterator, Callable, Tuple, cast

from opentelemetry import trace as trace_api

from ..agent import Agent, AgentResult
from ..agent import Agent
from ..agent.state import AgentState
from ..telemetry import get_tracer
from ..tools.decorator import tool
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeCompleteEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStreamEvent,
MultiAgentResultEvent,
)
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
Expand Down Expand Up @@ -266,12 +273,39 @@ async def invoke_async(
) -> SwarmResult:
"""Invoke the swarm asynchronously.

This method uses stream_async internally and consumes all events until completion,
following the same pattern as the Agent class.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
events = self.stream_async(task, invocation_state, **kwargs)
async for event in events:
_ = event

return cast(SwarmResult, event["result"])

async def stream_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[dict[str, Any]]:
"""Stream events during swarm execution.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues - a new empty dict
is created if None is provided.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.

Yields:
Dictionary events during swarm execution, such as:
- multi_agent_node_start: When a node begins execution
- multi_agent_node_stream: Forwarded agent events with node context
- multi_agent_handoff: When control is handed off between agents
- multi_agent_node_complete: When a node completes execution
- result: Final swarm result
"""
if invocation_state is None:
invocation_state = {}
Expand All @@ -282,7 +316,7 @@ async def invoke_async(
if self.entry_point:
initial_node = self.nodes[str(self.entry_point.name)]
else:
initial_node = next(iter(self.nodes.values())) # First SwarmNode
initial_node = next(iter(self.nodes.values()))

self.state = SwarmState(
current_node=initial_node,
Expand All @@ -303,15 +337,50 @@ async def invoke_async(
self.execution_timeout,
)

await self._execute_swarm(invocation_state)
async for event in self._execute_swarm(invocation_state):
yield event.as_dict()

# Yield final result (consistent with Agent's AgentResultEvent format)
result = self._build_result()
yield MultiAgentResultEvent(result=result).as_dict()

except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
raise
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)

return self._build_result()
async def _stream_with_timeout(
self, async_generator: AsyncIterator[Any], timeout: float | None, timeout_message: str
) -> AsyncIterator[Any]:
"""Wrap an async generator with timeout functionality.

Args:
async_generator: The generator to wrap
timeout: Timeout in seconds, or None for no timeout
timeout_message: Message to include in timeout exception

Yields:
Events from the wrapped generator

Raises:
Exception: If timeout is exceeded (same as original behavior)
"""
if timeout is None:
# No timeout - just pass through
async for event in async_generator:
yield event
else:
# Apply timeout to each event
while True:
try:
event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout)
yield event
except StopAsyncIteration:
break
except asyncio.TimeoutError:
raise Exception(timeout_message) from None

def _setup_swarm(self, nodes: list[Agent]) -> None:
"""Initialize swarm configuration."""
Expand Down Expand Up @@ -533,14 +602,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str:

return context_text

async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
"""Shared execution logic used by execute_async."""
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute swarm and yield TypedEvent objects."""
try:
# Main execution loop
while True:
if self.state.completion_status != Status.EXECUTING:
reason = f"Completion status is: {self.state.completion_status}"
logger.debug("reason=<%s> | stopping execution", reason)
logger.debug("reason=<%s> | stopping streaming execution", reason)
break

should_continue, reason = self.state.should_continue(
Expand Down Expand Up @@ -568,34 +637,44 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
len(self.state.node_history) + 1,
)

# Store the current node before execution to detect handoffs
previous_node = current_node

# Execute node with timeout protection
# TODO: Implement cancellation token to stop _execute_node from continuing
try:
await asyncio.wait_for(
# Execute with timeout wrapper for async generator streaming
node_stream = self._stream_with_timeout(
self._execute_node(current_node, self.state.task, invocation_state),
timeout=self.node_timeout,
self.node_timeout,
f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s",
)
async for event in node_stream:
yield event

self.state.node_history.append(current_node)

logger.debug("node=<%s> | node execution completed", current_node.node_id)

# Check if the current node is still the same after execution
# If it is, then no handoff occurred and we consider the swarm complete
if self.state.current_node == current_node:
# Check if handoff occurred during execution
if self.state.current_node != previous_node:
# Emit handoff event
handoff_event = MultiAgentHandoffEvent(
from_node=previous_node.node_id,
to_node=self.state.current_node.node_id,
message=self.state.handoff_message or "Agent handoff occurred",
)
yield handoff_event
logger.debug(
"from_node=<%s>, to_node=<%s> | handoff detected",
previous_node.node_id,
self.state.current_node.node_id,
)
else:
# No handoff occurred, mark swarm as complete
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
self.state.completion_status = Status.COMPLETED
break

except asyncio.TimeoutError:
logger.exception(
"node=<%s>, timeout=<%s>s | node execution timed out after timeout",
current_node.node_id,
self.node_timeout,
)
self.state.completion_status = Status.FAILED
break

except Exception:
logger.exception("node=<%s> | node execution failed", current_node.node_id)
self.state.completion_status = Status.FAILED
Expand All @@ -615,11 +694,15 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:

async def _execute_node(
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
) -> AgentResult:
"""Execute swarm node."""
) -> AsyncIterator[Any]:
"""Execute swarm node and yield TypedEvent objects."""
start_time = time.time()
node_name = node.node_id

# Emit node start event
start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent")
yield start_event

try:
# Prepare context for node
context_text = self._build_node_input(node)
Expand All @@ -632,11 +715,22 @@ async def _execute_node(
# Include additional ContentBlocks in node input
node_input = node_input + task

# Execute node
result = None
# Execute node with streaming
node.reset_executor_state()
# Unpacking since this is the agent class. Other executors should not unpack
result = await node.executor.invoke_async(node_input, **invocation_state)

# Stream agent events with node context and capture final result
result = None
async for event in node.executor.stream_async(node_input, **invocation_state):
# Forward agent events with node context
wrapped_event = MultiAgentNodeStreamEvent(node_name, event)
yield wrapped_event
# Capture the final result event
if isinstance(event, dict) and "result" in event:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a case where this is not an dict?

result = event["result"]

# Use the captured result from streaming to avoid double execution
if result is None:
raise ValueError(f"Node '{node_name}' did not produce a result event")

execution_time = round((time.time() - start_time) * 1000)

Expand Down Expand Up @@ -664,15 +758,17 @@ async def _execute_node(
# Accumulate metrics
self._accumulate_metrics(node_result)

return result
# Emit node complete event
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time)
yield complete_event

except Exception as e:
execution_time = round((time.time() - start_time) * 1000)
logger.exception("node=<%s> | node execution failed", node_name)

# Create a NodeResult for the failed node
node_result = NodeResult(
result=e, # Store exception as result
result=e,
execution_time=execution_time,
status=Status.FAILED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
Expand All @@ -683,6 +779,10 @@ async def _execute_node(
# Store result in state
self.state.results[node_name] = node_result

# Still emit complete event for failures
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time)
yield complete_event

raise

def _accumulate_metrics(self, node_result: NodeResult) -> None:
Expand Down
71 changes: 71 additions & 0 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,74 @@ def __init__(self, reason: str | Exception) -> None:
class AgentResultEvent(TypedEvent):
def __init__(self, result: "AgentResult"):
super().__init__({"result": result})


class MultiAgentResultEvent(TypedEvent):
"""Event emitted when multi-agent execution completes with final result."""

def __init__(self, result: Any) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this Any? - why is this not typed as GraphResult?

"""Initialize with multi-agent result.

Args:
result: The final result from multi-agent execution (SwarmResult, GraphResult, etc.)
"""
super().__init__({"result": result})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the caller differentiate between an AgentResult (using the key result) and a MultiAgent result?



class MultiAgentNodeStartEvent(TypedEvent):
"""Event emitted when a node begins execution in multi-agent context."""

def __init__(self, node_id: str, node_type: str) -> None:
"""Initialize with node information.

Args:
node_id: Unique identifier for the node
node_type: Type of node ("agent", "swarm", "graph")
"""
super().__init__({"multi_agent_node_start": True, "node_id": node_id, "node_type": node_type})


class MultiAgentNodeCompleteEvent(TypedEvent):
"""Event emitted when a node completes execution."""

def __init__(self, node_id: str, execution_time: int) -> None:
"""Initialize with completion information.

Args:
node_id: Unique identifier for the node
execution_time: Execution time in milliseconds
"""
super().__init__({"multi_agent_node_complete": True, "node_id": node_id, "execution_time": execution_time})


class MultiAgentHandoffEvent(TypedEvent):
"""Event emitted during agent handoffs in Swarm."""

def __init__(self, from_node: str, to_node: str, message: str) -> None:
"""Initialize with handoff information.

Args:
from_node: Node ID handing off control
to_node: Node ID receiving control
message: Handoff message explaining the transfer
"""
super().__init__({"multi_agent_handoff": True, "from_node": from_node, "to_node": to_node, "message": message})


class MultiAgentNodeStreamEvent(TypedEvent):
"""Event emitted during node execution - forwards agent events with node context."""

def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None:
"""Initialize with node context and agent event.

Args:
node_id: Unique identifier for the node generating the event
agent_event: The original agent event data
"""
super().__init__(
{
"multi_agent_node_stream": True,
"node_id": node_id,
**agent_event, # Forward all original agent event data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's nest this instead of combining. Specifically ToolStreamEvent has all data here as a sub-field and I think that's what we should do whenever we wrap things

}
)
Loading
Loading