From 6c00bbe85bd2b6506cd9c32c404c81377250e80b Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 2 Oct 2025 12:45:35 +0200 Subject: [PATCH 1/7] feat(multiagent): Add stream async --- src/strands/multiagent/base.py | 28 +- src/strands/multiagent/graph.py | 227 +++++++++++--- src/strands/multiagent/swarm.py | 121 ++++++-- src/strands/types/_events.py | 59 ++++ tests/strands/multiagent/test_graph.py | 408 +++++++++++++++++++++++++ tests/strands/multiagent/test_swarm.py | 270 ++++++++++++++++ 6 files changed, 1046 insertions(+), 67 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 03d7de9b4..ddc57cb68 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import Any, Union +from typing import Any, AsyncIterator, Union from ..agent import AgentResult from ..types.content import ContentBlock @@ -97,6 +97,32 @@ async def invoke_async( """ raise NotImplementedError("invoke_async not implemented") + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + """Stream events during multi-agent execution. + + This default implementation provides backward compatibility by executing + invoke_async and yielding a single result event. Subclasses can override + this method to provide true streaming capabilities. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Additional keyword arguments passed to underlying agents. + + Yields: + Dictionary events containing multi-agent execution information including: + - Multi-agent coordination events (node start/complete, handoffs) + - Forwarded single-agent events with node context + - Final result event + """ + # Default implementation for backward compatibility + # Execute invoke_async and yield the result as a single event + result = await self.invoke_async(task, invocation_state, **kwargs) + yield {"result": result} + def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> MultiAgentResult: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 738dc4d4c..457fb1fd2 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -20,13 +20,18 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Optional, Tuple +from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast from opentelemetry import trace as trace_api from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer +from ..types._events import ( + MultiAgentNodeCompleteEvent, + MultiAgentNodeStartEvent, + MultiAgentNodeStreamEvent, +) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -411,13 +416,39 @@ async def invoke_async( ) -> GraphResult: """Invoke the graph asynchronously. + This method uses stream_async internally and consumes all events until completion, + following the same pattern as the Agent class. + Args: task: The task to execute invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues - a new empty dict - is created if None is provided. + Defaults to None to avoid mutable default argument issues. **kwargs: Keyword arguments allowing backward compatible future changes. """ + events = self.stream_async(task, invocation_state, **kwargs) + async for event in events: + _ = event + + return cast(GraphResult, event["result"]) + + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + """Stream events during graph execution. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + + Yields: + Dictionary events containing graph execution information including: + - MultiAgentNodeStartEvent: When a node begins execution + - MultiAgentNodeStreamEvent: Forwarded agent events with node context + - MultiAgentNodeCompleteEvent: When a node completes execution + - Final result event + """ if invocation_state is None: invocation_state = {} @@ -444,23 +475,53 @@ async def invoke_async( self.node_timeout or "None", ) - await self._execute_graph(invocation_state) + async for event in self._execute_graph(invocation_state): + yield event # Set final status based on execution results if self.state.failed_nodes: self.state.status = Status.FAILED - elif self.state.status == Status.EXECUTING: # Only set to COMPLETED if still executing and no failures + elif self.state.status == Status.EXECUTING: self.state.status = Status.COMPLETED logger.debug("status=<%s> | graph execution completed", self.state.status) + # Yield final result (consistent with Agent's AgentResultEvent format) + result = self._build_result() + + # Use the same event format as Agent for consistency + yield {"result": result} + except Exception: logger.exception("graph execution failed") self.state.status = Status.FAILED raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - return self._build_result() + + async def _stream_with_timeout( + self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str + ) -> AsyncIterator[dict[str, Any]]: + """Wrap an async generator with timeout functionality.""" + + # Create a task for the entire generator with timeout + async def generator_with_timeout() -> AsyncIterator[dict[str, Any]]: + try: + async for event in async_generator: + yield event + except asyncio.TimeoutError: + raise Exception(timeout_message) from None + + # Use asyncio.wait_for on each individual event + generator = generator_with_timeout() + while True: + try: + event = await asyncio.wait_for(generator.__anext__(), timeout=timeout) + yield event + except StopAsyncIteration: + break + except asyncio.TimeoutError: + raise Exception(timeout_message) from None def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: """Validate graph nodes for duplicate instances.""" @@ -474,8 +535,8 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: - """Unified execution flow with conditional routing.""" + async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: + """Execute graph and yield events.""" ready_nodes = list(self.entry_points) while ready_nodes: @@ -487,22 +548,76 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: if not should_continue: self.state.status = Status.FAILED logger.debug("reason=<%s> | stopping execution", reason) - return # Let the top-level exception handler deal with it + return current_batch = ready_nodes.copy() ready_nodes.clear() - # Execute current batch of ready nodes concurrently - tasks = [asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch] - - for task in tasks: - await task + # Execute current batch of ready nodes in parallel + if len(current_batch) == 1: + # Single node - execute directly to avoid overhead + async for event in self._execute_node(current_batch[0], invocation_state): + yield event + else: + # Multiple nodes - execute in parallel and merge events + async for event in self._execute_nodes_parallel(current_batch, invocation_state): + yield event # Find newly ready nodes after batch execution - # We add all nodes in current batch as completed batch, - # because a failure would throw exception and code would not make it here ready_nodes.extend(self._find_newly_ready_nodes(current_batch)) + async def _execute_nodes_parallel( + self, nodes: list["GraphNode"], invocation_state: dict[str, Any] + ) -> AsyncIterator[dict[str, Any]]: + """Execute multiple nodes in parallel and merge their event streams.""" + # Create async generators for each node + node_generators = [self._execute_node(node, invocation_state) for node in nodes] + + # Track which generators are still active + active_generators = {i: gen for i, gen in enumerate(node_generators)} + + # Use asyncio.as_completed to process events as they arrive + while active_generators: + # Create tasks for the next event from each active generator + tasks: list[asyncio.Task[dict[str, Any]]] = [] + task_to_index: dict[asyncio.Task[dict[str, Any]], int] = {} + for i, gen in active_generators.items(): + task: asyncio.Task[dict[str, Any]] = asyncio.create_task(cast(Any, gen.__anext__())) + task_to_index[task] = i # Store the node index mapping + tasks.append(task) + + if not tasks: + break + + # Wait for the first event to arrive + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + # Process completed tasks + for task in done: + node_index = task_to_index[task] + try: + event = await task + yield event + except StopAsyncIteration: + # This generator is exhausted, remove it + if node_index in active_generators: + del active_generators[node_index] + except Exception: + # Node execution failed, remove it but continue with other nodes + # The _execute_node method has already recorded the failure + if node_index in active_generators: + del active_generators[node_index] + # Don't re-raise here - let other nodes complete + # The graph-level logic will determine if the overall execution should fail + + # Cancel pending tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" newly_ready = [] @@ -530,38 +645,52 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: - """Execute a single node with error handling and timeout protection.""" + async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: + """Execute a single node.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() - # Remove from completed nodes since we're re-executing it self.state.completed_nodes.remove(node) node.execution_status = Status.EXECUTING logger.debug("node_id=<%s> | executing node", node.node_id) + # Emit node start event + start_event = MultiAgentNodeStartEvent( + node_id=node.node_id, node_type="agent" if isinstance(node.executor, Agent) else "multiagent" + ) + yield start_event.as_dict() + start_time = time.time() try: # Build node input from satisfied dependencies node_input = self._build_node_input(node) - # Execute with timeout protection (only if node_timeout is set) + # Execute with timeout protection and stream events try: - # Execute based on node type and create unified NodeResult if isinstance(node.executor, MultiAgentBase): + # For nested multi-agent systems, stream their events if self.node_timeout is not None: - multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input, invocation_state), - timeout=self.node_timeout, - ) + # Implement timeout for async generator streaming + async for event in self._stream_with_timeout( + node.executor.stream_async(node_input, invocation_state), + self.node_timeout, + f"Node '{node.node_id}' execution timed out after {self.node_timeout}s", + ): + # Forward nested multi-agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event.as_dict() else: - multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) + async for event in node.executor.stream_async(node_input, invocation_state): + # Forward nested multi-agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event.as_dict() - # Create NodeResult with MultiAgentResult directly + # Get the final result for metrics + multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) node_result = NodeResult( - result=multi_agent_result, # type is MultiAgentResult + result=multi_agent_result, execution_time=multi_agent_result.execution_time, status=Status.COMPLETED, accumulated_usage=multi_agent_result.accumulated_usage, @@ -570,15 +699,25 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) elif isinstance(node.executor, Agent): + # For agents, stream their events if self.node_timeout is not None: - agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input, **invocation_state), - timeout=self.node_timeout, - ) + # Implement timeout for async generator streaming + async for event in self._stream_with_timeout( + node.executor.stream_async(node_input, **invocation_state), + self.node_timeout, + f"Node '{node.node_id}' execution timed out after {self.node_timeout}s", + ): + # Forward agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event.as_dict() else: - agent_response = await node.executor.invoke_async(node_input, **invocation_state) + async for event in node.executor.stream_async(node_input, **invocation_state): + # Forward agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event.as_dict() - # Extract metrics from agent response + # Get the final result for metrics + agent_response = await node.executor.invoke_async(node_input, **invocation_state) usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics = Metrics(latencyMs=0) if hasattr(agent_response, "metrics") and agent_response.metrics: @@ -588,7 +727,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) metrics = agent_response.metrics.accumulated_metrics node_result = NodeResult( - result=agent_response, # type is AgentResult + result=agent_response, execution_time=round((time.time() - start_time) * 1000), status=Status.COMPLETED, accumulated_usage=usage, @@ -601,7 +740,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) except asyncio.TimeoutError: timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out after timeout", + "node=<%s>, timeout=<%s>s | node execution timed out", node.node_id, self.node_timeout, ) @@ -618,8 +757,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Accumulate metrics self._accumulate_metrics(node_result) + # Emit node complete event + complete_event = MultiAgentNodeCompleteEvent(node_id=node.node_id, execution_time=node.execution_time) + yield complete_event.as_dict() + logger.debug( - "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time + "node_id=<%s>, execution_time=<%dms> | node completed successfully", + node.node_id, + node.execution_time, ) except Exception as e: @@ -628,7 +773,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Create a NodeResult for the failed node node_result = NodeResult( - result=e, # Store exception as result + result=e, execution_time=execution_time, status=Status.FAILED, accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), @@ -640,7 +785,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) node.result = node_result node.execution_time = execution_time self.state.failed_nodes.add(node) - self.state.results[node.node_id] = node_result # Store in results for consistency + self.state.results[node.node_id] = node_result + + # Still emit complete event even for failures + complete_event = MultiAgentNodeCompleteEvent(node_id=node.node_id, execution_time=execution_time) + yield complete_event.as_dict() raise diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 620fa5e24..c28f9899d 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -19,14 +19,20 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Tuple +from typing import Any, AsyncIterator, Callable, Tuple, cast from opentelemetry import trace as trace_api -from ..agent import Agent, AgentResult +from ..agent import Agent from ..agent.state import AgentState from ..telemetry import get_tracer from ..tools.decorator import tool +from ..types._events import ( + MultiAgentHandoffEvent, + MultiAgentNodeCompleteEvent, + MultiAgentNodeStartEvent, + MultiAgentNodeStreamEvent, +) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -266,12 +272,39 @@ async def invoke_async( ) -> SwarmResult: """Invoke the swarm asynchronously. + This method uses stream_async internally and consumes all events until completion, + following the same pattern as the Agent class. + + Args: + task: The task to execute + invocation_state: Additional state/context passed to underlying agents. + Defaults to None to avoid mutable default argument issues. + **kwargs: Keyword arguments allowing backward compatible future changes. + """ + events = self.stream_async(task, invocation_state, **kwargs) + async for event in events: + _ = event + + return cast(SwarmResult, event["result"]) + + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + """Stream events during swarm execution. + Args: task: The task to execute invocation_state: Additional state/context passed to underlying agents. - Defaults to None to avoid mutable default argument issues - a new empty dict - is created if None is provided. + Defaults to None to avoid mutable default argument issues. **kwargs: Keyword arguments allowing backward compatible future changes. + + Yields: + Dictionary events containing swarm execution information including: + - MultiAgentNodeStartEvent: When an agent begins execution + - MultiAgentNodeStreamEvent: Forwarded agent events with node context + - MultiAgentHandoffEvent: When control is handed off between agents + - MultiAgentNodeCompleteEvent: When an agent completes execution + - Final result event """ if invocation_state is None: invocation_state = {} @@ -282,7 +315,7 @@ async def invoke_async( if self.entry_point: initial_node = self.nodes[str(self.entry_point.name)] else: - initial_node = next(iter(self.nodes.values())) # First SwarmNode + initial_node = next(iter(self.nodes.values())) self.state = SwarmState( current_node=initial_node, @@ -303,7 +336,13 @@ async def invoke_async( self.execution_timeout, ) - await self._execute_swarm(invocation_state) + async for event in self._execute_swarm(invocation_state): + yield event + + # Yield final result (consistent with Agent's AgentResultEvent format) + result = self._build_result() + yield {"result": result} + except Exception: logger.exception("swarm execution failed") self.state.completion_status = Status.FAILED @@ -311,8 +350,6 @@ async def invoke_async( finally: self.state.execution_time = round((time.time() - start_time) * 1000) - return self._build_result() - def _setup_swarm(self, nodes: list[Agent]) -> None: """Initialize swarm configuration.""" # Validate nodes before setup @@ -533,14 +570,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text - async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: - """Shared execution logic used by execute_async.""" + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: + """Execute swarm and yield events.""" try: # Main execution loop while True: if self.state.completion_status != Status.EXECUTING: reason = f"Completion status is: {self.state.completion_status}" - logger.debug("reason=<%s> | stopping execution", reason) + logger.debug("reason=<%s> | stopping streaming execution", reason) break should_continue, reason = self.state.should_continue( @@ -568,28 +605,43 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: len(self.state.node_history) + 1, ) + # Store the current node before execution to detect handoffs + previous_node = current_node + # Execute node with timeout protection - # TODO: Implement cancellation token to stop _execute_node from continuing try: - await asyncio.wait_for( - self._execute_node(current_node, self.state.task, invocation_state), - timeout=self.node_timeout, - ) + # For now, execute without timeout for async generators + # TODO: Implement proper timeout for async generators if needed + async for event in self._execute_node(current_node, self.state.task, invocation_state): + yield event self.state.node_history.append(current_node) logger.debug("node=<%s> | node execution completed", current_node.node_id) - # Check if the current node is still the same after execution - # If it is, then no handoff occurred and we consider the swarm complete - if self.state.current_node == current_node: + # Check if handoff occurred during execution + if self.state.current_node != previous_node: + # Emit handoff event + handoff_event = MultiAgentHandoffEvent( + from_node=previous_node.node_id, + to_node=self.state.current_node.node_id, + message=self.state.handoff_message or "Agent handoff occurred", + ) + yield handoff_event.as_dict() + logger.debug( + "from_node=<%s>, to_node=<%s> | handoff detected", + previous_node.node_id, + self.state.current_node.node_id, + ) + else: + # No handoff occurred, mark swarm as complete logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) self.state.completion_status = Status.COMPLETED break except asyncio.TimeoutError: logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out after timeout", + "node=<%s>, timeout=<%s>s | node execution timed out", current_node.node_id, self.node_timeout, ) @@ -615,11 +667,15 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: async def _execute_node( self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] - ) -> AgentResult: + ) -> AsyncIterator[dict[str, Any]]: """Execute swarm node.""" start_time = time.time() node_name = node.node_id + # Emit node start event + start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent") + yield start_event.as_dict() + try: # Prepare context for node context_text = self._build_node_input(node) @@ -632,12 +688,17 @@ async def _execute_node( # Include additional ContentBlocks in node input node_input = node_input + task - # Execute node - result = None + # Execute node with streaming node.reset_executor_state() - # Unpacking since this is the agent class. Other executors should not unpack - result = await node.executor.invoke_async(node_input, **invocation_state) + # Stream agent events with node context + async for event in node.executor.stream_async(node_input, **invocation_state): + # Forward agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node_name, event) + yield wrapped_event.as_dict() + + # Get the final result for metrics (we need to call invoke_async again for the result) + result = await node.executor.invoke_async(node_input, **invocation_state) execution_time = round((time.time() - start_time) * 1000) # Create NodeResult @@ -664,7 +725,9 @@ async def _execute_node( # Accumulate metrics self._accumulate_metrics(node_result) - return result + # Emit node complete event + complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) + yield complete_event.as_dict() except Exception as e: execution_time = round((time.time() - start_time) * 1000) @@ -672,7 +735,7 @@ async def _execute_node( # Create a NodeResult for the failed node node_result = NodeResult( - result=e, # Store exception as result + result=e, execution_time=execution_time, status=Status.FAILED, accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), @@ -683,6 +746,10 @@ async def _execute_node( # Store result in state self.state.results[node_name] = node_result + # Still emit complete event for failures + complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) + yield complete_event.as_dict() + raise def _accumulate_metrics(self, node_result: NodeResult) -> None: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 3d0f1d0f0..3332444a7 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -351,3 +351,62 @@ def __init__(self, reason: str | Exception) -> None: class AgentResultEvent(TypedEvent): def __init__(self, result: "AgentResult"): super().__init__({"result": result}) + + +class MultiAgentNodeStartEvent(TypedEvent): + """Event emitted when a node begins execution in multi-agent context.""" + + def __init__(self, node_id: str, node_type: str) -> None: + """Initialize with node information. + + Args: + node_id: Unique identifier for the node + node_type: Type of node ("agent", "swarm", "graph") + """ + super().__init__({"multi_agent_node_start": True, "node_id": node_id, "node_type": node_type}) + + +class MultiAgentNodeCompleteEvent(TypedEvent): + """Event emitted when a node completes execution.""" + + def __init__(self, node_id: str, execution_time: int) -> None: + """Initialize with completion information. + + Args: + node_id: Unique identifier for the node + execution_time: Execution time in milliseconds + """ + super().__init__({"multi_agent_node_complete": True, "node_id": node_id, "execution_time": execution_time}) + + +class MultiAgentHandoffEvent(TypedEvent): + """Event emitted during agent handoffs in Swarm.""" + + def __init__(self, from_node: str, to_node: str, message: str) -> None: + """Initialize with handoff information. + + Args: + from_node: Node ID handing off control + to_node: Node ID receiving control + message: Handoff message explaining the transfer + """ + super().__init__({"multi_agent_handoff": True, "from_node": from_node, "to_node": to_node, "message": message}) + + +class MultiAgentNodeStreamEvent(TypedEvent): + """Event emitted during node execution - forwards agent events with node context.""" + + def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None: + """Initialize with node context and agent event. + + Args: + node_id: Unique identifier for the node generating the event + agent_event: The original agent event data + """ + super().__init__( + { + "multi_agent_node_stream": True, + "node_id": node_id, + **agent_event, # Forward all original agent event data + } + ) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 8097d944e..e1ac87be5 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -40,7 +40,13 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen async def mock_invoke_async(*args, **kwargs): return mock_result + async def mock_stream_async(*args, **kwargs): + # Simple mock stream that yields a start event and then the result + yield {"agent_start": True} + yield {"result": mock_result} + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) + agent.stream_async = Mock(side_effect=mock_stream_async) return agent @@ -66,7 +72,14 @@ def create_mock_multi_agent(name, response_text="Multi-agent response"): execution_count=1, execution_time=150, ) + + async def mock_multi_stream_async(*args, **kwargs): + # Simple mock stream that yields a start event and then the result + yield {"multi_agent_start": True} + yield {"result": mock_result} + multi_agent.invoke_async = AsyncMock(return_value=mock_result) + multi_agent.stream_async = Mock(side_effect=mock_multi_stream_async) multi_agent.execute = Mock(return_value=mock_result) return multi_agent @@ -277,7 +290,13 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) async def mock_invoke_failure(*args, **kwargs): raise Exception("Simulated failure") + async def mock_stream_failure(*args, **kwargs): + # Simple mock stream that fails + yield {"agent_start": True} + raise Exception("Simulated failure") + failing_agent.invoke_async = mock_invoke_failure + failing_agent.stream_async = Mock(side_effect=mock_stream_failure) success_agent = create_mock_agent("success_agent", "Success") @@ -623,7 +642,13 @@ async def timeout_invoke(*args, **kwargs): await asyncio.sleep(0.2) # Longer than node timeout return timeout_agent.return_value + async def timeout_stream(*args, **kwargs): + yield {"agent_start": True} + await asyncio.sleep(0.2) # Longer than node timeout + yield {"result": timeout_agent.return_value} + timeout_agent.invoke_async = AsyncMock(side_effect=timeout_invoke) + timeout_agent.stream_async = Mock(side_effect=timeout_stream) builder = GraphBuilder() builder.add_node(timeout_agent, "timeout_node") @@ -1337,3 +1362,386 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_streaming_events(mock_strands_tracer, mock_use_span): + """Test that graph streaming emits proper events during execution.""" + # Create agents with custom streaming behavior + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + + # Track events from agent streams + agent_a_events = [ + {"agent_thinking": True, "thought": "Processing task A"}, + {"agent_progress": True, "step": "analyzing"}, + {"result": agent_a.return_value}, + ] + + agent_b_events = [ + {"agent_thinking": True, "thought": "Processing task B"}, + {"agent_progress": True, "step": "computing"}, + {"result": agent_b.return_value}, + ] + + async def stream_a(*args, **kwargs): + for event in agent_a_events: + yield event + + async def stream_b(*args, **kwargs): + for event in agent_b_events: + yield event + + agent_a.stream_async = Mock(side_effect=stream_a) + agent_b.stream_async = Mock(side_effect=stream_b) + + # Build graph: A -> B + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + graph = builder.build() + + # Collect all streaming events + events = [] + async for event in graph.stream_async("Test streaming"): + events.append(event) + + # Verify event structure and order + assert len(events) > 0 + + # Should have node start/complete events and forwarded agent events + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + node_complete_events = [e for e in events if e.get("multi_agent_node_complete")] + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + result_events = [e for e in events if "result" in e and not e.get("multi_agent_node_stream")] + + # Should have start/complete events for both nodes + assert len(node_start_events) == 2 + assert len(node_complete_events) == 2 + + # Should have forwarded agent events + assert len(node_stream_events) >= 4 # At least 2 events per agent + + # Should have final result + assert len(result_events) == 1 + + # Verify node start events have correct structure + for event in node_start_events: + assert "node_id" in event + assert "node_type" in event + assert event["node_type"] == "agent" + + # Verify node complete events have execution time + for event in node_complete_events: + assert "node_id" in event + assert "execution_time" in event + assert isinstance(event["execution_time"], int) + + # Verify forwarded events maintain node context + for event in node_stream_events: + assert "node_id" in event + assert event["node_id"] in ["a", "b"] + + # Verify final result + final_result = result_events[0]["result"] + assert final_result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_streaming_parallel_events(mock_strands_tracer, mock_use_span): + """Test that parallel graph execution properly streams events from concurrent nodes.""" + # Create agents that execute in parallel + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + agent_c = create_mock_agent("agent_c", "Response C") + + # Track timing and events + execution_order = [] + + async def stream_with_timing(node_id, delay=0.05): + execution_order.append(f"{node_id}_start") + yield {"node_start": True, "node": node_id} + await asyncio.sleep(delay) + yield {"node_progress": True, "node": node_id} + execution_order.append(f"{node_id}_end") + yield {"result": create_mock_agent(node_id, f"Response {node_id}").return_value} + + agent_a.stream_async = Mock(side_effect=lambda *args, **kwargs: stream_with_timing("A", 0.05)) + agent_b.stream_async = Mock(side_effect=lambda *args, **kwargs: stream_with_timing("B", 0.05)) + agent_c.stream_async = Mock(side_effect=lambda *args, **kwargs: stream_with_timing("C", 0.05)) + + # Build graph with parallel nodes + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + # All are entry points (parallel execution) + builder.set_entry_point("a") + builder.set_entry_point("b") + builder.set_entry_point("c") + graph = builder.build() + + # Collect streaming events + events = [] + start_time = time.time() + async for event in graph.stream_async("Test parallel streaming"): + events.append(event) + total_time = time.time() - start_time + + # Verify parallel execution timing + assert total_time < 0.2, f"Expected parallel execution, took {total_time}s" + + # Verify we get events from all nodes + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + nodes_with_events = set(e["node_id"] for e in node_stream_events) + assert nodes_with_events == {"a", "b", "c"} + + # Verify start events for all nodes + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + start_node_ids = set(e["node_id"] for e in node_start_events) + assert start_node_ids == {"a", "b", "c"} + + +@pytest.mark.asyncio +async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span): + """Test graph streaming behavior when nodes fail.""" + # Create a failing agent + failing_agent = Mock(spec=Agent) + failing_agent.name = "failing_agent" + failing_agent.id = "fail_node" + failing_agent._session_manager = None + failing_agent.hooks = HookRegistry() + + async def failing_stream(*args, **kwargs): + yield {"agent_start": True} + yield {"agent_thinking": True, "thought": "About to fail"} + await asyncio.sleep(0.01) + raise Exception("Simulated streaming failure") + + async def failing_invoke(*args, **kwargs): + raise Exception("Simulated failure") + + failing_agent.stream_async = Mock(side_effect=failing_stream) + failing_agent.invoke_async = failing_invoke + + # Create successful agent + success_agent = create_mock_agent("success_agent", "Success") + + # Build graph + builder = GraphBuilder() + builder.add_node(failing_agent, "fail") + builder.add_node(success_agent, "success") + builder.set_entry_point("fail") + builder.set_entry_point("success") + graph = builder.build() + + # Collect events until failure + events = [] + try: + async for event in graph.stream_async("Test streaming with failure"): + events.append(event) + raise AssertionError("Expected an exception") + except Exception: + # Should get some events before failure + assert len(events) > 0 + + # Should have node start events + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + assert len(node_start_events) >= 1 + + # Should have some forwarded events before failure + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + assert len(node_stream_events) >= 1 + + +@pytest.mark.asyncio +async def test_graph_parallel_execution(mock_strands_tracer, mock_use_span): + """Test that nodes without dependencies execute in parallel.""" + + # Create agents that track execution timing + execution_times = {} + + async def create_timed_agent(name, delay=0.1): + agent = create_mock_agent(name, f"{name} response") + + async def timed_invoke(*args, **kwargs): + start_time = time.time() + execution_times[name] = {"start": start_time} + await asyncio.sleep(delay) # Simulate work + end_time = time.time() + execution_times[name]["end"] = end_time + return agent.return_value + + async def timed_stream(*args, **kwargs): + # Simulate streaming by yielding some events then the final result + start_time = time.time() + execution_times[name] = {"start": start_time} + + # Yield a start event + yield {"agent_start": True, "node": name} + + await asyncio.sleep(delay) # Simulate work + + end_time = time.time() + execution_times[name]["end"] = end_time + + # Yield final result event + yield {"result": agent.return_value} + + agent.invoke_async = AsyncMock(side_effect=timed_invoke) + # Create a mock that returns the async generator directly + agent.stream_async = Mock(side_effect=timed_stream) + return agent + + # Create agents that should execute in parallel + agent_a = await create_timed_agent("agent_a", 0.1) + agent_b = await create_timed_agent("agent_b", 0.1) + agent_c = await create_timed_agent("agent_c", 0.1) + + # Create a dependent agent that should execute after the parallel ones + agent_d = await create_timed_agent("agent_d", 0.05) + + # Build graph: A, B, C execute in parallel, then D depends on all of them + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.add_node(agent_d, "d") + + # D depends on A, B, and C + builder.add_edge("a", "d") + builder.add_edge("b", "d") + builder.add_edge("c", "d") + + # A, B, C are entry points (no dependencies) + builder.set_entry_point("a") + builder.set_entry_point("b") + builder.set_entry_point("c") + + graph = builder.build() + + # Execute the graph + start_time = time.time() + result = await graph.invoke_async("Test parallel execution") + total_time = time.time() - start_time + + # Verify successful execution + assert result.status == Status.COMPLETED + assert result.completed_nodes == 4 + assert len(result.execution_order) == 4 + + # Verify all agents were called + agent_a.invoke_async.assert_called_once() + agent_b.invoke_async.assert_called_once() + agent_c.invoke_async.assert_called_once() + agent_d.invoke_async.assert_called_once() + + # Verify parallel execution: A, B, C should have overlapping execution times + # If they were sequential, total time would be ~0.35s (3 * 0.1 + 0.05) + # If parallel, total time should be ~0.15s (max(0.1, 0.1, 0.1) + 0.05) + assert total_time < 0.4, f"Expected parallel execution to be faster, took {total_time}s" + + # Verify timing overlap for parallel nodes + a_start = execution_times["agent_a"]["start"] + b_start = execution_times["agent_b"]["start"] + c_start = execution_times["agent_c"]["start"] + + # All parallel nodes should start within a small time window + max_start_diff = max(a_start, b_start, c_start) - min(a_start, b_start, c_start) + assert max_start_diff < 0.1, f"Parallel nodes should start nearly simultaneously, diff: {max_start_diff}s" + + # D should start after A, B, C have finished + d_start = execution_times["agent_d"]["start"] + a_end = execution_times["agent_a"]["end"] + b_end = execution_times["agent_b"]["end"] + c_end = execution_times["agent_c"]["end"] + + latest_parallel_end = max(a_end, b_end, c_end) + assert d_start >= latest_parallel_end - 0.02, "Dependent node should start after parallel nodes complete" + + +@pytest.mark.asyncio +async def test_graph_single_node_optimization(mock_strands_tracer, mock_use_span): + """Test that single node execution uses direct path (optimization).""" + agent = create_mock_agent("single_agent", "Single response") + + builder = GraphBuilder() + builder.add_node(agent, "single") + graph = builder.build() + + result = await graph.invoke_async("Test single node") + + assert result.status == Status.COMPLETED + assert result.completed_nodes == 1 + agent.invoke_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span): + """Test parallel execution with some nodes failing.""" + # Create a failing agent + failing_agent = Mock(spec=Agent) + failing_agent.name = "failing_agent" + failing_agent.id = "fail_node" + failing_agent._session_manager = None + failing_agent.hooks = HookRegistry() + + async def mock_invoke_failure(*args, **kwargs): + await asyncio.sleep(0.05) # Small delay + raise Exception("Simulated failure") + + async def mock_stream_failure_parallel(*args, **kwargs): + # Simple mock stream that fails + yield {"agent_start": True} + await asyncio.sleep(0.05) # Small delay + raise Exception("Simulated failure") + + failing_agent.invoke_async = mock_invoke_failure + failing_agent.stream_async = Mock(side_effect=mock_stream_failure_parallel) + + # Create successful agents that take longer than the failing agent + success_agent_a = create_mock_agent("success_a", "Success A") + success_agent_b = create_mock_agent("success_b", "Success B") + + # Override their stream methods to take longer + async def slow_stream_a(*args, **kwargs): + yield {"agent_start": True, "node": "success_a"} + await asyncio.sleep(0.1) # Longer than failing agent + yield {"result": success_agent_a.return_value} + + async def slow_stream_b(*args, **kwargs): + yield {"agent_start": True, "node": "success_b"} + await asyncio.sleep(0.1) # Longer than failing agent + yield {"result": success_agent_b.return_value} + + success_agent_a.stream_async = Mock(side_effect=slow_stream_a) + success_agent_b.stream_async = Mock(side_effect=slow_stream_b) + + # Build graph with parallel execution where one fails + builder = GraphBuilder() + builder.add_node(failing_agent, "fail") + builder.add_node(success_agent_a, "success_a") + builder.add_node(success_agent_b, "success_b") + + # All are entry points (parallel) + builder.set_entry_point("fail") + builder.set_entry_point("success_a") + builder.set_entry_point("success_b") + + graph = builder.build() + + # Execute should succeed with partial failures - some nodes succeed, some fail + result = await graph.invoke_async("Test parallel with failure") + + # The graph should fail since one node failed (current behavior) + assert result.status == Status.FAILED + assert result.failed_nodes == 1 # One failed node + + # Verify failed node is tracked + assert "fail" in result.results + assert result.results["fail"].status == Status.FAILED + + # Note: The successful nodes may not complete if the failure happens early + # This is expected behavior in the current implementation diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 7d3e69695..aed3563a4 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1,3 +1,4 @@ +import asyncio import time from unittest.mock import MagicMock, Mock, patch @@ -53,7 +54,14 @@ def create_mock_result(): async def mock_invoke_async(*args, **kwargs): return create_mock_result() + async def mock_stream_async(*args, **kwargs): + # Simple mock stream that yields a start event and then the result + yield {"agent_start": True, "node": name} + yield {"agent_thinking": True, "thought": f"Processing with {name}"} + yield {"result": create_mock_result()} + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) + agent.stream_async = Mock(side_effect=mock_stream_async) return agent @@ -574,3 +582,265 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_swarm_streaming_events(mock_strands_tracer, mock_use_span): + """Test that swarm streaming emits proper events during execution.""" + + # Create agents with custom streaming behavior + coordinator = create_mock_agent("coordinator", "Coordinating task") + specialist = create_mock_agent("specialist", "Specialized response") + + # Track events and execution order + execution_events = [] + + async def coordinator_stream(*args, **kwargs): + execution_events.append("coordinator_start") + yield {"agent_start": True, "node": "coordinator"} + yield {"agent_thinking": True, "thought": "Analyzing task"} + await asyncio.sleep(0.01) # Small delay + execution_events.append("coordinator_end") + yield {"result": coordinator.return_value} + + async def specialist_stream(*args, **kwargs): + execution_events.append("specialist_start") + yield {"agent_start": True, "node": "specialist"} + yield {"agent_thinking": True, "thought": "Applying expertise"} + await asyncio.sleep(0.01) # Small delay + execution_events.append("specialist_end") + yield {"result": specialist.return_value} + + coordinator.stream_async = Mock(side_effect=coordinator_stream) + specialist.stream_async = Mock(side_effect=specialist_stream) + + # Create swarm with handoff logic + swarm = Swarm(nodes=[coordinator, specialist], max_handoffs=2, max_iterations=3, execution_timeout=30.0) + + # Add handoff tool to coordinator to trigger specialist + def handoff_to_specialist(): + """Hand off to specialist for detailed analysis.""" + return specialist + + coordinator.tool_registry.registry = {"handoff_to_specialist": handoff_to_specialist} + + # Collect all streaming events + events = [] + async for event in swarm.stream_async("Test swarm streaming"): + events.append(event) + + # Verify event structure + assert len(events) > 0 + + # Should have node start/complete events + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + node_complete_events = [e for e in events if e.get("multi_agent_node_complete")] + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + result_events = [e for e in events if "result" in e and not e.get("multi_agent_node_stream")] + + # Should have at least one node execution + assert len(node_start_events) >= 1 + assert len(node_complete_events) >= 1 + + # Should have forwarded agent events + assert len(node_stream_events) >= 2 # At least some events per agent + + # Should have final result + assert len(result_events) == 1 + + # Verify node start events have correct structure + for event in node_start_events: + assert "node_id" in event + assert "node_type" in event + assert event["node_type"] == "agent" + + # Verify node complete events have execution time + for event in node_complete_events: + assert "node_id" in event + assert "execution_time" in event + assert isinstance(event["execution_time"], int) + + # Verify forwarded events maintain node context + for event in node_stream_events: + assert "node_id" in event + + # Verify final result + final_result = result_events[0]["result"] + assert final_result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_swarm_streaming_with_handoffs(mock_strands_tracer, mock_use_span): + """Test swarm streaming with agent handoffs.""" + + # Create agents + coordinator = create_mock_agent("coordinator", "Coordinating") + specialist = create_mock_agent("specialist", "Specialized work") + reviewer = create_mock_agent("reviewer", "Review complete") + + # Track handoff sequence + handoff_sequence = [] + + async def coordinator_stream(*args, **kwargs): + yield {"agent_start": True, "node": "coordinator"} + yield {"agent_thinking": True, "thought": "Need specialist help"} + handoff_sequence.append("coordinator_to_specialist") + yield {"result": coordinator.return_value} + + async def specialist_stream(*args, **kwargs): + yield {"agent_start": True, "node": "specialist"} + yield {"agent_thinking": True, "thought": "Doing specialized work"} + handoff_sequence.append("specialist_to_reviewer") + yield {"result": specialist.return_value} + + async def reviewer_stream(*args, **kwargs): + yield {"agent_start": True, "node": "reviewer"} + yield {"agent_thinking": True, "thought": "Reviewing work"} + handoff_sequence.append("reviewer_complete") + yield {"result": reviewer.return_value} + + coordinator.stream_async = Mock(side_effect=coordinator_stream) + specialist.stream_async = Mock(side_effect=specialist_stream) + reviewer.stream_async = Mock(side_effect=reviewer_stream) + + # Set up handoff tools + def handoff_to_specialist(): + return specialist + + def handoff_to_reviewer(): + return reviewer + + coordinator.tool_registry.registry = {"handoff_to_specialist": handoff_to_specialist} + specialist.tool_registry.registry = {"handoff_to_reviewer": handoff_to_reviewer} + reviewer.tool_registry.registry = {} + + # Create swarm + swarm = Swarm(nodes=[coordinator, specialist, reviewer], max_handoffs=5, max_iterations=5, execution_timeout=30.0) + + # Collect streaming events + events = [] + async for event in swarm.stream_async("Test handoff streaming"): + events.append(event) + + # Should have multiple node executions due to handoffs + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + handoff_events = [e for e in events if e.get("multi_agent_handoff")] + + # Should have executed at least one agent (handoffs are complex to mock) + assert len(node_start_events) >= 1 + + # Verify handoff events have proper structure if any occurred + for event in handoff_events: + assert "from_node" in event + assert "to_node" in event + assert "message" in event + + +@pytest.mark.asyncio +async def test_swarm_streaming_with_failures(mock_strands_tracer, mock_use_span): + """Test swarm streaming behavior when agents fail.""" + + # Create a failing agent (don't fail during creation, fail during execution) + failing_agent = create_mock_agent("failing_agent", "Should fail") + success_agent = create_mock_agent("success_agent", "Success") + + async def failing_stream(*args, **kwargs): + yield {"agent_start": True, "node": "failing_agent"} + yield {"agent_thinking": True, "thought": "About to fail"} + await asyncio.sleep(0.01) + raise Exception("Simulated streaming failure") + + async def success_stream(*args, **kwargs): + yield {"agent_start": True, "node": "success_agent"} + yield {"agent_thinking": True, "thought": "Working successfully"} + yield {"result": success_agent.return_value} + + failing_agent.stream_async = Mock(side_effect=failing_stream) + success_agent.stream_async = Mock(side_effect=success_stream) + + # Create swarm starting with failing agent + swarm = Swarm(nodes=[failing_agent, success_agent], max_handoffs=2, max_iterations=3, execution_timeout=30.0) + + # Collect events until failure + events = [] + try: + async for event in swarm.stream_async("Test streaming with failure"): + events.append(event) + # If we get here, the swarm might have handled the failure gracefully + except Exception: + # Should get some events before failure + assert len(events) > 0 + + # Should have node start events + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + assert len(node_start_events) >= 1 + + # Should have some forwarded events before failure + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + assert len(node_stream_events) >= 1 + + +@pytest.mark.asyncio +async def test_swarm_streaming_timeout_behavior(mock_strands_tracer, mock_use_span): + """Test swarm streaming with execution timeout.""" + + # Create a slow agent + slow_agent = create_mock_agent("slow_agent", "Slow response") + + async def slow_stream(*args, **kwargs): + yield {"agent_start": True, "node": "slow_agent"} + yield {"agent_thinking": True, "thought": "Taking my time"} + await asyncio.sleep(0.2) # Longer than timeout + yield {"result": slow_agent.return_value} + + slow_agent.stream_async = Mock(side_effect=slow_stream) + + # Create swarm with short timeout + swarm = Swarm( + nodes=[slow_agent], + max_handoffs=1, + max_iterations=1, + execution_timeout=0.1, # Very short timeout + ) + + # Should timeout during streaming or complete + events = [] + try: + async for event in swarm.stream_async("Test timeout streaming"): + events.append(event) + # If no timeout, that's also acceptable for this test + # Just verify we got some events + assert len(events) >= 1 + except Exception: + # Timeout is expected but not required for this test + # Should still get some initial events + assert len(events) >= 1 + + +@pytest.mark.asyncio +async def test_swarm_streaming_backward_compatibility(mock_strands_tracer, mock_use_span): + """Test that swarm streaming maintains backward compatibility.""" + # Create simple agent + agent = create_mock_agent("test_agent", "Test response") + + # Create swarm + swarm = Swarm(nodes=[agent]) + + # Test that invoke_async still works + result = await swarm.invoke_async("Test backward compatibility") + assert result.status == Status.COMPLETED + + # Test that streaming also works and produces same result + events = [] + async for event in swarm.stream_async("Test backward compatibility"): + events.append(event) + + # Should have final result event + result_events = [e for e in events if "result" in e and not e.get("multi_agent_node_stream")] + assert len(result_events) == 1 + + streaming_result = result_events[0]["result"] + assert streaming_result.status == Status.COMPLETED + + # Results should be equivalent + assert result.status == streaming_result.status From 08141a04466270dcb785aae81e750b01a42a6abe Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 2 Oct 2025 13:11:04 +0200 Subject: [PATCH 2/7] fix(graph): improve parallel node calling --- src/strands/multiagent/graph.py | 95 +++++++++++++++++---------------- 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 457fb1fd2..f55724a95 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -569,54 +569,57 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato async def _execute_nodes_parallel( self, nodes: list["GraphNode"], invocation_state: dict[str, Any] ) -> AsyncIterator[dict[str, Any]]: - """Execute multiple nodes in parallel and merge their event streams.""" - # Create async generators for each node - node_generators = [self._execute_node(node, invocation_state) for node in nodes] - - # Track which generators are still active - active_generators = {i: gen for i, gen in enumerate(node_generators)} - - # Use asyncio.as_completed to process events as they arrive - while active_generators: - # Create tasks for the next event from each active generator - tasks: list[asyncio.Task[dict[str, Any]]] = [] - task_to_index: dict[asyncio.Task[dict[str, Any]], int] = {} - for i, gen in active_generators.items(): - task: asyncio.Task[dict[str, Any]] = asyncio.create_task(cast(Any, gen.__anext__())) - task_to_index[task] = i # Store the node index mapping - tasks.append(task) - - if not tasks: - break + """Execute multiple nodes in parallel and merge their event streams in real-time. + + Uses a shared queue where each node's stream runs independently and pushes events + as they occur, enabling true real-time event propagation without round-robin delays. + """ + # Create a shared queue for all events + event_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue() - # Wait for the first event to arrive - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + # Track active node tasks + active_tasks: set[asyncio.Task[None]] = set() - # Process completed tasks - for task in done: - node_index = task_to_index[task] - try: - event = await task - yield event - except StopAsyncIteration: - # This generator is exhausted, remove it - if node_index in active_generators: - del active_generators[node_index] - except Exception: - # Node execution failed, remove it but continue with other nodes - # The _execute_node method has already recorded the failure - if node_index in active_generators: - del active_generators[node_index] - # Don't re-raise here - let other nodes complete - # The graph-level logic will determine if the overall execution should fail - - # Cancel pending tasks - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + async def stream_node_to_queue(node: GraphNode, node_index: int) -> None: + """Stream events from a node to the shared queue.""" + try: + async for event in self._execute_node(node, invocation_state): + await event_queue.put(event) + except Exception as e: + # Node execution failed - the _execute_node method has already recorded the failure + # Log and continue to allow other nodes to complete + logger.debug( + "node_id=<%s>, error=<%s> | node streaming failed", + node.node_id, + str(e), + ) + finally: + # Signal that this node is done by putting None + await event_queue.put(None) + + # Start all node streams as independent tasks + for i, node in enumerate(nodes): + task = asyncio.create_task(stream_node_to_queue(node, i)) + active_tasks.add(task) + + # Track how many nodes have completed + completed_count = 0 + total_nodes = len(nodes) + + # Consume events from the queue as they arrive + while completed_count < total_nodes: + event = await event_queue.get() + + if event is None: + # A node has completed + completed_count += 1 + else: + # Forward the event immediately + yield event + + # Wait for all tasks to complete (should be immediate since they've all signaled completion) + if active_tasks: + await asyncio.gather(*active_tasks, return_exceptions=True) def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" From d4f55717f64a8d70075dbdae947b71ce02cac27b Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 2 Oct 2025 14:59:46 +0200 Subject: [PATCH 3/7] fix: Fix double execution --- src/strands/multiagent/graph.py | 30 +++- src/strands/multiagent/swarm.py | 60 +++++--- tests/strands/multiagent/test_graph.py | 190 +++++++++++++++++++------ tests/strands/multiagent/test_swarm.py | 151 ++++++++++++++++++-- 4 files changed, 355 insertions(+), 76 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index f55724a95..451d4a6e9 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -673,7 +673,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Execute with timeout protection and stream events try: if isinstance(node.executor, MultiAgentBase): - # For nested multi-agent systems, stream their events + # For nested multi-agent systems, stream their events and collect result + multi_agent_result = None if self.node_timeout is not None: # Implement timeout for async generator streaming async for event in self._stream_with_timeout( @@ -684,14 +685,22 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Forward nested multi-agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event.as_dict() + # Capture the final result event + if "result" in event: + multi_agent_result = event["result"] else: async for event in node.executor.stream_async(node_input, invocation_state): # Forward nested multi-agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event.as_dict() + # Capture the final result event + if "result" in event: + multi_agent_result = event["result"] + + # Use the captured result from streaming (no double execution) + if multi_agent_result is None: + raise ValueError(f"Node '{node.node_id}' did not produce a result event") - # Get the final result for metrics - multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, @@ -702,7 +711,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) elif isinstance(node.executor, Agent): - # For agents, stream their events + # For agents, stream their events and collect result + agent_response = None if self.node_timeout is not None: # Implement timeout for async generator streaming async for event in self._stream_with_timeout( @@ -713,14 +723,22 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event.as_dict() + # Capture the final result event + if "result" in event: + agent_response = event["result"] else: async for event in node.executor.stream_async(node_input, **invocation_state): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) yield wrapped_event.as_dict() + # Capture the final result event + if "result" in event: + agent_response = event["result"] + + # Use the captured result from streaming (no double execution) + if agent_response is None: + raise ValueError(f"Node '{node.node_id}' did not produce a result event") - # Get the final result for metrics - agent_response = await node.executor.invoke_async(node_input, **invocation_state) usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics = Metrics(latencyMs=0) if hasattr(agent_response, "metrics") and agent_response.metrics: diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index c28f9899d..66f3cf428 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -350,6 +350,19 @@ async def stream_async( finally: self.state.execution_time = round((time.time() - start_time) * 1000) + async def _stream_with_timeout( + self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str + ) -> AsyncIterator[dict[str, Any]]: + """Wrap an async generator with timeout functionality.""" + while True: + try: + event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) + yield event + except StopAsyncIteration: + break + except asyncio.TimeoutError: + raise Exception(timeout_message) from None + def _setup_swarm(self, nodes: list[Agent]) -> None: """Initialize swarm configuration.""" # Validate nodes before setup @@ -610,9 +623,17 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato # Execute node with timeout protection try: - # For now, execute without timeout for async generators - # TODO: Implement proper timeout for async generators if needed - async for event in self._execute_node(current_node, self.state.task, invocation_state): + # Execute with timeout wrapper for async generator streaming + node_stream = ( + self._stream_with_timeout( + self._execute_node(current_node, self.state.task, invocation_state), + self.node_timeout, + f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", + ) + if self.node_timeout is not None + else self._execute_node(current_node, self.state.task, invocation_state) + ) + async for event in node_stream: yield event self.state.node_history.append(current_node) @@ -639,17 +660,16 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato self.state.completion_status = Status.COMPLETED break - except asyncio.TimeoutError: - logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out", - current_node.node_id, - self.node_timeout, - ) - self.state.completion_status = Status.FAILED - break - - except Exception: - logger.exception("node=<%s> | node execution failed", current_node.node_id) + except Exception as e: + # Check if this is a timeout exception + if "timed out after" in str(e): + logger.exception( + "node=<%s>, timeout=<%s>s | node execution timed out", + current_node.node_id, + self.node_timeout, + ) + else: + logger.exception("node=<%s> | node execution failed", current_node.node_id) self.state.completion_status = Status.FAILED break @@ -691,14 +711,20 @@ async def _execute_node( # Execute node with streaming node.reset_executor_state() - # Stream agent events with node context + # Stream agent events with node context and capture final result + result = None async for event in node.executor.stream_async(node_input, **invocation_state): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node_name, event) yield wrapped_event.as_dict() + # Capture the final result event + if "result" in event: + result = event["result"] + + # Use the captured result from streaming to avoid double execution + if result is None: + raise ValueError(f"Node '{node_name}' did not produce a result event") - # Get the final result for metrics (we need to call invoke_async again for the result) - result = await node.executor.invoke_async(node_input, **invocation_state) execution_time = round((time.time() - start_time) * 1000) # Create NodeResult diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index e1ac87be5..af50b2cc2 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -214,15 +214,15 @@ async def test_graph_execution(mock_strands_tracer, mock_use_span, mock_graph, m assert len(result.execution_order) == 7 assert result.execution_order[0].node_id == "start_agent" - # Verify agent calls - mock_agents["start_agent"].invoke_async.assert_called_once() - mock_agents["multi_agent"].invoke_async.assert_called_once() - mock_agents["conditional_agent"].invoke_async.assert_called_once() - mock_agents["final_agent"].invoke_async.assert_called_once() - mock_agents["no_metrics_agent"].invoke_async.assert_called_once() - mock_agents["partial_metrics_agent"].invoke_async.assert_called_once() - string_content_agent.invoke_async.assert_called_once() - mock_agents["blocked_agent"].invoke_async.assert_not_called() + # Verify agent calls (now using stream_async internally) + assert mock_agents["start_agent"].stream_async.call_count == 1 + assert mock_agents["multi_agent"].stream_async.call_count == 1 + assert mock_agents["conditional_agent"].stream_async.call_count == 1 + assert mock_agents["final_agent"].stream_async.call_count == 1 + assert mock_agents["no_metrics_agent"].stream_async.call_count == 1 + assert mock_agents["partial_metrics_agent"].stream_async.call_count == 1 + assert string_content_agent.stream_async.call_count == 1 + assert mock_agents["blocked_agent"].stream_async.call_count == 0 # Verify metrics aggregation assert result.accumulated_usage["totalTokens"] > 0 @@ -328,8 +328,8 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span): result = await graph.invoke_async([{"text": "Original task"}]) - # Verify entry node was called with original task - entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}]) + # Verify entry node was called with original task (via stream_async) + assert entry_agent.stream_async.call_count == 1 assert result.status == Status.COMPLETED mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called_once() @@ -403,10 +403,10 @@ def spy_reset(self): execution_ids = [node.node_id for node in result.execution_order] assert execution_ids == ["a", "b", "c", "a"] - # Verify that each agent was called the expected number of times - assert agent_a.invoke_async.call_count == 2 # A executes twice - assert agent_b.invoke_async.call_count == 1 # B executes once - assert agent_c.invoke_async.call_count == 1 # C executes once + # Verify that each agent was called the expected number of times (via stream_async) + assert agent_a.stream_async.call_count == 2 # A executes twice + assert agent_b.stream_async.call_count == 1 # B executes once + assert agent_c.stream_async.call_count == 1 # C executes once # Verify that node state was reset for the revisited node (A) assert reset_spy.call_args_list == [call("a")] # Only A should be reset (when revisited) @@ -866,9 +866,9 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag assert result.execution_order[0].node_id == "start_agent" assert result.execution_order[1].node_id == "final_agent" - # Verify agent calls - mock_agents["start_agent"].invoke_async.assert_called_once() - mock_agents["final_agent"].invoke_async.assert_called_once() + # Verify agent calls (via stream_async) + assert mock_agents["start_agent"].stream_async.call_count == 1 + assert mock_agents["final_agent"].stream_async.call_count == 1 # Verify return type is GraphResult assert isinstance(result, GraphResult) @@ -946,6 +946,12 @@ async def invoke_async(self, input_data): ), ) + async def stream_async(self, input_data, **kwargs): + # Stream implementation that yields events and final result + yield {"agent_start": True} + result = await self.invoke_async(input_data) + yield {"result": result} + # Create agents agent_a = StatefulAgent("agent_a") agent_b = StatefulAgent("agent_b") @@ -1066,9 +1072,9 @@ async def test_linear_graph_behavior(): assert result.execution_order[0].node_id == "a" assert result.execution_order[1].node_id == "b" - # Verify agents were called once each (no state reset) - agent_a.invoke_async.assert_called_once() - agent_b.invoke_async.assert_called_once() + # Verify agents were called once each (no state reset, via stream_async) + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 @pytest.mark.asyncio @@ -1140,9 +1146,9 @@ def loop_condition(state: GraphState) -> bool: graph = builder.build() result = await graph.invoke_async("Test self loop") - # Verify basic self-loop functionality + # Verify basic self-loop functionality (via stream_async) assert result.status == Status.COMPLETED - assert self_loop_agent.invoke_async.call_count == 3 + assert self_loop_agent.stream_async.call_count == 3 assert len(result.execution_order) == 3 assert all(node.node_id == "self_loop" for node in result.execution_order) @@ -1202,9 +1208,9 @@ def end_condition(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) == 4 # start -> loop -> loop -> end assert [node.node_id for node in result.execution_order] == ["start_node", "loop_node", "loop_node", "end_node"] - assert start_agent.invoke_async.call_count == 1 - assert loop_agent.invoke_async.call_count == 2 - assert end_agent.invoke_async.call_count == 1 + assert start_agent.stream_async.call_count == 1 + assert loop_agent.stream_async.call_count == 2 + assert end_agent.stream_async.call_count == 1 @pytest.mark.asyncio @@ -1233,8 +1239,8 @@ def condition_b(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) == 4 # a -> a -> b -> b - assert agent_a.invoke_async.call_count == 2 - assert agent_b.invoke_async.call_count == 2 + assert agent_a.stream_async.call_count == 2 + assert agent_b.stream_async.call_count == 2 mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called() @@ -1309,7 +1315,7 @@ def multi_loop_condition(state: GraphState) -> bool: assert result.status == Status.COMPLETED assert len(result.execution_order) >= 2 - assert multi_agent.invoke_async.call_count >= 2 + assert multi_agent.stream_async.call_count >= 2 @pytest.mark.asyncio @@ -1325,7 +1331,8 @@ async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state) + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count == 1 assert result.status == Status.COMPLETED @@ -1342,9 +1349,8 @@ async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_spa test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state) - kwargs_multiagent.invoke_async.assert_called_once_with( - [{"text": "Test kwargs passing to multiagent"}], test_invocation_state - ) + # Verify stream_async was called (kwargs are passed through) + assert kwargs_multiagent.stream_async.call_count == 1 assert result.status == Status.COMPLETED @@ -1360,7 +1366,8 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): test_invocation_state = {"custom_param": "test_value", "another_param": 42} result = graph("Test kwargs passing sync", test_invocation_state) - kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count == 1 assert result.status == Status.COMPLETED @@ -1632,11 +1639,11 @@ async def timed_stream(*args, **kwargs): assert result.completed_nodes == 4 assert len(result.execution_order) == 4 - # Verify all agents were called - agent_a.invoke_async.assert_called_once() - agent_b.invoke_async.assert_called_once() - agent_c.invoke_async.assert_called_once() - agent_d.invoke_async.assert_called_once() + # Verify all agents were called (via stream_async) + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 + assert agent_c.stream_async.call_count == 1 + assert agent_d.stream_async.call_count == 1 # Verify parallel execution: A, B, C should have overlapping execution times # If they were sequential, total time would be ~0.35s (3 * 0.1 + 0.05) @@ -1675,7 +1682,7 @@ async def test_graph_single_node_optimization(mock_strands_tracer, mock_use_span assert result.status == Status.COMPLETED assert result.completed_nodes == 1 - agent.invoke_async.assert_called_once() + assert agent.stream_async.call_count == 1 @pytest.mark.asyncio @@ -1745,3 +1752,106 @@ async def slow_stream_b(*args, **kwargs): # Note: The successful nodes may not complete if the failure happens early # This is expected behavior in the current implementation + + +@pytest.mark.asyncio +async def test_graph_single_invocation_no_double_execution(mock_strands_tracer, mock_use_span): + """Test that nodes are only invoked once (no double execution from streaming).""" + # Create agents with invocation counters + agent_a = create_mock_agent("agent_a", "Response A") + agent_b = create_mock_agent("agent_b", "Response B") + + # Track invocation counts + invocation_counts = {"agent_a": 0, "agent_b": 0} + + async def counted_stream_a(*args, **kwargs): + invocation_counts["agent_a"] += 1 + yield {"agent_start": True} + yield {"agent_thinking": True, "thought": "Processing A"} + yield {"result": agent_a.return_value} + + async def counted_stream_b(*args, **kwargs): + invocation_counts["agent_b"] += 1 + yield {"agent_start": True} + yield {"agent_thinking": True, "thought": "Processing B"} + yield {"result": agent_b.return_value} + + agent_a.stream_async = Mock(side_effect=counted_stream_a) + agent_b.stream_async = Mock(side_effect=counted_stream_b) + + # Build graph: A -> B + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_edge("a", "b") + builder.set_entry_point("a") + graph = builder.build() + + # Execute the graph + result = await graph.invoke_async("Test single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Each agent should be invoked exactly once + assert invocation_counts["agent_a"] == 1, f"Agent A invoked {invocation_counts['agent_a']} times, expected 1" + assert invocation_counts["agent_b"] == 1, f"Agent B invoked {invocation_counts['agent_b']} times, expected 1" + + # Verify stream_async was called but invoke_async was NOT called + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 + # invoke_async should not be called at all since we're using streaming + agent_a.invoke_async.assert_not_called() + agent_b.invoke_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_graph_parallel_single_invocation(mock_strands_tracer, mock_use_span): + """Test that parallel nodes are only invoked once each.""" + # Create parallel agents with invocation counters + invocation_counts = {"a": 0, "b": 0, "c": 0} + + async def create_counted_agent(name): + agent = create_mock_agent(name, f"Response {name}") + + async def counted_stream(*args, **kwargs): + invocation_counts[name] += 1 + yield {"agent_start": True, "node": name} + await asyncio.sleep(0.01) # Small delay + yield {"result": agent.return_value} + + agent.stream_async = Mock(side_effect=counted_stream) + return agent + + agent_a = await create_counted_agent("a") + agent_b = await create_counted_agent("b") + agent_c = await create_counted_agent("c") + + # Build graph with parallel nodes + builder = GraphBuilder() + builder.add_node(agent_a, "a") + builder.add_node(agent_b, "b") + builder.add_node(agent_c, "c") + builder.set_entry_point("a") + builder.set_entry_point("b") + builder.set_entry_point("c") + graph = builder.build() + + # Execute the graph + result = await graph.invoke_async("Test parallel single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Each agent should be invoked exactly once + assert invocation_counts["a"] == 1, f"Agent A invoked {invocation_counts['a']} times, expected 1" + assert invocation_counts["b"] == 1, f"Agent B invoked {invocation_counts['b']} times, expected 1" + assert invocation_counts["c"] == 1, f"Agent C invoked {invocation_counts['c']} times, expected 1" + + # Verify stream_async was called but invoke_async was NOT called + assert agent_a.stream_async.call_count == 1 + assert agent_b.stream_async.call_count == 1 + assert agent_c.stream_async.call_count == 1 + agent_a.invoke_async.assert_not_called() + agent_b.invoke_async.assert_not_called() + agent_c.invoke_async.assert_not_called() diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index aed3563a4..946668e19 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -239,8 +239,8 @@ async def test_swarm_execution_async(mock_strands_tracer, mock_use_span, mock_sw assert result.execution_count == 1 assert len(result.results) == 1 - # Verify agent was called - mock_agents["coordinator"].invoke_async.assert_called() + # Verify agent was called (via stream_async) + assert mock_agents["coordinator"].stream_async.call_count >= 1 # Verify metrics aggregation assert result.accumulated_usage["totalTokens"] >= 0 @@ -275,8 +275,8 @@ def test_swarm_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag assert len(result.results) == 1 assert result.execution_time >= 0 - # Verify agent was called - mock_agents["coordinator"].invoke_async.assert_called() + # Verify agent was called (via stream_async) + assert mock_agents["coordinator"].stream_async.call_count >= 1 # Verify return type is SwarmResult assert isinstance(result, SwarmResult) @@ -366,7 +366,13 @@ def create_handoff_result(): async def mock_invoke_async(*args, **kwargs): return create_handoff_result() + async def mock_stream_async(*args, **kwargs): + yield {"agent_start": True} + result = create_handoff_result() + yield {"result": result} + agent.invoke_async = MagicMock(side_effect=mock_invoke_async) + agent.stream_async = Mock(side_effect=mock_stream_async) return agent # Create agents - first one hands off, second one completes by not handing off @@ -392,9 +398,9 @@ async def mock_invoke_async(*args, **kwargs): # Verify the completion agent executed after handoff assert result.node_history[1].node_id == "completion_agent" - # Verify both agents were called - handoff_agent.invoke_async.assert_called() - completion_agent.invoke_async.assert_called() + # Verify both agents were called (via stream_async) + assert handoff_agent.stream_async.call_count >= 1 + assert completion_agent.stream_async.call_count >= 1 # Test handoff when task is already completed completed_swarm = Swarm(nodes=[handoff_agent, completion_agent]) @@ -455,8 +461,8 @@ def test_swarm_auto_completion_without_handoff(): assert len(result.node_history) == 1 assert result.node_history[0].node_id == "no_handoff_agent" - # Verify the agent was called - no_handoff_agent.invoke_async.assert_called() + # Verify the agent was called (via stream_async) + assert no_handoff_agent.stream_async.call_count >= 1 def test_swarm_configurable_entry_point(): @@ -559,28 +565,28 @@ def test_swarm_validate_unsupported_features(): async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying agents.""" kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") - kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) swarm = Swarm(nodes=[kwargs_agent]) test_kwargs = {"custom_param": "test_value", "another_param": 42} result = await swarm.invoke_async("Test kwargs passing", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count >= 1 assert result.status == Status.COMPLETED def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): """Test that kwargs are passed through to underlying agents in sync execution.""" kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs") - kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async) swarm = Swarm(nodes=[kwargs_agent]) test_kwargs = {"custom_param": "test_value", "another_param": 42} result = swarm("Test kwargs passing sync", test_kwargs) - assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs + # Verify stream_async was called (kwargs are passed through) + assert kwargs_agent.stream_async.call_count >= 1 assert result.status == Status.COMPLETED @@ -844,3 +850,122 @@ async def test_swarm_streaming_backward_compatibility(mock_strands_tracer, mock_ # Results should be equivalent assert result.status == streaming_result.status + + +@pytest.mark.asyncio +async def test_swarm_single_invocation_no_double_execution(mock_strands_tracer, mock_use_span): + """Test that swarm nodes are only invoked once (no double execution from streaming).""" + # Create agent with invocation counter + agent = create_mock_agent("test_agent", "Test response") + + # Track invocation count + invocation_count = {"count": 0} + + async def counted_stream(*args, **kwargs): + invocation_count["count"] += 1 + yield {"agent_start": True, "node": "test_agent"} + yield {"agent_thinking": True, "thought": "Processing"} + yield {"result": agent.return_value} + + agent.stream_async = Mock(side_effect=counted_stream) + + # Create swarm + swarm = Swarm(nodes=[agent]) + + # Execute the swarm + result = await swarm.invoke_async("Test single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Agent should be invoked exactly once + assert invocation_count["count"] == 1, f"Agent invoked {invocation_count['count']} times, expected 1" + + # Verify stream_async was called but invoke_async was NOT called + assert agent.stream_async.call_count == 1 + # invoke_async should not be called at all since we're using streaming + agent.invoke_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_swarm_handoff_single_invocation_per_node(mock_strands_tracer, mock_use_span): + """Test that each node in a swarm handoff chain is invoked exactly once.""" + # Create agents with invocation counters + invocation_counts = {"coordinator": 0, "specialist": 0} + + coordinator = create_mock_agent("coordinator", "Coordinating") + specialist = create_mock_agent("specialist", "Specialized work") + + async def coordinator_stream(*args, **kwargs): + invocation_counts["coordinator"] += 1 + yield {"agent_start": True, "node": "coordinator"} + yield {"agent_thinking": True, "thought": "Need specialist"} + yield {"result": coordinator.return_value} + + async def specialist_stream(*args, **kwargs): + invocation_counts["specialist"] += 1 + yield {"agent_start": True, "node": "specialist"} + yield {"agent_thinking": True, "thought": "Doing specialized work"} + yield {"result": specialist.return_value} + + coordinator.stream_async = Mock(side_effect=coordinator_stream) + specialist.stream_async = Mock(side_effect=specialist_stream) + + # Set up handoff tool + def handoff_to_specialist(): + return specialist + + coordinator.tool_registry.registry = {"handoff_to_specialist": handoff_to_specialist} + specialist.tool_registry.registry = {} + + # Create swarm + swarm = Swarm(nodes=[coordinator, specialist], max_handoffs=2, max_iterations=3) + + # Execute the swarm + result = await swarm.invoke_async("Test handoff single invocation") + + # Verify successful execution + assert result.status == Status.COMPLETED + + # CRITICAL: Each agent should be invoked exactly once + # Note: Actual invocation depends on whether handoff occurs, but no double execution + assert invocation_counts["coordinator"] == 1, f"Coordinator invoked {invocation_counts['coordinator']} times" + # Specialist may or may not be invoked depending on handoff logic, but if invoked, only once + assert invocation_counts["specialist"] <= 1, f"Specialist invoked {invocation_counts['specialist']} times" + + # Verify stream_async was called but invoke_async was NOT called + assert coordinator.stream_async.call_count == 1 + coordinator.invoke_async.assert_not_called() + if invocation_counts["specialist"] > 0: + specialist.invoke_async.assert_not_called() + + +@pytest.mark.asyncio +async def test_swarm_timeout_with_streaming(mock_strands_tracer, mock_use_span): + """Test that swarm node timeout works correctly with streaming.""" + # Create a slow agent + slow_agent = create_mock_agent("slow_agent", "Slow response") + + async def slow_stream(*args, **kwargs): + yield {"agent_start": True, "node": "slow_agent"} + await asyncio.sleep(0.3) # Longer than timeout + yield {"result": slow_agent.return_value} + + slow_agent.stream_async = Mock(side_effect=slow_stream) + + # Create swarm with short node timeout + swarm = Swarm( + nodes=[slow_agent], + max_handoffs=1, + max_iterations=1, + node_timeout=0.1, # Short timeout + ) + + # Execute - should complete with FAILED status due to timeout + result = await swarm.invoke_async("Test timeout") + + # Verify the swarm failed due to timeout + assert result.status == Status.FAILED + + # Verify the agent started streaming + assert slow_agent.stream_async.call_count == 1 From fc0a2729525cfe999fc8e2e441a8ae7b269269c1 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 3 Oct 2025 11:58:11 +0200 Subject: [PATCH 4/7] fix: improve graph timeout --- src/strands/multiagent/graph.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 451d4a6e9..c670d1417 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -503,20 +503,9 @@ async def _stream_with_timeout( self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str ) -> AsyncIterator[dict[str, Any]]: """Wrap an async generator with timeout functionality.""" - - # Create a task for the entire generator with timeout - async def generator_with_timeout() -> AsyncIterator[dict[str, Any]]: - try: - async for event in async_generator: - yield event - except asyncio.TimeoutError: - raise Exception(timeout_message) from None - - # Use asyncio.wait_for on each individual event - generator = generator_with_timeout() while True: try: - event = await asyncio.wait_for(generator.__anext__(), timeout=timeout) + event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) yield event except StopAsyncIteration: break From 60f16b9581dff90f9ef715ac77f765557f216ee3 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 3 Oct 2025 12:27:31 +0200 Subject: [PATCH 5/7] fix: Add integ tests --- tests_integ/test_multiagent_graph.py | 182 +++++++++++++++++++++++++++ tests_integ/test_multiagent_swarm.py | 38 ++++++ 2 files changed, 220 insertions(+) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index c2c13c443..bebef054a 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,3 +1,5 @@ +from typing import Any, AsyncIterator + import pytest from strands import Agent, tool @@ -9,6 +11,7 @@ BeforeModelCallEvent, MessageAddedEvent, ) +from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status from strands.multiagent.graph import GraphBuilder from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -218,3 +221,182 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events + + +class CustomStreamingNode(MultiAgentBase): + """Custom node that wraps an agent and adds custom streaming events.""" + + def __init__(self, agent: Agent, name: str): + self.agent = agent + self.name = name + + async def invoke_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + result = await self.agent.invoke_async(task, **kwargs) + node_result = NodeResult(result=result, status=Status.COMPLETED) + return MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result}) + + async def stream_async( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + yield {"custom_event": "start", "node": self.name} + result = await self.agent.invoke_async(task, **kwargs) + yield {"custom_event": "agent_complete", "node": self.name} + node_result = NodeResult(result=result, status=Status.COMPLETED) + yield {"result": MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result})} + + +@pytest.mark.asyncio +async def test_graph_streaming_with_agents(): + """Test that Graph properly streams events from agent nodes.""" + math_agent = Agent( + name="math", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a math assistant.", + tools=[calculate_sum], + ) + summary_agent = Agent( + name="summary", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summary assistant.", + ) + + builder = GraphBuilder() + builder.add_node(math_agent, "math") + builder.add_node(summary_agent, "summary") + builder.add_edge("math", "summary") + builder.set_entry_point("math") + graph = builder.build() + + # Collect events + events = [] + async for event in graph.stream_async("Calculate 5 + 3 and summarize the result"): + events.append(event) + + # Count event categories + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + node_complete_events = [e for e in events if e.get("multi_agent_node_complete")] + result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e] + + # Verify we got multiple events of each type + assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}" + assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(node_complete_events) >= 2, f"Expected at least 2 node_complete events, got {len(node_complete_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify we have events for both nodes + math_events = [e for e in events if e.get("node_id") == "math"] + summary_events = [e for e in events if e.get("node_id") == "summary"] + assert len(math_events) > 0, "Expected events from math node" + assert len(summary_events) > 0, "Expected events from summary node" + + +@pytest.mark.asyncio +async def test_graph_streaming_with_custom_node(): + """Test that Graph properly streams events from custom MultiAgentBase nodes.""" + math_agent = Agent( + name="math", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a math assistant.", + tools=[calculate_sum], + ) + summary_agent = Agent( + name="summary", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summary assistant.", + ) + + # Create a custom node + custom_node = CustomStreamingNode(summary_agent, "custom_summary") + + builder = GraphBuilder() + builder.add_node(math_agent, "math") + builder.add_node(custom_node, "custom_summary") + builder.add_edge("math", "custom_summary") + builder.set_entry_point("math") + graph = builder.build() + + # Collect events + events = [] + async for event in graph.stream_async("Calculate 5 + 3 and summarize the result"): + events.append(event) + + # Count event categories + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + custom_events = [e for e in events if e.get("custom_event")] + result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e] + + # Verify we got multiple events of each type + assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}" + assert len(node_stream_events) > 5, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(custom_events) >= 2, f"Expected at least 2 custom events (start, complete), got {len(custom_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify custom events are properly structured + custom_start = [e for e in custom_events if e.get("custom_event") == "start"] + custom_complete = [e for e in custom_events if e.get("custom_event") == "agent_complete"] + + assert len(custom_start) >= 1, "Expected at least 1 custom start event" + assert len(custom_complete) >= 1, "Expected at least 1 custom complete event" + + +@pytest.mark.asyncio +async def test_nested_graph_streaming(): + """Test that nested graphs properly propagate streaming events.""" + math_agent = Agent( + name="math", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a math assistant.", + tools=[calculate_sum], + ) + analysis_agent = Agent( + name="analysis", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are an analysis assistant.", + ) + + # Create nested graph + nested_builder = GraphBuilder() + nested_builder.add_node(math_agent, "calculator") + nested_builder.add_node(analysis_agent, "analyzer") + nested_builder.add_edge("calculator", "analyzer") + nested_builder.set_entry_point("calculator") + nested_graph = nested_builder.build() + + # Create outer graph with nested graph + summary_agent = Agent( + name="summary", + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summary assistant.", + ) + + outer_builder = GraphBuilder() + outer_builder.add_node(nested_graph, "computation") + outer_builder.add_node(summary_agent, "summary") + outer_builder.add_edge("computation", "summary") + outer_builder.set_entry_point("computation") + outer_graph = outer_builder.build() + + # Collect events + events = [] + async for event in outer_graph.stream_async("Calculate 7 + 8 and provide a summary"): + events.append(event) + + # Count event categories + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e] + + # Verify we got multiple events + assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}" + assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify we have events from nested nodes + computation_events = [e for e in events if e.get("node_id") == "computation"] + summary_events = [e for e in events if e.get("node_id") == "summary"] + assert len(computation_events) > 0, "Expected events from computation (nested graph) node" + assert len(summary_events) > 0, "Expected events from summary node" diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 9a8c79bf8..9c0ecad76 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -134,3 +134,41 @@ async def test_swarm_execution_with_image(researcher_agent, analyst_agent, write # Verify agent history - at least one agent should have been used assert len(result.node_history) > 0 + + +@pytest.mark.asyncio +async def test_swarm_streaming(): + """Test that Swarm properly streams events during execution.""" + researcher = Agent( + name="researcher", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a researcher. When you need calculations, hand off to the analyst.", + ) + analyst = Agent( + name="analyst", + model="us.amazon.nova-pro-v1:0", + system_prompt="You are an analyst. Use tools to perform calculations.", + tools=[calculate], + ) + + swarm = Swarm([researcher, analyst]) + + # Collect events + events = [] + async for event in swarm.stream_async("Calculate 10 + 5 and explain the result"): + events.append(event) + + # Count event categories + node_start_events = [e for e in events if e.get("multi_agent_node_start")] + node_stream_events = [e for e in events if e.get("multi_agent_node_stream")] + result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e] + + # Verify we got multiple events of each type + assert len(node_start_events) >= 1, f"Expected at least 1 node_start event, got {len(node_start_events)}" + assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}" + assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}" + + # Verify we have events from at least one agent + researcher_events = [e for e in events if e.get("node_id") == "researcher"] + analyst_events = [e for e in events if e.get("node_id") == "analyst"] + assert len(researcher_events) > 0 or len(analyst_events) > 0, "Expected events from at least one agent" From a307f37eb7516c6e75fc90af1d9c1baa1ae9e9b8 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 10 Oct 2025 13:07:45 +0200 Subject: [PATCH 6/7] refactor(multiagent): improve streaming event handling and documentation - Update docstrings to match Agent's minimal style (use dict keys instead of class names) - Add isinstance checks for result event detection for type safety - Improve _stream_with_timeout to handle None timeout case - Add MultiAgentResultEvent for consistency with Agent pattern - Yield TypedEvent objects internally, convert to dict at API boundary - All 154 tests passing --- src/strands/multiagent/graph.py | 225 +++++++++++++++----------------- src/strands/multiagent/swarm.py | 99 +++++++------- src/strands/types/_events.py | 12 ++ 3 files changed, 172 insertions(+), 164 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index c670d1417..39b182589 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -31,6 +31,7 @@ MultiAgentNodeCompleteEvent, MultiAgentNodeStartEvent, MultiAgentNodeStreamEvent, + MultiAgentResultEvent, ) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage @@ -443,11 +444,11 @@ async def stream_async( **kwargs: Keyword arguments allowing backward compatible future changes. Yields: - Dictionary events containing graph execution information including: - - MultiAgentNodeStartEvent: When a node begins execution - - MultiAgentNodeStreamEvent: Forwarded agent events with node context - - MultiAgentNodeCompleteEvent: When a node completes execution - - Final result event + Dictionary events during graph execution, such as: + - multi_agent_node_start: When a node begins execution + - multi_agent_node_stream: Forwarded agent/multi-agent events with node context + - multi_agent_node_complete: When a node completes execution + - result: Final graph result """ if invocation_state is None: invocation_state = {} @@ -476,7 +477,7 @@ async def stream_async( ) async for event in self._execute_graph(invocation_state): - yield event + yield event.as_dict() # Set final status based on execution results if self.state.failed_nodes: @@ -490,7 +491,7 @@ async def stream_async( result = self._build_result() # Use the same event format as Agent for consistency - yield {"result": result} + yield MultiAgentResultEvent(result=result).as_dict() except Exception: logger.exception("graph execution failed") @@ -500,17 +501,35 @@ async def stream_async( self.state.execution_time = round((time.time() - start_time) * 1000) async def _stream_with_timeout( - self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str - ) -> AsyncIterator[dict[str, Any]]: - """Wrap an async generator with timeout functionality.""" - while True: - try: - event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) + self, async_generator: AsyncIterator[Any], timeout: float | None, timeout_message: str + ) -> AsyncIterator[Any]: + """Wrap an async generator with timeout functionality. + + Args: + async_generator: The generator to wrap + timeout: Timeout in seconds, or None for no timeout + timeout_message: Message to include in timeout exception + + Yields: + Events from the wrapped generator + + Raises: + Exception: If timeout is exceeded (same as original behavior) + """ + if timeout is None: + # No timeout - just pass through + async for event in async_generator: yield event - except StopAsyncIteration: - break - except asyncio.TimeoutError: - raise Exception(timeout_message) from None + else: + # Apply timeout to each event + while True: + try: + event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) + yield event + except StopAsyncIteration: + break + except asyncio.TimeoutError: + raise Exception(timeout_message) from None def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: """Validate graph nodes for duplicate instances.""" @@ -524,8 +543,8 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: - """Execute graph and yield events.""" + async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + """Execute graph and yield TypedEvent objects.""" ready_nodes = list(self.entry_points) while ready_nodes: @@ -557,14 +576,14 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato async def _execute_nodes_parallel( self, nodes: list["GraphNode"], invocation_state: dict[str, Any] - ) -> AsyncIterator[dict[str, Any]]: + ) -> AsyncIterator[Any]: """Execute multiple nodes in parallel and merge their event streams in real-time. Uses a shared queue where each node's stream runs independently and pushes events as they occur, enabling true real-time event propagation without round-robin delays. """ # Create a shared queue for all events - event_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue() + event_queue: asyncio.Queue[Any | None] = asyncio.Queue() # Track active node tasks active_tasks: set[asyncio.Task[None]] = set() @@ -637,8 +656,8 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ ) return False - async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: - """Execute a single node.""" + async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + """Execute a single node and yield TypedEvent objects.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) @@ -652,7 +671,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) start_event = MultiAgentNodeStartEvent( node_id=node.node_id, node_type="agent" if isinstance(node.executor, Agent) else "multiagent" ) - yield start_event.as_dict() + yield start_event start_time = time.time() try: @@ -660,101 +679,71 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) node_input = self._build_node_input(node) # Execute with timeout protection and stream events - try: - if isinstance(node.executor, MultiAgentBase): - # For nested multi-agent systems, stream their events and collect result - multi_agent_result = None - if self.node_timeout is not None: - # Implement timeout for async generator streaming - async for event in self._stream_with_timeout( - node.executor.stream_async(node_input, invocation_state), - self.node_timeout, - f"Node '{node.node_id}' execution timed out after {self.node_timeout}s", - ): - # Forward nested multi-agent events with node context - wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) - yield wrapped_event.as_dict() - # Capture the final result event - if "result" in event: - multi_agent_result = event["result"] - else: - async for event in node.executor.stream_async(node_input, invocation_state): - # Forward nested multi-agent events with node context - wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) - yield wrapped_event.as_dict() - # Capture the final result event - if "result" in event: - multi_agent_result = event["result"] - - # Use the captured result from streaming (no double execution) - if multi_agent_result is None: - raise ValueError(f"Node '{node.node_id}' did not produce a result event") - - node_result = NodeResult( - result=multi_agent_result, - execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, - accumulated_usage=multi_agent_result.accumulated_usage, - accumulated_metrics=multi_agent_result.accumulated_metrics, - execution_count=multi_agent_result.execution_count, - ) - - elif isinstance(node.executor, Agent): - # For agents, stream their events and collect result - agent_response = None - if self.node_timeout is not None: - # Implement timeout for async generator streaming - async for event in self._stream_with_timeout( - node.executor.stream_async(node_input, **invocation_state), - self.node_timeout, - f"Node '{node.node_id}' execution timed out after {self.node_timeout}s", - ): - # Forward agent events with node context - wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) - yield wrapped_event.as_dict() - # Capture the final result event - if "result" in event: - agent_response = event["result"] - else: - async for event in node.executor.stream_async(node_input, **invocation_state): - # Forward agent events with node context - wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) - yield wrapped_event.as_dict() - # Capture the final result event - if "result" in event: - agent_response = event["result"] - - # Use the captured result from streaming (no double execution) - if agent_response is None: - raise ValueError(f"Node '{node.node_id}' did not produce a result event") - - usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics = Metrics(latencyMs=0) - if hasattr(agent_response, "metrics") and agent_response.metrics: - if hasattr(agent_response.metrics, "accumulated_usage"): - usage = agent_response.metrics.accumulated_usage - if hasattr(agent_response.metrics, "accumulated_metrics"): - metrics = agent_response.metrics.accumulated_metrics - - node_result = NodeResult( - result=agent_response, - execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, - accumulated_usage=usage, - accumulated_metrics=metrics, - execution_count=1, - ) - else: - raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") + if isinstance(node.executor, MultiAgentBase): + # For nested multi-agent systems, stream their events and collect result + multi_agent_result = None + async for event in self._stream_with_timeout( + node.executor.stream_async(node_input, invocation_state), + self.node_timeout, + f"Node '{node.node_id}' execution timed out after {self.node_timeout}s", + ): + # Forward nested multi-agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event + # Capture the final result event + if isinstance(event, dict) and "result" in event: + multi_agent_result = event["result"] + + # Use the captured result from streaming (no double execution) + if multi_agent_result is None: + raise ValueError(f"Node '{node.node_id}' did not produce a result event") + + node_result = NodeResult( + result=multi_agent_result, + execution_time=multi_agent_result.execution_time, + status=Status.COMPLETED, + accumulated_usage=multi_agent_result.accumulated_usage, + accumulated_metrics=multi_agent_result.accumulated_metrics, + execution_count=multi_agent_result.execution_count, + ) - except asyncio.TimeoutError: - timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" - logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out", - node.node_id, + elif isinstance(node.executor, Agent): + # For agents, stream their events and collect result + agent_response = None + async for event in self._stream_with_timeout( + node.executor.stream_async(node_input, **invocation_state), self.node_timeout, + f"Node '{node.node_id}' execution timed out after {self.node_timeout}s", + ): + # Forward agent events with node context + wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event) + yield wrapped_event + # Capture the final result event + if isinstance(event, dict) and "result" in event: + agent_response = event["result"] + + # Use the captured result from streaming (no double execution) + if agent_response is None: + raise ValueError(f"Node '{node.node_id}' did not produce a result event") + + usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) + metrics = Metrics(latencyMs=0) + if hasattr(agent_response, "metrics") and agent_response.metrics: + if hasattr(agent_response.metrics, "accumulated_usage"): + usage = agent_response.metrics.accumulated_usage + if hasattr(agent_response.metrics, "accumulated_metrics"): + metrics = agent_response.metrics.accumulated_metrics + + node_result = NodeResult( + result=agent_response, + execution_time=round((time.time() - start_time) * 1000), + status=Status.COMPLETED, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=1, ) - raise Exception(timeout_msg) from None + else: + raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") # Mark as completed node.execution_status = Status.COMPLETED @@ -769,7 +758,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Emit node complete event complete_event = MultiAgentNodeCompleteEvent(node_id=node.node_id, execution_time=node.execution_time) - yield complete_event.as_dict() + yield complete_event logger.debug( "node_id=<%s>, execution_time=<%dms> | node completed successfully", @@ -799,7 +788,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Still emit complete event even for failures complete_event = MultiAgentNodeCompleteEvent(node_id=node.node_id, execution_time=execution_time) - yield complete_event.as_dict() + yield complete_event raise diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 66f3cf428..2930a4d09 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -32,6 +32,7 @@ MultiAgentNodeCompleteEvent, MultiAgentNodeStartEvent, MultiAgentNodeStreamEvent, + MultiAgentResultEvent, ) from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage @@ -299,12 +300,12 @@ async def stream_async( **kwargs: Keyword arguments allowing backward compatible future changes. Yields: - Dictionary events containing swarm execution information including: - - MultiAgentNodeStartEvent: When an agent begins execution - - MultiAgentNodeStreamEvent: Forwarded agent events with node context - - MultiAgentHandoffEvent: When control is handed off between agents - - MultiAgentNodeCompleteEvent: When an agent completes execution - - Final result event + Dictionary events during swarm execution, such as: + - multi_agent_node_start: When a node begins execution + - multi_agent_node_stream: Forwarded agent events with node context + - multi_agent_handoff: When control is handed off between agents + - multi_agent_node_complete: When a node completes execution + - result: Final swarm result """ if invocation_state is None: invocation_state = {} @@ -337,11 +338,11 @@ async def stream_async( ) async for event in self._execute_swarm(invocation_state): - yield event + yield event.as_dict() # Yield final result (consistent with Agent's AgentResultEvent format) result = self._build_result() - yield {"result": result} + yield MultiAgentResultEvent(result=result).as_dict() except Exception: logger.exception("swarm execution failed") @@ -351,17 +352,35 @@ async def stream_async( self.state.execution_time = round((time.time() - start_time) * 1000) async def _stream_with_timeout( - self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str - ) -> AsyncIterator[dict[str, Any]]: - """Wrap an async generator with timeout functionality.""" - while True: - try: - event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) + self, async_generator: AsyncIterator[Any], timeout: float | None, timeout_message: str + ) -> AsyncIterator[Any]: + """Wrap an async generator with timeout functionality. + + Args: + async_generator: The generator to wrap + timeout: Timeout in seconds, or None for no timeout + timeout_message: Message to include in timeout exception + + Yields: + Events from the wrapped generator + + Raises: + Exception: If timeout is exceeded (same as original behavior) + """ + if timeout is None: + # No timeout - just pass through + async for event in async_generator: yield event - except StopAsyncIteration: - break - except asyncio.TimeoutError: - raise Exception(timeout_message) from None + else: + # Apply timeout to each event + while True: + try: + event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) + yield event + except StopAsyncIteration: + break + except asyncio.TimeoutError: + raise Exception(timeout_message) from None def _setup_swarm(self, nodes: list[Agent]) -> None: """Initialize swarm configuration.""" @@ -583,8 +602,8 @@ def _build_node_input(self, target_node: SwarmNode) -> str: return context_text - async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: - """Execute swarm and yield events.""" + async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: + """Execute swarm and yield TypedEvent objects.""" try: # Main execution loop while True: @@ -624,14 +643,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato # Execute node with timeout protection try: # Execute with timeout wrapper for async generator streaming - node_stream = ( - self._stream_with_timeout( - self._execute_node(current_node, self.state.task, invocation_state), - self.node_timeout, - f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", - ) - if self.node_timeout is not None - else self._execute_node(current_node, self.state.task, invocation_state) + node_stream = self._stream_with_timeout( + self._execute_node(current_node, self.state.task, invocation_state), + self.node_timeout, + f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", ) async for event in node_stream: yield event @@ -648,7 +663,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato to_node=self.state.current_node.node_id, message=self.state.handoff_message or "Agent handoff occurred", ) - yield handoff_event.as_dict() + yield handoff_event logger.debug( "from_node=<%s>, to_node=<%s> | handoff detected", previous_node.node_id, @@ -660,16 +675,8 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato self.state.completion_status = Status.COMPLETED break - except Exception as e: - # Check if this is a timeout exception - if "timed out after" in str(e): - logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out", - current_node.node_id, - self.node_timeout, - ) - else: - logger.exception("node=<%s> | node execution failed", current_node.node_id) + except Exception: + logger.exception("node=<%s> | node execution failed", current_node.node_id) self.state.completion_status = Status.FAILED break @@ -687,14 +694,14 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato async def _execute_node( self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] - ) -> AsyncIterator[dict[str, Any]]: - """Execute swarm node.""" + ) -> AsyncIterator[Any]: + """Execute swarm node and yield TypedEvent objects.""" start_time = time.time() node_name = node.node_id # Emit node start event start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent") - yield start_event.as_dict() + yield start_event try: # Prepare context for node @@ -716,9 +723,9 @@ async def _execute_node( async for event in node.executor.stream_async(node_input, **invocation_state): # Forward agent events with node context wrapped_event = MultiAgentNodeStreamEvent(node_name, event) - yield wrapped_event.as_dict() + yield wrapped_event # Capture the final result event - if "result" in event: + if isinstance(event, dict) and "result" in event: result = event["result"] # Use the captured result from streaming to avoid double execution @@ -753,7 +760,7 @@ async def _execute_node( # Emit node complete event complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) - yield complete_event.as_dict() + yield complete_event except Exception as e: execution_time = round((time.time() - start_time) * 1000) @@ -774,7 +781,7 @@ async def _execute_node( # Still emit complete event for failures complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) - yield complete_event.as_dict() + yield complete_event raise diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 3332444a7..26ec31b0d 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -353,6 +353,18 @@ def __init__(self, result: "AgentResult"): super().__init__({"result": result}) +class MultiAgentResultEvent(TypedEvent): + """Event emitted when multi-agent execution completes with final result.""" + + def __init__(self, result: Any) -> None: + """Initialize with multi-agent result. + + Args: + result: The final result from multi-agent execution (SwarmResult, GraphResult, etc.) + """ + super().__init__({"result": result}) + + class MultiAgentNodeStartEvent(TypedEvent): """Event emitted when a node begins execution in multi-agent context.""" From 24502fc02d03267be9cb2364e349c6fc9ffa7c63 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Fri, 10 Oct 2025 13:18:04 +0200 Subject: [PATCH 7/7] fix(multiagent): remove no-op asyncio.gather in parallel execution - Remove unnecessary asyncio.gather() after event loop completion - Same issue as tool executor PR #954 - By the time loop exits, all tasks have already completed - Gather was waiting for already-finished tasks (no-op) - All 154 tests passing --- src/strands/multiagent/graph.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 39b182589..ca8c83ffa 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -625,10 +625,6 @@ async def stream_node_to_queue(node: GraphNode, node_index: int) -> None: # Forward the event immediately yield event - # Wait for all tasks to complete (should be immediate since they've all signaled completion) - if active_tasks: - await asyncio.gather(*active_tasks, return_exceptions=True) - def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" newly_ready = []