From 838c23665772267cdc00290abc7952fffbd1a08b Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 6 Oct 2025 18:34:11 -0400 Subject: [PATCH 1/3] feat(multiagent): add FunctionNode to improve DX for deterministic cases --- src/strands/multiagent/__init__.py | 2 + src/strands/multiagent/function_node.py | 155 ++++++++++++++++++ .../strands/multiagent/test_function_node.py | 94 +++++++++++ tests_integ/test_multiagent_function_node.py | 44 +++++ 4 files changed, 295 insertions(+) create mode 100644 src/strands/multiagent/function_node.py create mode 100644 tests/strands/multiagent/test_function_node.py create mode 100644 tests_integ/test_multiagent_function_node.py diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py index e251e9318..7d93f0cec 100644 --- a/src/strands/multiagent/__init__.py +++ b/src/strands/multiagent/__init__.py @@ -9,10 +9,12 @@ """ from .base import MultiAgentBase, MultiAgentResult +from .function_node import FunctionNode from .graph import GraphBuilder, GraphResult from .swarm import Swarm, SwarmResult __all__ = [ + "FunctionNode", "GraphBuilder", "GraphResult", "MultiAgentBase", diff --git a/src/strands/multiagent/function_node.py b/src/strands/multiagent/function_node.py new file mode 100644 index 000000000..56774283d --- /dev/null +++ b/src/strands/multiagent/function_node.py @@ -0,0 +1,155 @@ +"""FunctionNode implementation for executing deterministic Python functions as graph nodes. + +This module provides the FunctionNode class that extends MultiAgentBase to execute +regular Python functions while maintaining compatibility with the existing graph +execution framework, proper error handling, metrics collection, and result formatting. +""" + +import logging +import time +from typing import Any, Protocol, Union + +from opentelemetry import trace as trace_api + +from ..agent import AgentResult +from ..telemetry import get_tracer +from ..telemetry.metrics import EventLoopMetrics +from ..types.content import ContentBlock, Message +from ..types.event_loop import Metrics, Usage +from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status + +logger = logging.getLogger(__name__) + + +class FunctionNodeCallable(Protocol): + """Protocol defining the required signature for functions used in FunctionNode. + + Functions must accept: + - task: The input task (string or ContentBlock list) + - invocation_state: Additional state/context from the calling environment + - **kwargs: Additional keyword arguments for future extensibility + + Functions must return: + - A string result that will be converted to a Message + """ + + def __call__( + self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> str: + """Execute the node with the given task.""" + ... + + +class FunctionNode(MultiAgentBase): + """Execute deterministic Python functions as graph nodes. + + FunctionNode wraps any callable Python function and executes it within the + established multiagent framework, handling input conversion, error management, + metrics collection, and result formatting automatically. + + Args: + func: The callable function to wrap and execute + name: Required name for the node + """ + + def __init__(self, func: FunctionNodeCallable, name: str): + """Initialize FunctionNode with a callable function and required name. + + Args: + func: The callable function to wrap and execute + name: Required name for the node + """ + self.func = func + self.name = name + self.tracer = get_tracer() + + async def invoke_async( + self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> MultiAgentResult: + """Execute the wrapped function and return formatted results. + + Args: + task: The task input (string or ContentBlock list) to pass to the function + invocation_state: Additional state/context (preserved for interface compatibility) + **kwargs: Additional keyword arguments (preserved for future extensibility) + + Returns: + MultiAgentResult containing the function execution results and metadata + """ + if invocation_state is None: + invocation_state = {} + + logger.debug("task=<%s> | starting function node execution", task) + logger.debug("function_name=<%s> | executing function", self.name) + + start_time = time.time() + span = self.tracer.start_multiagent_span(task, "function_node") + with trace_api.use_span(span, end_on_exit=True): + try: + # Execute the wrapped function with proper parameters + function_result = self.func(task, invocation_state, **kwargs) + logger.debug("function_result=<%s> | function executed successfully", function_result) + + # Calculate execution time + execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds + + # Convert function result to Message + message = Message(role="assistant", content=[ContentBlock(text=str(function_result))]) + agent_result = AgentResult( + stop_reason="end_turn", # "Normal completion of the response" - function executed successfully + message=message, + metrics=EventLoopMetrics(), + state={}, + ) + + # Create NodeResult for this function execution + node_result = NodeResult( + result=agent_result, # type is AgentResult + execution_time=execution_time, + status=Status.COMPLETED, + execution_count=1, + ) + + # Create MultiAgentResult with the NodeResult + multi_agent_result = MultiAgentResult( + status=Status.COMPLETED, + results={self.name: node_result}, + execution_count=1, + execution_time=execution_time, + ) + + logger.debug( + "function_name=<%s>, execution_time=<%dms> | function node completed successfully", + self.name, + execution_time, + ) + + return multi_agent_result + + except Exception as e: + # Calculate execution time even for failed executions + execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds + + logger.error("function_name=<%s>, error=<%s> | function node failed", self.name, e) + + # Create failed NodeResult with exception + node_result = NodeResult( + result=e, + execution_time=execution_time, + status=Status.FAILED, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=execution_time), + execution_count=1, + ) + + # Create failed MultiAgentResult + multi_agent_result = MultiAgentResult( + status=Status.FAILED, + results={self.name: node_result}, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=execution_time), + execution_count=1, + execution_time=execution_time, + ) + + return multi_agent_result diff --git a/tests/strands/multiagent/test_function_node.py b/tests/strands/multiagent/test_function_node.py new file mode 100644 index 000000000..aee48827e --- /dev/null +++ b/tests/strands/multiagent/test_function_node.py @@ -0,0 +1,94 @@ +"""Unit tests for FunctionNode implementation.""" + +from unittest.mock import Mock, patch + +import pytest + +from strands.multiagent.base import Status +from strands.multiagent.function_node import FunctionNode +from strands.types.content import ContentBlock + + +@pytest.fixture +def mock_tracer(): + """Create a mock tracer for testing.""" + tracer = Mock() + span = Mock() + span.__enter__ = Mock(return_value=span) + span.__exit__ = Mock(return_value=None) + tracer.start_multiagent_span.return_value = span + return tracer + + +@pytest.mark.asyncio +async def test_invoke_async_string_input_success(mock_tracer): + """Test successful function execution with string input.""" + + def test_function(task, invocation_state=None, **kwargs): + return f"Processed: {task}" + + node = FunctionNode(test_function, "string_test") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test input") + + assert result.status == Status.COMPLETED + assert "string_test" in result.results + assert result.results["string_test"].status == Status.COMPLETED + assert result.accumulated_usage["inputTokens"] == 0 + assert result.accumulated_usage["outputTokens"] == 0 + + +@pytest.mark.asyncio +async def test_invoke_async_content_block_input_success(mock_tracer): + """Test successful function execution with ContentBlock input.""" + + def test_function(task, invocation_state=None, **kwargs): + return "ContentBlock processed" + + node = FunctionNode(test_function, "content_block_test") + content_blocks = [ContentBlock(text="First block"), ContentBlock(text="Second block")] + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async(content_blocks) + + assert result.status == Status.COMPLETED + assert "content_block_test" in result.results + node_result = result.results["content_block_test"] + assert node_result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_invoke_async_with_kwargs(mock_tracer): + """Test function execution with additional kwargs.""" + + def test_function(task, invocation_state=None, **kwargs): + extra_param = kwargs.get("extra_param", "none") + return f"Extra: {extra_param}" + + node = FunctionNode(test_function, "kwargs_test") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test", None, extra_param="test_value") + + assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_invoke_async_function_exception(mock_tracer): + """Test proper exception handling when function raises an error.""" + + def failing_function(task, invocation_state=None, **kwargs): + raise ValueError("Test exception") + + node = FunctionNode(failing_function, "exception_test") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test input") + + assert result.status == Status.FAILED + assert "exception_test" in result.results + node_result = result.results["exception_test"] + assert node_result.status == Status.FAILED + assert isinstance(node_result.result, ValueError) + assert str(node_result.result) == "Test exception" diff --git a/tests_integ/test_multiagent_function_node.py b/tests_integ/test_multiagent_function_node.py new file mode 100644 index 000000000..11be9e3fa --- /dev/null +++ b/tests_integ/test_multiagent_function_node.py @@ -0,0 +1,44 @@ +"""Integration tests for FunctionNode with multiagent systems.""" + +import pytest + +from strands import Agent +from strands.multiagent.base import Status +from strands.multiagent.function_node import FunctionNode +from strands.multiagent.graph import GraphBuilder + +# Global variable to test function execution +test_global_var = None + + +def set_global_var(task, invocation_state=None, **kwargs): + """Simple function that sets a global variable.""" + global test_global_var + test_global_var = f"Function executed with: {task}" + return "Global variable set" + + +@pytest.mark.asyncio +async def test_agent_with_function_node(): + """Test graph with agent and function node.""" + global test_global_var + test_global_var = None + + # Create nodes + agent = Agent() + function_node = FunctionNode(set_global_var, "setter") + + # Build graph + builder = GraphBuilder() + builder.add_node(agent, "agent") + builder.add_node(function_node, "setter") + builder.add_edge("agent", "setter") + builder.set_entry_point("agent") + graph = builder.build() + + # Execute + result = await graph.invoke_async("Say hello") + + # Verify function was called + assert "Function executed with:" in test_global_var + assert result.status == Status.COMPLETED From 709ef4b72ab865d04c4853742fe808fb3bf01f4c Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 7 Oct 2025 17:49:02 -0400 Subject: [PATCH 2/3] comments --- src/strands/multiagent/function_node.py | 29 +++++++------------------ 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/src/strands/multiagent/function_node.py b/src/strands/multiagent/function_node.py index 56774283d..d3d71685a 100644 --- a/src/strands/multiagent/function_node.py +++ b/src/strands/multiagent/function_node.py @@ -7,7 +7,7 @@ import logging import time -from typing import Any, Protocol, Union +from typing import Any, Callable, Union from opentelemetry import trace as trace_api @@ -21,23 +21,7 @@ logger = logging.getLogger(__name__) -class FunctionNodeCallable(Protocol): - """Protocol defining the required signature for functions used in FunctionNode. - - Functions must accept: - - task: The input task (string or ContentBlock list) - - invocation_state: Additional state/context from the calling environment - - **kwargs: Additional keyword arguments for future extensibility - - Functions must return: - - A string result that will be converted to a Message - """ - - def __call__( - self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any - ) -> str: - """Execute the node with the given task.""" - ... +FunctionNodeCallable = Callable[[Union[str, list[ContentBlock]], dict[str, Any] | None], str] class FunctionNode(MultiAgentBase): @@ -82,16 +66,19 @@ async def invoke_async( logger.debug("task=<%s> | starting function node execution", task) logger.debug("function_name=<%s> | executing function", self.name) - start_time = time.time() span = self.tracer.start_multiagent_span(task, "function_node") with trace_api.use_span(span, end_on_exit=True): try: + start_time = time.time() # Execute the wrapped function with proper parameters function_result = self.func(task, invocation_state, **kwargs) - logger.debug("function_result=<%s> | function executed successfully", function_result) - # Calculate execution time execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds + logger.debug( + "function_result=<%s>, execution_time=<%dms> | function executed successfully", + function_result, + execution_time, + ) # Convert function result to Message message = Message(role="assistant", content=[ContentBlock(text=str(function_result))]) From 4b9b2f561cfa8509199c44fa7bea7ce563d964f8 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 10 Oct 2025 10:59:50 -0400 Subject: [PATCH 3/3] comments --- src/strands/multiagent/function_node.py | 23 ++++++-- .../strands/multiagent/test_function_node.py | 53 ++++++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/src/strands/multiagent/function_node.py b/src/strands/multiagent/function_node.py index d3d71685a..039ad09de 100644 --- a/src/strands/multiagent/function_node.py +++ b/src/strands/multiagent/function_node.py @@ -7,7 +7,7 @@ import logging import time -from typing import Any, Callable, Union +from typing import Any, Protocol, Union from opentelemetry import trace as trace_api @@ -21,7 +21,14 @@ logger = logging.getLogger(__name__) -FunctionNodeCallable = Callable[[Union[str, list[ContentBlock]], dict[str, Any] | None], str] +class FunctionNodeCallable(Protocol): + """Protocol for functions that can be executed within FunctionNode.""" + + def __call__( + self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> str | list[ContentBlock] | Message: + """Execute deterministic logic within the multiagent system.""" + ... class FunctionNode(MultiAgentBase): @@ -80,8 +87,16 @@ async def invoke_async( execution_time, ) - # Convert function result to Message - message = Message(role="assistant", content=[ContentBlock(text=str(function_result))]) + # Convert function result to Message based on type + if isinstance(function_result, dict) and "role" in function_result and "content" in function_result: + # Already a Message + message = function_result + elif isinstance(function_result, list): + # List of ContentBlocks + message = Message(role="assistant", content=function_result) + else: + # String or other type - convert to string + message = Message(role="assistant", content=[ContentBlock(text=str(function_result))]) agent_result = AgentResult( stop_reason="end_turn", # "Normal completion of the response" - function executed successfully message=message, diff --git a/tests/strands/multiagent/test_function_node.py b/tests/strands/multiagent/test_function_node.py index aee48827e..b46ad8818 100644 --- a/tests/strands/multiagent/test_function_node.py +++ b/tests/strands/multiagent/test_function_node.py @@ -6,7 +6,7 @@ from strands.multiagent.base import Status from strands.multiagent.function_node import FunctionNode -from strands.types.content import ContentBlock +from strands.types.content import ContentBlock, Message @pytest.fixture @@ -92,3 +92,54 @@ def failing_function(task, invocation_state=None, **kwargs): assert node_result.status == Status.FAILED assert isinstance(node_result.result, ValueError) assert str(node_result.result) == "Test exception" + + +@pytest.mark.asyncio +async def test_function_returns_string(mock_tracer): + """Test function returning string.""" + + def string_function(task, invocation_state=None, **kwargs): + return "Hello World" + + node = FunctionNode(string_function, "string_node") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test") + + agent_result = result.results["string_node"].result + assert agent_result.message["content"][0]["text"] == "Hello World" + + +@pytest.mark.asyncio +async def test_function_returns_content_blocks(mock_tracer): + """Test function returning list of ContentBlocks.""" + + def content_block_function(task, invocation_state=None, **kwargs): + return [ContentBlock(text="Block 1"), ContentBlock(text="Block 2")] + + node = FunctionNode(content_block_function, "content_node") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test") + + agent_result = result.results["content_node"].result + assert len(agent_result.message["content"]) == 2 + assert agent_result.message["content"][0]["text"] == "Block 1" + assert agent_result.message["content"][1]["text"] == "Block 2" + + +@pytest.mark.asyncio +async def test_function_returns_message(mock_tracer): + """Test function returning Message.""" + + def message_function(task, invocation_state=None, **kwargs): + return Message(role="user", content=[ContentBlock(text="Custom message")]) + + node = FunctionNode(message_function, "message_node") + + with patch.object(node, "tracer", mock_tracer): + result = await node.invoke_async("test") + + agent_result = result.results["message_node"].result + assert agent_result.message["role"] == "user" + assert agent_result.message["content"][0]["text"] == "Custom message"