diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py index e251e9318..7d93f0cec 100644 --- a/src/strands/multiagent/__init__.py +++ b/src/strands/multiagent/__init__.py @@ -9,10 +9,12 @@ """ from .base import MultiAgentBase, MultiAgentResult +from .function_node import FunctionNode from .graph import GraphBuilder, GraphResult from .swarm import Swarm, SwarmResult __all__ = [ + "FunctionNode", "GraphBuilder", "GraphResult", "MultiAgentBase", diff --git a/src/strands/multiagent/function_node.py b/src/strands/multiagent/function_node.py new file mode 100644 index 000000000..039ad09de --- /dev/null +++ b/src/strands/multiagent/function_node.py @@ -0,0 +1,157 @@ +"""FunctionNode implementation for executing deterministic Python functions as graph nodes. + +This module provides the FunctionNode class that extends MultiAgentBase to execute +regular Python functions while maintaining compatibility with the existing graph +execution framework, proper error handling, metrics collection, and result formatting. +""" + +import logging +import time +from typing import Any, Protocol, Union + +from opentelemetry import trace as trace_api + +from ..agent import AgentResult +from ..telemetry import get_tracer +from ..telemetry.metrics import EventLoopMetrics +from ..types.content import ContentBlock, Message +from ..types.event_loop import Metrics, Usage +from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status + +logger = logging.getLogger(__name__) + + +class FunctionNodeCallable(Protocol): + """Protocol for functions that can be executed within FunctionNode.""" + + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> str | list[ContentBlock] | Message: + """Execute deterministic logic within the multiagent system.""" + ... + + +class FunctionNode(MultiAgentBase): + """Execute deterministic Python functions as graph nodes. + + FunctionNode wraps any callable Python function and executes it within the + established multiagent framework, handling input conversion, error management, + metrics collection, and result formatting automatically. + + Args: + func: The callable function to wrap and execute + name: Required name for the node + """ + + def __init__(self, func: FunctionNodeCallable, name: str): + """Initialize FunctionNode with a callable function and required name. + + Args: + func: The callable function to wrap and execute + name: Required name for the node + """ + self.func = func + self.name = name + self.tracer = get_tracer() + + async def invoke_async( + self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Execute the wrapped function and return formatted results. + + Args: + task: The task input (string or ContentBlock list) to pass to the function + invocation_state: Additional state/context (preserved for interface compatibility) + **kwargs: Additional keyword arguments (preserved for future extensibility) + + Returns: + MultiAgentResult containing the function execution results and metadata + """ + if invocation_state is None: + invocation_state = {} + + logger.debug("task=<%s> | starting function node execution", task) + logger.debug("function_name=<%s> | executing function", self.name) + + span = self.tracer.start_multiagent_span(task, "function_node") + with trace_api.use_span(span, end_on_exit=True): + try: + start_time = time.time() + # Execute the wrapped function with proper parameters + function_result = self.func(task, invocation_state, **kwargs) + # Calculate execution time + execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds + logger.debug( + "function_result=<%s>, execution_time=<%dms> | function executed successfully", + function_result, + execution_time, + ) + + # Convert function result to Message based on type + if isinstance(function_result, dict) and "role" in function_result and "content" in function_result: + # Already a Message + message = function_result + elif isinstance(function_result, list): + # List of ContentBlocks + message = Message(role="assistant", content=function_result) + else: + # String or other type - convert to string + message = Message(role="assistant", content=[ContentBlock(text=str(function_result))]) + agent_result = AgentResult( + stop_reason="end_turn", # "Normal completion of the response" - function executed successfully + message=message, + metrics=EventLoopMetrics(), + state={}, + ) + + # Create NodeResult for this function execution + node_result = NodeResult( + result=agent_result, # type is AgentResult + execution_time=execution_time, + status=Status.COMPLETED, + execution_count=1, + ) + + # Create MultiAgentResult with the NodeResult + multi_agent_result = MultiAgentResult( + status=Status.COMPLETED, + results={self.name: node_result}, + execution_count=1, + execution_time=execution_time, + ) + + logger.debug( + "function_name=<%s>, execution_time=<%dms> | function node completed successfully", + self.name, + execution_time, + ) + + return multi_agent_result + + except Exception as e: + # Calculate execution time even for failed executions + execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds + + logger.error("function_name=<%s>, error=<%s> | function node failed", self.name, e) + + # Create failed NodeResult with exception + node_result = NodeResult( + result=e, + execution_time=execution_time, + status=Status.FAILED, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=execution_time), + execution_count=1, + ) + + # Create failed MultiAgentResult + multi_agent_result = MultiAgentResult( + status=Status.FAILED, + results={self.name: node_result}, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=execution_time), + execution_count=1, + execution_time=execution_time, + ) + + return multi_agent_result diff --git a/tests/strands/multiagent/test_function_node.py b/tests/strands/multiagent/test_function_node.py new file mode 100644 index 000000000..b46ad8818 --- /dev/null +++ b/tests/strands/multiagent/test_function_node.py @@ -0,0 +1,145 @@ +"""Unit tests for FunctionNode implementation.""" + +from unittest.mock import Mock, patch + +import pytest + +from strands.multiagent.base import Status +from strands.multiagent.function_node import FunctionNode +from strands.types.content import ContentBlock, Message + + +@pytest.fixture +def mock_tracer(): + """Create a mock tracer for testing.""" + tracer = Mock() + span = Mock() + span.__enter__ = Mock(return_value=span) + span.__exit__ = Mock(return_value=None) + tracer.start_multiagent_span.return_value = span + return tracer + + +@pytest.mark.asyncio +async def test_invoke_async_string_input_success(mock_tracer): + """Test successful function execution with string input.""" + + def test_function(task, invocation_state=None, **kwargs): + return f"Processed: {task}" + + node = FunctionNode(test_function, "string_test") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test input") + + assert result.status == Status.COMPLETED + assert "string_test" in result.results + assert result.results["string_test"].status == Status.COMPLETED + assert result.accumulated_usage["inputTokens"] == 0 + assert result.accumulated_usage["outputTokens"] == 0 + + +@pytest.mark.asyncio +async def test_invoke_async_content_block_input_success(mock_tracer): + """Test successful function execution with ContentBlock input.""" + + def test_function(task, invocation_state=None, **kwargs): + return "ContentBlock processed" + + node = FunctionNode(test_function, "content_block_test") + content_blocks = [ContentBlock(text="First block"), ContentBlock(text="Second block")] + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async(content_blocks) + + assert result.status == Status.COMPLETED + assert "content_block_test" in result.results + node_result = result.results["content_block_test"] + assert node_result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_invoke_async_with_kwargs(mock_tracer): + """Test function execution with additional kwargs.""" + + def test_function(task, invocation_state=None, **kwargs): + extra_param = kwargs.get("extra_param", "none") + return f"Extra: {extra_param}" + + node = FunctionNode(test_function, "kwargs_test") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test", None, extra_param="test_value") + + assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_invoke_async_function_exception(mock_tracer): + """Test proper exception handling when function raises an error.""" + + def failing_function(task, invocation_state=None, **kwargs): + raise ValueError("Test exception") + + node = FunctionNode(failing_function, "exception_test") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test input") + + assert result.status == Status.FAILED + assert "exception_test" in result.results + node_result = result.results["exception_test"] + assert node_result.status == Status.FAILED + assert isinstance(node_result.result, ValueError) + assert str(node_result.result) == "Test exception" + + +@pytest.mark.asyncio +async def test_function_returns_string(mock_tracer): + """Test function returning string.""" + + def string_function(task, invocation_state=None, **kwargs): + return "Hello World" + + node = FunctionNode(string_function, "string_node") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test") + + agent_result = result.results["string_node"].result + assert agent_result.message["content"][0]["text"] == "Hello World" + + +@pytest.mark.asyncio +async def test_function_returns_content_blocks(mock_tracer): + """Test function returning list of ContentBlocks.""" + + def content_block_function(task, invocation_state=None, **kwargs): + return [ContentBlock(text="Block 1"), ContentBlock(text="Block 2")] + + node = FunctionNode(content_block_function, "content_node") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test") + + agent_result = result.results["content_node"].result + assert len(agent_result.message["content"]) == 2 + assert agent_result.message["content"][0]["text"] == "Block 1" + assert agent_result.message["content"][1]["text"] == "Block 2" + + +@pytest.mark.asyncio +async def test_function_returns_message(mock_tracer): + """Test function returning Message.""" + + def message_function(task, invocation_state=None, **kwargs): + return Message(role="user", content=[ContentBlock(text="Custom message")]) + + node = FunctionNode(message_function, "message_node") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test") + + agent_result = result.results["message_node"].result + assert agent_result.message["role"] == "user" + assert agent_result.message["content"][0]["text"] == "Custom message" diff --git a/tests_integ/test_multiagent_function_node.py b/tests_integ/test_multiagent_function_node.py new file mode 100644 index 000000000..11be9e3fa --- /dev/null +++ b/tests_integ/test_multiagent_function_node.py @@ -0,0 +1,44 @@ +"""Integration tests for FunctionNode with multiagent systems.""" + +import pytest + +from strands import Agent +from strands.multiagent.base import Status +from strands.multiagent.function_node import FunctionNode +from strands.multiagent.graph import GraphBuilder + +# Global variable to test function execution +test_global_var = None + + +def set_global_var(task, invocation_state=None, **kwargs): + """Simple function that sets a global variable.""" + global test_global_var + test_global_var = f"Function executed with: {task}" + return "Global variable set" + + +@pytest.mark.asyncio +async def test_agent_with_function_node(): + """Test graph with agent and function node.""" + global test_global_var + test_global_var = None + + # Create nodes + agent = Agent() + function_node = FunctionNode(set_global_var, "setter") + + # Build graph + builder = GraphBuilder() + builder.add_node(agent, "agent") + builder.add_node(function_node, "setter") + builder.add_edge("agent", "setter") + builder.set_entry_point("agent") + graph = builder.build() + + # Execute + result = await graph.invoke_async("Say hello") + + # Verify function was called + assert "Function executed with:" in test_global_var + assert result.status == Status.COMPLETED