Skip to content

feat: expose user-defined state in MultiAgent Graph #703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
"""

import copy
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
Expand All @@ -22,6 +24,105 @@ class Status(Enum):
FAILED = "failed"


@dataclass
class MultiAgentNode:
"""Base class for nodes in multi-agent systems."""

node_id: str

def __hash__(self) -> int:
"""Return hash for MultiAgentNode based on node_id."""
return hash(self.node_id)

def __eq__(self, other: Any) -> bool:
"""Return equality for MultiAgentNode based on node_id."""
if not isinstance(other, MultiAgentNode):
return False
return self.node_id == other.node_id


@dataclass
class SharedContext:
"""Shared context between multi-agent nodes.

This class provides a key-value store for sharing information across nodes
in multi-agent systems like Graph and Swarm. It validates that all values
are JSON serializable to ensure compatibility.
"""

context: dict[str, dict[str, Any]] = field(default_factory=dict)

def add_context(self, node: MultiAgentNode, key: str, value: Any) -> None:
"""Add context for a specific node.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also looks like unit tests are failing. May need to rebase or address them if they are still failing


Args:
node: The node object to add context for
key: The key to store the value under
value: The value to store (must be JSON serializable)

Raises:
ValueError: If key is invalid or value is not JSON serializable
"""
self._validate_key(key)
self._validate_json_serializable(value)

if node.node_id not in self.context:
self.context[node.node_id] = {}
self.context[node.node_id][key] = value

def get_context(self, node: MultiAgentNode, key: str | None = None) -> Any:
"""Get context for a specific node.

Args:
node: The node object to get context for
key: The specific key to retrieve (if None, returns all context for the node)

Returns:
The stored value, entire context dict for the node, or None if not found
"""
if node.node_id not in self.context:
return None if key else {}

if key is None:
return copy.deepcopy(self.context[node.node_id])
else:
value = self.context[node.node_id].get(key)
return copy.deepcopy(value) if value is not None else None

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
Expand Down
28 changes: 25 additions & 3 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..telemetry import get_tracer
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status

logger = logging.getLogger(__name__)

Expand All @@ -46,6 +46,7 @@ class GraphState:
task: The original input prompt/query provided to the graph execution.
This represents the actual work to be performed by the graph as a whole.
Entry point nodes receive this task as their input if they have no dependencies.
shared_context: Context shared between graph nodes for storing user-defined state.
"""

# Task (with default empty string)
Expand All @@ -61,6 +62,9 @@ class GraphState:
# Results
results: dict[str, NodeResult] = field(default_factory=dict)

# User-defined state shared across nodes
shared_context: "SharedContext" = field(default_factory=lambda: SharedContext())

# Accumulated metrics
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
Expand Down Expand Up @@ -126,7 +130,7 @@ def should_traverse(self, state: GraphState) -> bool:


@dataclass
class GraphNode:
class GraphNode(MultiAgentNode):
"""Represents a node in the graph.

The execution_status tracks the node's lifecycle within graph orchestration:
Expand All @@ -135,7 +139,6 @@ class GraphNode:
- COMPLETED/FAILED: Node finished executing (regardless of result quality)
"""

node_id: str
executor: Agent | MultiAgentBase
dependencies: set["GraphNode"] = field(default_factory=set)
execution_status: Status = Status.PENDING
Expand Down Expand Up @@ -389,6 +392,25 @@ def __init__(
self.state = GraphState()
self.tracer = get_tracer()

@property
def shared_context(self) -> SharedContext:
"""Access to the shared context for storing user-defined state across graph nodes.

Returns:
The SharedContext instance that can be used to store and retrieve
information that should be accessible to all nodes in the graph.

Example:
```python
graph = Graph(...)
node1 = graph.nodes["node1"]
node2 = graph.nodes["node2"]
graph.shared_context.add_context(node1, "file_reference", "/path/to/file")
graph.shared_context.get_context(node2, "file_reference")
```
"""
return self.state.shared_context

def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
"""Invoke the graph synchronously."""

Expand Down
60 changes: 7 additions & 53 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import asyncio
import copy
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -29,16 +28,15 @@
from ..tools.decorator import tool
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status

logger = logging.getLogger(__name__)


@dataclass
class SwarmNode:
class SwarmNode(MultiAgentNode):
"""Represents a node (e.g. Agent) in the swarm."""

node_id: str
executor: Agent
_initial_messages: Messages = field(default_factory=list, init=False)
_initial_state: AgentState = field(default_factory=AgentState, init=False)
Expand Down Expand Up @@ -73,55 +71,6 @@ def reset_executor_state(self) -> None:
self.executor.state = AgentState(self._initial_state.get())


@dataclass
class SharedContext:
"""Shared context between swarm nodes."""

context: dict[str, dict[str, Any]] = field(default_factory=dict)

def add_context(self, node: SwarmNode, key: str, value: Any) -> None:
"""Add context."""
self._validate_key(key)
self._validate_json_serializable(value)

if node.node_id not in self.context:
self.context[node.node_id] = {}
self.context[node.node_id][key] = value

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e


@dataclass
class SwarmState:
"""Current state of swarm execution."""
Expand Down Expand Up @@ -654,3 +603,8 @@ def _build_result(self) -> SwarmResult:
execution_time=self.state.execution_time,
node_history=self.state.node_history,
)


# Backward compatibility aliases
# These ensure that existing imports continue to work
__all__ = ["SwarmNode", "SharedContext", "Status"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing, thanks! I will pull this down and test a little bit today but looks great!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @dbschmigelski ,

please test and merge this if you find any issue to fix please let me know

Loading
Loading