-
Notifications
You must be signed in to change notification settings - Fork 425
feat(multiagent): Add stream_async #961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
6c00bbe
b09b539
08141a0
d4f5571
fc0a272
ca59221
60f16b9
a307f37
24502fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,14 +19,20 @@ | |
import time | ||
from concurrent.futures import ThreadPoolExecutor | ||
from dataclasses import dataclass, field | ||
from typing import Any, Callable, Tuple | ||
from typing import Any, AsyncIterator, Callable, Tuple, cast | ||
|
||
from opentelemetry import trace as trace_api | ||
|
||
from ..agent import Agent, AgentResult | ||
from ..agent import Agent | ||
from ..agent.state import AgentState | ||
from ..telemetry import get_tracer | ||
from ..tools.decorator import tool | ||
from ..types._events import ( | ||
MultiAgentHandoffEvent, | ||
MultiAgentNodeCompleteEvent, | ||
MultiAgentNodeStartEvent, | ||
MultiAgentNodeStreamEvent, | ||
) | ||
from ..types.content import ContentBlock, Messages | ||
from ..types.event_loop import Metrics, Usage | ||
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status | ||
|
@@ -266,12 +272,39 @@ async def invoke_async( | |
) -> SwarmResult: | ||
"""Invoke the swarm asynchronously. | ||
|
||
This method uses stream_async internally and consumes all events until completion, | ||
following the same pattern as the Agent class. | ||
|
||
Args: | ||
task: The task to execute | ||
invocation_state: Additional state/context passed to underlying agents. | ||
Defaults to None to avoid mutable default argument issues. | ||
**kwargs: Keyword arguments allowing backward compatible future changes. | ||
""" | ||
events = self.stream_async(task, invocation_state, **kwargs) | ||
async for event in events: | ||
_ = event | ||
|
||
return cast(SwarmResult, event["result"]) | ||
|
||
async def stream_async( | ||
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any | ||
) -> AsyncIterator[dict[str, Any]]: | ||
"""Stream events during swarm execution. | ||
|
||
Args: | ||
task: The task to execute | ||
invocation_state: Additional state/context passed to underlying agents. | ||
Defaults to None to avoid mutable default argument issues - a new empty dict | ||
is created if None is provided. | ||
Defaults to None to avoid mutable default argument issues. | ||
**kwargs: Keyword arguments allowing backward compatible future changes. | ||
|
||
Yields: | ||
Dictionary events containing swarm execution information including: | ||
- MultiAgentNodeStartEvent: When an agent begins execution | ||
- MultiAgentNodeStreamEvent: Forwarded agent events with node context | ||
- MultiAgentHandoffEvent: When control is handed off between agents | ||
- MultiAgentNodeCompleteEvent: When an agent completes execution | ||
- Final result event | ||
""" | ||
if invocation_state is None: | ||
invocation_state = {} | ||
|
@@ -282,7 +315,7 @@ async def invoke_async( | |
if self.entry_point: | ||
initial_node = self.nodes[str(self.entry_point.name)] | ||
else: | ||
initial_node = next(iter(self.nodes.values())) # First SwarmNode | ||
initial_node = next(iter(self.nodes.values())) | ||
|
||
self.state = SwarmState( | ||
current_node=initial_node, | ||
|
@@ -303,15 +336,32 @@ async def invoke_async( | |
self.execution_timeout, | ||
) | ||
|
||
await self._execute_swarm(invocation_state) | ||
async for event in self._execute_swarm(invocation_state): | ||
yield event | ||
|
||
# Yield final result (consistent with Agent's AgentResultEvent format) | ||
result = self._build_result() | ||
yield {"result": result} | ||
|
||
except Exception: | ||
logger.exception("swarm execution failed") | ||
self.state.completion_status = Status.FAILED | ||
raise | ||
finally: | ||
self.state.execution_time = round((time.time() - start_time) * 1000) | ||
|
||
return self._build_result() | ||
async def _stream_with_timeout( | ||
self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str | ||
) -> AsyncIterator[dict[str, Any]]: | ||
"""Wrap an async generator with timeout functionality.""" | ||
while True: | ||
|
||
try: | ||
event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout) | ||
yield event | ||
except StopAsyncIteration: | ||
break | ||
except asyncio.TimeoutError: | ||
raise Exception(timeout_message) from None | ||
|
||
def _setup_swarm(self, nodes: list[Agent]) -> None: | ||
"""Initialize swarm configuration.""" | ||
|
@@ -533,14 +583,14 @@ def _build_node_input(self, target_node: SwarmNode) -> str: | |
|
||
return context_text | ||
|
||
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: | ||
"""Shared execution logic used by execute_async.""" | ||
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]: | ||
"""Execute swarm and yield events.""" | ||
try: | ||
# Main execution loop | ||
while True: | ||
if self.state.completion_status != Status.EXECUTING: | ||
reason = f"Completion status is: {self.state.completion_status}" | ||
logger.debug("reason=<%s> | stopping execution", reason) | ||
logger.debug("reason=<%s> | stopping streaming execution", reason) | ||
break | ||
|
||
should_continue, reason = self.state.should_continue( | ||
|
@@ -568,36 +618,58 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: | |
len(self.state.node_history) + 1, | ||
) | ||
|
||
# Store the current node before execution to detect handoffs | ||
previous_node = current_node | ||
|
||
# Execute node with timeout protection | ||
# TODO: Implement cancellation token to stop _execute_node from continuing | ||
try: | ||
await asyncio.wait_for( | ||
self._execute_node(current_node, self.state.task, invocation_state), | ||
timeout=self.node_timeout, | ||
# Execute with timeout wrapper for async generator streaming | ||
node_stream = ( | ||
self._stream_with_timeout( | ||
self._execute_node(current_node, self.state.task, invocation_state), | ||
self.node_timeout, | ||
f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s", | ||
) | ||
if self.node_timeout is not None | ||
else self._execute_node(current_node, self.state.task, invocation_state) | ||
) | ||
async for event in node_stream: | ||
yield event | ||
|
||
self.state.node_history.append(current_node) | ||
|
||
logger.debug("node=<%s> | node execution completed", current_node.node_id) | ||
|
||
# Check if the current node is still the same after execution | ||
# If it is, then no handoff occurred and we consider the swarm complete | ||
if self.state.current_node == current_node: | ||
# Check if handoff occurred during execution | ||
if self.state.current_node != previous_node: | ||
# Emit handoff event | ||
handoff_event = MultiAgentHandoffEvent( | ||
from_node=previous_node.node_id, | ||
to_node=self.state.current_node.node_id, | ||
message=self.state.handoff_message or "Agent handoff occurred", | ||
) | ||
yield handoff_event.as_dict() | ||
mkmeral marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
logger.debug( | ||
"from_node=<%s>, to_node=<%s> | handoff detected", | ||
previous_node.node_id, | ||
self.state.current_node.node_id, | ||
) | ||
else: | ||
# No handoff occurred, mark swarm as complete | ||
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) | ||
self.state.completion_status = Status.COMPLETED | ||
break | ||
|
||
except asyncio.TimeoutError: | ||
logger.exception( | ||
"node=<%s>, timeout=<%s>s | node execution timed out after timeout", | ||
current_node.node_id, | ||
self.node_timeout, | ||
) | ||
self.state.completion_status = Status.FAILED | ||
break | ||
|
||
except Exception: | ||
logger.exception("node=<%s> | node execution failed", current_node.node_id) | ||
except Exception as e: | ||
|
||
# Check if this is a timeout exception | ||
if "timed out after" in str(e): | ||
|
||
logger.exception( | ||
"node=<%s>, timeout=<%s>s | node execution timed out", | ||
current_node.node_id, | ||
self.node_timeout, | ||
) | ||
else: | ||
logger.exception("node=<%s> | node execution failed", current_node.node_id) | ||
self.state.completion_status = Status.FAILED | ||
break | ||
|
||
|
@@ -615,11 +687,15 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: | |
|
||
async def _execute_node( | ||
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] | ||
) -> AgentResult: | ||
) -> AsyncIterator[dict[str, Any]]: | ||
"""Execute swarm node.""" | ||
start_time = time.time() | ||
node_name = node.node_id | ||
|
||
# Emit node start event | ||
start_event = MultiAgentNodeStartEvent(node_id=node_name, node_type="agent") | ||
yield start_event.as_dict() | ||
|
||
try: | ||
# Prepare context for node | ||
context_text = self._build_node_input(node) | ||
|
@@ -632,11 +708,22 @@ async def _execute_node( | |
# Include additional ContentBlocks in node input | ||
node_input = node_input + task | ||
|
||
# Execute node | ||
result = None | ||
# Execute node with streaming | ||
node.reset_executor_state() | ||
# Unpacking since this is the agent class. Other executors should not unpack | ||
result = await node.executor.invoke_async(node_input, **invocation_state) | ||
|
||
# Stream agent events with node context and capture final result | ||
result = None | ||
async for event in node.executor.stream_async(node_input, **invocation_state): | ||
# Forward agent events with node context | ||
wrapped_event = MultiAgentNodeStreamEvent(node_name, event) | ||
yield wrapped_event.as_dict() | ||
# Capture the final result event | ||
if "result" in event: | ||
result = event["result"] | ||
|
||
# Use the captured result from streaming to avoid double execution | ||
if result is None: | ||
raise ValueError(f"Node '{node_name}' did not produce a result event") | ||
|
||
execution_time = round((time.time() - start_time) * 1000) | ||
|
||
|
@@ -664,15 +751,17 @@ async def _execute_node( | |
# Accumulate metrics | ||
self._accumulate_metrics(node_result) | ||
|
||
return result | ||
# Emit node complete event | ||
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) | ||
yield complete_event.as_dict() | ||
|
||
except Exception as e: | ||
execution_time = round((time.time() - start_time) * 1000) | ||
logger.exception("node=<%s> | node execution failed", node_name) | ||
|
||
# Create a NodeResult for the failed node | ||
node_result = NodeResult( | ||
result=e, # Store exception as result | ||
result=e, | ||
execution_time=execution_time, | ||
status=Status.FAILED, | ||
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), | ||
|
@@ -683,6 +772,10 @@ async def _execute_node( | |
# Store result in state | ||
self.state.results[node_name] = node_result | ||
|
||
# Still emit complete event for failures | ||
complete_event = MultiAgentNodeCompleteEvent(node_id=node_name, execution_time=execution_time) | ||
yield complete_event.as_dict() | ||
|
||
raise | ||
|
||
def _accumulate_metrics(self, node_result: NodeResult) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -351,3 +351,62 @@ def __init__(self, reason: str | Exception) -> None: | |
class AgentResultEvent(TypedEvent): | ||
def __init__(self, result: "AgentResult"): | ||
super().__init__({"result": result}) | ||
|
||
|
||
class MultiAgentNodeStartEvent(TypedEvent): | ||
"""Event emitted when a node begins execution in multi-agent context.""" | ||
|
||
def __init__(self, node_id: str, node_type: str) -> None: | ||
"""Initialize with node information. | ||
|
||
Args: | ||
node_id: Unique identifier for the node | ||
node_type: Type of node ("agent", "swarm", "graph") | ||
""" | ||
super().__init__({"multi_agent_node_start": True, "node_id": node_id, "node_type": node_type}) | ||
|
||
|
||
class MultiAgentNodeCompleteEvent(TypedEvent): | ||
"""Event emitted when a node completes execution.""" | ||
|
||
def __init__(self, node_id: str, execution_time: int) -> None: | ||
"""Initialize with completion information. | ||
|
||
Args: | ||
node_id: Unique identifier for the node | ||
execution_time: Execution time in milliseconds | ||
""" | ||
super().__init__({"multi_agent_node_complete": True, "node_id": node_id, "execution_time": execution_time}) | ||
|
||
|
||
class MultiAgentHandoffEvent(TypedEvent): | ||
"""Event emitted during agent handoffs in Swarm.""" | ||
|
||
def __init__(self, from_node: str, to_node: str, message: str) -> None: | ||
"""Initialize with handoff information. | ||
|
||
Args: | ||
from_node: Node ID handing off control | ||
to_node: Node ID receiving control | ||
message: Handoff message explaining the transfer | ||
""" | ||
super().__init__({"multi_agent_handoff": True, "from_node": from_node, "to_node": to_node, "message": message}) | ||
|
||
|
||
class MultiAgentNodeStreamEvent(TypedEvent): | ||
"""Event emitted during node execution - forwards agent events with node context.""" | ||
|
||
def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None: | ||
"""Initialize with node context and agent event. | ||
|
||
Args: | ||
node_id: Unique identifier for the node generating the event | ||
agent_event: The original agent event data | ||
""" | ||
super().__init__( | ||
{ | ||
"multi_agent_node_stream": True, | ||
"node_id": node_id, | ||
**agent_event, # Forward all original agent event data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's nest this instead of combining. Specifically |
||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an AgentResult as it is for single-agent streaming? If not, can we rename so that it doesn't conflict with different types