diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index c6b1af702..6a6c31782 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -3,6 +3,8 @@ Provides minimal foundation for multi-agent patterns (Swarm, Graph). """ +import copy +import json from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum @@ -22,6 +24,105 @@ class Status(Enum): FAILED = "failed" +@dataclass +class MultiAgentNode: + """Base class for nodes in multi-agent systems.""" + + node_id: str + + def __hash__(self) -> int: + """Return hash for MultiAgentNode based on node_id.""" + return hash(self.node_id) + + def __eq__(self, other: Any) -> bool: + """Return equality for MultiAgentNode based on node_id.""" + if not isinstance(other, MultiAgentNode): + return False + return self.node_id == other.node_id + + +@dataclass +class SharedContext: + """Shared context between multi-agent nodes. + + This class provides a key-value store for sharing information across nodes + in multi-agent systems like Graph and Swarm. It validates that all values + are JSON serializable to ensure compatibility. + """ + + context: dict[str, dict[str, Any]] = field(default_factory=dict) + + def add_context(self, node: MultiAgentNode, key: str, value: Any) -> None: + """Add context for a specific node. + + Args: + node: The node object to add context for + key: The key to store the value under + value: The value to store (must be JSON serializable) + + Raises: + ValueError: If key is invalid or value is not JSON serializable + """ + self._validate_key(key) + self._validate_json_serializable(value) + + if node.node_id not in self.context: + self.context[node.node_id] = {} + self.context[node.node_id][key] = value + + def get_context(self, node: MultiAgentNode, key: str | None = None) -> Any: + """Get context for a specific node. + + Args: + node: The node object to get context for + key: The specific key to retrieve (if None, returns all context for the node) + + Returns: + The stored value, entire context dict for the node, or None if not found + """ + if node.node_id not in self.context: + return None if key else {} + + if key is None: + return copy.deepcopy(self.context[node.node_id]) + else: + value = self.context[node.node_id].get(key) + return copy.deepcopy(value) if value is not None else None + + def _validate_key(self, key: str) -> None: + """Validate that a key is valid. + + Args: + key: The key to validate + + Raises: + ValueError: If key is invalid + """ + if key is None: + raise ValueError("Key cannot be None") + if not isinstance(key, str): + raise ValueError("Key must be a string") + if not key.strip(): + raise ValueError("Key cannot be empty") + + def _validate_json_serializable(self, value: Any) -> None: + """Validate that a value is JSON serializable. + + Args: + value: The value to validate + + Raises: + ValueError: If value is not JSON serializable + """ + try: + json.dumps(value) + except (TypeError, ValueError) as e: + raise ValueError( + f"Value is not JSON serializable: {type(value).__name__}. " + f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." + ) from e + + @dataclass class NodeResult: """Unified result from node execution - handles both Agent and nested MultiAgentBase results. diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 9aee260b1..fde3d3ce4 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -29,7 +29,7 @@ from ..telemetry import get_tracer from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage -from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status +from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status logger = logging.getLogger(__name__) @@ -46,6 +46,7 @@ class GraphState: task: The original input prompt/query provided to the graph execution. This represents the actual work to be performed by the graph as a whole. Entry point nodes receive this task as their input if they have no dependencies. + shared_context: Context shared between graph nodes for storing user-defined state. """ # Task (with default empty string) @@ -61,6 +62,9 @@ class GraphState: # Results results: dict[str, NodeResult] = field(default_factory=dict) + # User-defined state shared across nodes + shared_context: "SharedContext" = field(default_factory=lambda: SharedContext()) + # Accumulated metrics accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) @@ -126,7 +130,7 @@ def should_traverse(self, state: GraphState) -> bool: @dataclass -class GraphNode: +class GraphNode(MultiAgentNode): """Represents a node in the graph. The execution_status tracks the node's lifecycle within graph orchestration: @@ -135,7 +139,6 @@ class GraphNode: - COMPLETED/FAILED: Node finished executing (regardless of result quality) """ - node_id: str executor: Agent | MultiAgentBase dependencies: set["GraphNode"] = field(default_factory=set) execution_status: Status = Status.PENDING @@ -389,6 +392,25 @@ def __init__( self.state = GraphState() self.tracer = get_tracer() + @property + def shared_context(self) -> SharedContext: + """Access to the shared context for storing user-defined state across graph nodes. + + Returns: + The SharedContext instance that can be used to store and retrieve + information that should be accessible to all nodes in the graph. + + Example: + ```python + graph = Graph(...) + node1 = graph.nodes["node1"] + node2 = graph.nodes["node2"] + graph.shared_context.add_context(node1, "file_reference", "/path/to/file") + graph.shared_context.get_context(node2, "file_reference") + ``` + """ + return self.state.shared_context + def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult: """Invoke the graph synchronously.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index a96c92de8..c3750b4eb 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -14,7 +14,6 @@ import asyncio import copy -import json import logging import time from concurrent.futures import ThreadPoolExecutor @@ -29,16 +28,15 @@ from ..tools.decorator import tool from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage -from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status +from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status logger = logging.getLogger(__name__) @dataclass -class SwarmNode: +class SwarmNode(MultiAgentNode): """Represents a node (e.g. Agent) in the swarm.""" - node_id: str executor: Agent _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) @@ -73,55 +71,6 @@ def reset_executor_state(self) -> None: self.executor.state = AgentState(self._initial_state.get()) -@dataclass -class SharedContext: - """Shared context between swarm nodes.""" - - context: dict[str, dict[str, Any]] = field(default_factory=dict) - - def add_context(self, node: SwarmNode, key: str, value: Any) -> None: - """Add context.""" - self._validate_key(key) - self._validate_json_serializable(value) - - if node.node_id not in self.context: - self.context[node.node_id] = {} - self.context[node.node_id][key] = value - - def _validate_key(self, key: str) -> None: - """Validate that a key is valid. - - Args: - key: The key to validate - - Raises: - ValueError: If key is invalid - """ - if key is None: - raise ValueError("Key cannot be None") - if not isinstance(key, str): - raise ValueError("Key must be a string") - if not key.strip(): - raise ValueError("Key cannot be empty") - - def _validate_json_serializable(self, value: Any) -> None: - """Validate that a value is JSON serializable. - - Args: - value: The value to validate - - Raises: - ValueError: If value is not JSON serializable - """ - try: - json.dumps(value) - except (TypeError, ValueError) as e: - raise ValueError( - f"Value is not JSON serializable: {type(value).__name__}. " - f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." - ) from e - - @dataclass class SwarmState: """Current state of swarm execution.""" @@ -654,3 +603,8 @@ def _build_result(self) -> SwarmResult: execution_time=self.state.execution_time, node_history=self.state.node_history, ) + + +# Backward compatibility aliases +# These ensure that existing imports continue to work +__all__ = ["SwarmNode", "SharedContext", "Status"] diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 7aa76bb90..e70b86c37 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -1,149 +1,146 @@ -import pytest - -from strands.agent import AgentResult -from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status - - -@pytest.fixture -def agent_result(): - """Create a mock AgentResult for testing.""" - return AgentResult( - message={"role": "assistant", "content": [{"text": "Test response"}]}, - stop_reason="end_turn", - state={}, - metrics={}, - ) - - -def test_node_result_initialization_and_properties(agent_result): - """Test NodeResult initialization and property access.""" - # Basic initialization - node_result = NodeResult(result=agent_result, execution_time=50, status="completed") - - # Verify properties - assert node_result.result == agent_result - assert node_result.execution_time == 50 - assert node_result.status == "completed" - assert node_result.accumulated_usage == {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} - assert node_result.accumulated_metrics == {"latencyMs": 0.0} - assert node_result.execution_count == 0 - - # With custom metrics - custom_usage = {"inputTokens": 100, "outputTokens": 200, "totalTokens": 300} - custom_metrics = {"latencyMs": 250.0} - node_result_custom = NodeResult( - result=agent_result, - execution_time=75, - status="completed", - accumulated_usage=custom_usage, - accumulated_metrics=custom_metrics, - execution_count=5, - ) - assert node_result_custom.accumulated_usage == custom_usage - assert node_result_custom.accumulated_metrics == custom_metrics - assert node_result_custom.execution_count == 5 - - # Test default factory creates independent instances - node_result1 = NodeResult(result=agent_result) - node_result2 = NodeResult(result=agent_result) - node_result1.accumulated_usage["inputTokens"] = 100 - assert node_result2.accumulated_usage["inputTokens"] == 0 - assert node_result1.accumulated_usage is not node_result2.accumulated_usage - - -def test_node_result_get_agent_results(agent_result): - """Test get_agent_results method with different structures.""" - # Simple case with single AgentResult - node_result = NodeResult(result=agent_result) - agent_results = node_result.get_agent_results() - assert len(agent_results) == 1 - assert agent_results[0] == agent_result - - # Test with Exception as result (should return empty list) - exception_result = NodeResult(result=Exception("Test exception"), status=Status.FAILED) - agent_results = exception_result.get_agent_results() - assert len(agent_results) == 0 - - # Complex nested case - inner_agent_result1 = AgentResult( - message={"role": "assistant", "content": [{"text": "Response 1"}]}, stop_reason="end_turn", state={}, metrics={} - ) - inner_agent_result2 = AgentResult( - message={"role": "assistant", "content": [{"text": "Response 2"}]}, stop_reason="end_turn", state={}, metrics={} - ) - - inner_node_result1 = NodeResult(result=inner_agent_result1) - inner_node_result2 = NodeResult(result=inner_agent_result2) - - multi_agent_result = MultiAgentResult(results={"node1": inner_node_result1, "node2": inner_node_result2}) +"""Tests for MultiAgentBase module.""" - outer_node_result = NodeResult(result=multi_agent_result) - agent_results = outer_node_result.get_agent_results() - - assert len(agent_results) == 2 - response_texts = [result.message["content"][0]["text"] for result in agent_results] - assert "Response 1" in response_texts - assert "Response 2" in response_texts - - -def test_multi_agent_result_initialization(agent_result): - """Test MultiAgentResult initialization with defaults and custom values.""" - # Default initialization - result = MultiAgentResult(results={}) - assert result.results == {} - assert result.accumulated_usage == {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0} - assert result.accumulated_metrics == {"latencyMs": 0.0} - assert result.execution_count == 0 - assert result.execution_time == 0 - - # Custom values`` - node_result = NodeResult(result=agent_result) - results = {"test_node": node_result} - usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} - metrics = {"latencyMs": 200.0} - - result = MultiAgentResult( - results=results, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=3, execution_time=300 - ) - - assert result.results == results - assert result.accumulated_usage == usage - assert result.accumulated_metrics == metrics - assert result.execution_count == 3 - assert result.execution_time == 300 - - # Test default factory creates independent instances - result1 = MultiAgentResult(results={}) - result2 = MultiAgentResult(results={}) - result1.accumulated_usage["inputTokens"] = 200 - result1.accumulated_metrics["latencyMs"] = 500.0 - assert result2.accumulated_usage["inputTokens"] == 0 - assert result2.accumulated_metrics["latencyMs"] == 0.0 - assert result1.accumulated_usage is not result2.accumulated_usage - assert result1.accumulated_metrics is not result2.accumulated_metrics - - -def test_multi_agent_base_abstract_behavior(): - """Test abstract class behavior of MultiAgentBase.""" - # Test that MultiAgentBase cannot be instantiated directly - with pytest.raises(TypeError): - MultiAgentBase() - - # Test that incomplete implementations raise TypeError - class IncompleteMultiAgent(MultiAgentBase): - pass - - with pytest.raises(TypeError): - IncompleteMultiAgent() - - # Test that complete implementations can be instantiated - class CompleteMultiAgent(MultiAgentBase): - async def invoke_async(self, task: str) -> MultiAgentResult: - return MultiAgentResult(results={}) - - def __call__(self, task: str) -> MultiAgentResult: - return MultiAgentResult(results={}) +import pytest - # Should not raise an exception - agent = CompleteMultiAgent() - assert isinstance(agent, MultiAgentBase) +from strands.multiagent.base import SharedContext + + +def test_shared_context_initialization(): + """Test SharedContext initialization.""" + context = SharedContext() + assert context.context == {} + + # Test with initial context + initial_context = {"node1": {"key1": "value1"}} + context = SharedContext(initial_context) + assert context.context == initial_context + + +def test_shared_context_add_context(): + """Test adding context to SharedContext.""" + context = SharedContext() + + # Create mock nodes + node1 = type('MockNode', (), {'node_id': 'node1'})() + node2 = type('MockNode', (), {'node_id': 'node2'})() + + # Add context for a node + context.add_context(node1, "key1", "value1") + assert context.context["node1"]["key1"] == "value1" + + # Add more context for the same node + context.add_context(node1, "key2", "value2") + assert context.context["node1"]["key1"] == "value1" + assert context.context["node1"]["key2"] == "value2" + + # Add context for a different node + context.add_context(node2, "key1", "value3") + assert context.context["node2"]["key1"] == "value3" + assert "node2" not in context.context["node1"] + + +def test_shared_context_get_context(): + """Test getting context from SharedContext.""" + context = SharedContext() + + # Create mock nodes + node1 = type('MockNode', (), {'node_id': 'node1'})() + node2 = type('MockNode', (), {'node_id': 'node2'})() + non_existent_node = type('MockNode', (), {'node_id': 'non_existent_node'})() + + # Add some test data + context.add_context(node1, "key1", "value1") + context.add_context(node1, "key2", "value2") + context.add_context(node2, "key1", "value3") + + # Get specific key + assert context.get_context(node1, "key1") == "value1" + assert context.get_context(node1, "key2") == "value2" + assert context.get_context(node2, "key1") == "value3" + + # Get all context for a node + node1_context = context.get_context(node1) + assert node1_context == {"key1": "value1", "key2": "value2"} + + # Get context for non-existent node + assert context.get_context(non_existent_node) == {} + assert context.get_context(non_existent_node, "key") is None + + +def test_shared_context_validation(): + """Test SharedContext input validation.""" + context = SharedContext() + + # Create mock node + node1 = type('MockNode', (), {'node_id': 'node1'})() + + # Test invalid key validation + with pytest.raises(ValueError, match="Key cannot be None"): + context.add_context(node1, None, "value") + + with pytest.raises(ValueError, match="Key must be a string"): + context.add_context(node1, 123, "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + context.add_context(node1, "", "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + context.add_context(node1, " ", "value") + + # Test JSON serialization validation + with pytest.raises(ValueError, match="Value is not JSON serializable"): + context.add_context(node1, "key", lambda x: x) # Function not serializable + + # Test valid values + context.add_context(node1, "string", "hello") + context.add_context(node1, "number", 42) + context.add_context(node1, "boolean", True) + context.add_context(node1, "list", [1, 2, 3]) + context.add_context(node1, "dict", {"nested": "value"}) + context.add_context(node1, "none", None) + + +def test_shared_context_isolation(): + """Test that SharedContext provides proper isolation between nodes.""" + context = SharedContext() + + # Create mock nodes + node1 = type('MockNode', (), {'node_id': 'node1'})() + node2 = type('MockNode', (), {'node_id': 'node2'})() + + # Add context for different nodes + context.add_context(node1, "key1", "value1") + context.add_context(node2, "key1", "value2") + + # Ensure nodes don't interfere with each other + assert context.get_context(node1, "key1") == "value1" + assert context.get_context(node2, "key1") == "value2" + + # Getting all context for a node should only return that node's context + assert context.get_context(node1) == {"key1": "value1"} + assert context.get_context(node2) == {"key1": "value2"} + + +def test_shared_context_copy_semantics(): + """Test that SharedContext.get_context returns copies to prevent mutation.""" + context = SharedContext() + + # Create mock node + node1 = type('MockNode', (), {'node_id': 'node1'})() + + # Add a mutable value + context.add_context(node1, "mutable", [1, 2, 3]) + + # Get the context and modify it + retrieved_context = context.get_context(node1) + retrieved_context["mutable"].append(4) + + # The original should remain unchanged + assert context.get_context(node1, "mutable") == [1, 2, 3] + + # Test that getting all context returns a copy + all_context = context.get_context(node1) + all_context["new_key"] = "new_value" + + # The original should remain unchanged + assert "new_key" not in context.get_context(node1) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c60361da8..5d4ad9334 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -797,19 +797,96 @@ def test_condition(state): # Test GraphEdge hashing node_x = GraphNode("x", mock_agent_a) node_y = GraphNode("y", mock_agent_b) - edge1 = GraphEdge(node_x, node_y) - edge2 = GraphEdge(node_x, node_y) - edge3 = GraphEdge(node_y, node_x) - assert hash(edge1) == hash(edge2) - assert hash(edge1) != hash(edge3) - - # Test GraphNode initialization - mock_agent = create_mock_agent("test_agent") - node = GraphNode("test_node", mock_agent) - assert node.node_id == "test_node" - assert node.executor == mock_agent - assert node.execution_status == Status.PENDING - assert len(node.dependencies) == 0 + edge_x_y = GraphEdge(node_x, node_y) + edge_y_x = GraphEdge(node_y, node_x) + + # Different edges should have different hashes + assert hash(edge_x_y) != hash(edge_y_x) + + # Same edge should have same hash + edge_x_y_duplicate = GraphEdge(node_x, node_y) + assert hash(edge_x_y) == hash(edge_x_y_duplicate) + + +def test_graph_shared_context(): + """Test that Graph exposes shared context for user-defined state.""" + # Create a simple graph + mock_agent_a = create_mock_agent("agent_a") + mock_agent_b = create_mock_agent("agent_b") + + builder = GraphBuilder() + builder.add_node(mock_agent_a, "node_a") + builder.add_node(mock_agent_b, "node_b") + builder.add_edge("node_a", "node_b") + builder.set_entry_point("node_a") + + graph = builder.build() + + # Test that shared_context is accessible + assert hasattr(graph, "shared_context") + assert graph.shared_context is not None + + # Get node objects + node_a = graph.nodes["node_a"] + node_b = graph.nodes["node_b"] + + # Test adding context + graph.shared_context.add_context(node_a, "file_reference", "/path/to/file") + graph.shared_context.add_context(node_a, "data", {"key": "value"}) + + # Test getting context + assert graph.shared_context.get_context(node_a, "file_reference") == "/path/to/file" + assert graph.shared_context.get_context(node_a, "data") == {"key": "value"} + assert graph.shared_context.get_context(node_a) == {"file_reference": "/path/to/file", "data": {"key": "value"}} + + # Test getting context for non-existent node + non_existent_node = type('MockNode', (), {'node_id': 'non_existent_node'})() + assert graph.shared_context.get_context(non_existent_node) == {} + assert graph.shared_context.get_context(non_existent_node, "key") is None + + # Test that context is shared across nodes + graph.shared_context.add_context(node_b, "shared_data", "accessible_to_all") + assert graph.shared_context.get_context(node_a, "shared_data") is None # Different node + assert graph.shared_context.get_context(node_b, "shared_data") == "accessible_to_all" + + +def test_graph_shared_context_validation(): + """Test that Graph shared context validates input properly.""" + mock_agent = create_mock_agent("agent") + + builder = GraphBuilder() + builder.add_node(mock_agent, "node") + builder.set_entry_point("node") + + graph = builder.build() + + # Get node object + node = graph.nodes["node"] + + # Test invalid key validation + with pytest.raises(ValueError, match="Key cannot be None"): + graph.shared_context.add_context(node, None, "value") + + with pytest.raises(ValueError, match="Key must be a string"): + graph.shared_context.add_context(node, 123, "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + graph.shared_context.add_context(node, "", "value") + + with pytest.raises(ValueError, match="Key cannot be empty"): + graph.shared_context.add_context(node, " ", "value") + + # Test JSON serialization validation + with pytest.raises(ValueError, match="Value is not JSON serializable"): + graph.shared_context.add_context(node, "key", lambda x: x) # Function not serializable + + # Test valid values + graph.shared_context.add_context(node, "string", "hello") + graph.shared_context.add_context(node, "number", 42) + graph.shared_context.add_context(node, "boolean", True) + graph.shared_context.add_context(node, "list", [1, 2, 3]) + graph.shared_context.add_context(node, "dict", {"nested": "value"}) + graph.shared_context.add_context(node, "none", None) def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_agents):