Skip to content

feat: expose user-defined state in MultiAgent Graph #703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
"""

import copy
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
Expand All @@ -22,6 +24,88 @@ class Status(Enum):
FAILED = "failed"


@dataclass
class SharedContext:
"""Shared context between multi-agent nodes.

This class provides a key-value store for sharing information across nodes
in multi-agent systems like Graph and Swarm. It validates that all values
are JSON serializable to ensure compatibility.
"""

context: dict[str, dict[str, Any]] = field(default_factory=dict)

def add_context(self, node_id: str, key: str, value: Any) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising this!

I have a couple of concerns relating to backwards compatibility.

It looks like we switched from SwarmNode to node_id. Can we instead retain the Node object. Refactor SwarmNode into base.py as some MultiAgentNode.

Then we need to maintain backwards compatibility via aliases in swarm. meaning we do not want to break imports as right now it will be broken if a user has an import like from strands.multiagent.swarm import SharedContext so we need to avoid breaking consumers for SharedContext and Node.

"""Add context for a specific node.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also looks like unit tests are failing. May need to rebase or address them if they are still failing


Args:
node_id: The ID of the node adding the context
key: The key to store the value under
value: The value to store (must be JSON serializable)

Raises:
ValueError: If key is invalid or value is not JSON serializable
"""
self._validate_key(key)
self._validate_json_serializable(value)

if node_id not in self.context:
self.context[node_id] = {}
self.context[node_id][key] = value

def get_context(self, node_id: str, key: str | None = None) -> Any:
"""Get context for a specific node.

Args:
node_id: The ID of the node to get context for
key: The specific key to retrieve (if None, returns all context for the node)

Returns:
The stored value, entire context dict for the node, or None if not found
"""
if node_id not in self.context:
return None if key else {}

if key is None:
return copy.deepcopy(self.context[node_id])
else:
value = self.context[node_id].get(key)
return copy.deepcopy(value) if value is not None else None

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
Expand Down
23 changes: 22 additions & 1 deletion src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..telemetry import get_tracer
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import MultiAgentBase, MultiAgentResult, NodeResult, SharedContext, Status

logger = logging.getLogger(__name__)

Expand All @@ -46,6 +46,7 @@ class GraphState:
task: The original input prompt/query provided to the graph execution.
This represents the actual work to be performed by the graph as a whole.
Entry point nodes receive this task as their input if they have no dependencies.
shared_context: Context shared between graph nodes for storing user-defined state.
"""

# Task (with default empty string)
Expand All @@ -61,6 +62,9 @@ class GraphState:
# Results
results: dict[str, NodeResult] = field(default_factory=dict)

# User-defined state shared across nodes
shared_context: "SharedContext" = field(default_factory=lambda: SharedContext())

# Accumulated metrics
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
Expand Down Expand Up @@ -389,6 +393,23 @@ def __init__(
self.state = GraphState()
self.tracer = get_tracer()

@property
def shared_context(self) -> SharedContext:
"""Access to the shared context for storing user-defined state across graph nodes.

Returns:
The SharedContext instance that can be used to store and retrieve
information that should be accessible to all nodes in the graph.

Example:
```python
graph = Graph(...)
graph.shared_context.add_context("node1", "file_reference", "/path/to/file")
graph.shared_context.get_context("node2", "file_reference")
```
"""
return self.state.shared_context

def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
"""Invoke the graph synchronously."""

Expand Down
54 changes: 2 additions & 52 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import asyncio
import copy
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -29,7 +28,7 @@
from ..tools.decorator import tool
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import MultiAgentBase, MultiAgentResult, NodeResult, SharedContext, Status

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,55 +72,6 @@ def reset_executor_state(self) -> None:
self.executor.state = AgentState(self._initial_state.get())


@dataclass
class SharedContext:
"""Shared context between swarm nodes."""

context: dict[str, dict[str, Any]] = field(default_factory=dict)

def add_context(self, node: SwarmNode, key: str, value: Any) -> None:
"""Add context."""
self._validate_key(key)
self._validate_json_serializable(value)

if node.node_id not in self.context:
self.context[node.node_id] = {}
self.context[node.node_id][key] = value

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e


@dataclass
class SwarmState:
"""Current state of swarm execution."""
Expand Down Expand Up @@ -405,7 +355,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
# Store handoff context as shared context
if context:
for key, value in context.items():
self.shared_context.add_context(previous_agent, key, value)
self.shared_context.add_context(previous_agent.node_id, key, value)

logger.debug(
"from_node=<%s>, to_node=<%s> | handed off from agent to agent",
Expand Down
Loading
Loading