Skip to content

Commit 838c236

Browse files
committed
feat(multiagent): add FunctionNode to improve DX for deterministic cases
1 parent 776fd93 commit 838c236

File tree

4 files changed

+295
-0
lines changed

4 files changed

+295
-0
lines changed

src/strands/multiagent/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
"""
1010

1111
from .base import MultiAgentBase, MultiAgentResult
12+
from .function_node import FunctionNode
1213
from .graph import GraphBuilder, GraphResult
1314
from .swarm import Swarm, SwarmResult
1415

1516
__all__ = [
17+
"FunctionNode",
1618
"GraphBuilder",
1719
"GraphResult",
1820
"MultiAgentBase",
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""FunctionNode implementation for executing deterministic Python functions as graph nodes.
2+
3+
This module provides the FunctionNode class that extends MultiAgentBase to execute
4+
regular Python functions while maintaining compatibility with the existing graph
5+
execution framework, proper error handling, metrics collection, and result formatting.
6+
"""
7+
8+
import logging
9+
import time
10+
from typing import Any, Protocol, Union
11+
12+
from opentelemetry import trace as trace_api
13+
14+
from ..agent import AgentResult
15+
from ..telemetry import get_tracer
16+
from ..telemetry.metrics import EventLoopMetrics
17+
from ..types.content import ContentBlock, Message
18+
from ..types.event_loop import Metrics, Usage
19+
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class FunctionNodeCallable(Protocol):
25+
"""Protocol defining the required signature for functions used in FunctionNode.
26+
27+
Functions must accept:
28+
- task: The input task (string or ContentBlock list)
29+
- invocation_state: Additional state/context from the calling environment
30+
- **kwargs: Additional keyword arguments for future extensibility
31+
32+
Functions must return:
33+
- A string result that will be converted to a Message
34+
"""
35+
36+
def __call__(
37+
self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any
38+
) -> str:
39+
"""Execute the node with the given task."""
40+
...
41+
42+
43+
class FunctionNode(MultiAgentBase):
44+
"""Execute deterministic Python functions as graph nodes.
45+
46+
FunctionNode wraps any callable Python function and executes it within the
47+
established multiagent framework, handling input conversion, error management,
48+
metrics collection, and result formatting automatically.
49+
50+
Args:
51+
func: The callable function to wrap and execute
52+
name: Required name for the node
53+
"""
54+
55+
def __init__(self, func: FunctionNodeCallable, name: str):
56+
"""Initialize FunctionNode with a callable function and required name.
57+
58+
Args:
59+
func: The callable function to wrap and execute
60+
name: Required name for the node
61+
"""
62+
self.func = func
63+
self.name = name
64+
self.tracer = get_tracer()
65+
66+
async def invoke_async(
67+
self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any
68+
) -> MultiAgentResult:
69+
"""Execute the wrapped function and return formatted results.
70+
71+
Args:
72+
task: The task input (string or ContentBlock list) to pass to the function
73+
invocation_state: Additional state/context (preserved for interface compatibility)
74+
**kwargs: Additional keyword arguments (preserved for future extensibility)
75+
76+
Returns:
77+
MultiAgentResult containing the function execution results and metadata
78+
"""
79+
if invocation_state is None:
80+
invocation_state = {}
81+
82+
logger.debug("task=<%s> | starting function node execution", task)
83+
logger.debug("function_name=<%s> | executing function", self.name)
84+
85+
start_time = time.time()
86+
span = self.tracer.start_multiagent_span(task, "function_node")
87+
with trace_api.use_span(span, end_on_exit=True):
88+
try:
89+
# Execute the wrapped function with proper parameters
90+
function_result = self.func(task, invocation_state, **kwargs)
91+
logger.debug("function_result=<%s> | function executed successfully", function_result)
92+
93+
# Calculate execution time
94+
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds
95+
96+
# Convert function result to Message
97+
message = Message(role="assistant", content=[ContentBlock(text=str(function_result))])
98+
agent_result = AgentResult(
99+
stop_reason="end_turn", # "Normal completion of the response" - function executed successfully
100+
message=message,
101+
metrics=EventLoopMetrics(),
102+
state={},
103+
)
104+
105+
# Create NodeResult for this function execution
106+
node_result = NodeResult(
107+
result=agent_result, # type is AgentResult
108+
execution_time=execution_time,
109+
status=Status.COMPLETED,
110+
execution_count=1,
111+
)
112+
113+
# Create MultiAgentResult with the NodeResult
114+
multi_agent_result = MultiAgentResult(
115+
status=Status.COMPLETED,
116+
results={self.name: node_result},
117+
execution_count=1,
118+
execution_time=execution_time,
119+
)
120+
121+
logger.debug(
122+
"function_name=<%s>, execution_time=<%dms> | function node completed successfully",
123+
self.name,
124+
execution_time,
125+
)
126+
127+
return multi_agent_result
128+
129+
except Exception as e:
130+
# Calculate execution time even for failed executions
131+
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds
132+
133+
logger.error("function_name=<%s>, error=<%s> | function node failed", self.name, e)
134+
135+
# Create failed NodeResult with exception
136+
node_result = NodeResult(
137+
result=e,
138+
execution_time=execution_time,
139+
status=Status.FAILED,
140+
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
141+
accumulated_metrics=Metrics(latencyMs=execution_time),
142+
execution_count=1,
143+
)
144+
145+
# Create failed MultiAgentResult
146+
multi_agent_result = MultiAgentResult(
147+
status=Status.FAILED,
148+
results={self.name: node_result},
149+
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
150+
accumulated_metrics=Metrics(latencyMs=execution_time),
151+
execution_count=1,
152+
execution_time=execution_time,
153+
)
154+
155+
return multi_agent_result
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Unit tests for FunctionNode implementation."""
2+
3+
from unittest.mock import Mock, patch
4+
5+
import pytest
6+
7+
from strands.multiagent.base import Status
8+
from strands.multiagent.function_node import FunctionNode
9+
from strands.types.content import ContentBlock
10+
11+
12+
@pytest.fixture
13+
def mock_tracer():
14+
"""Create a mock tracer for testing."""
15+
tracer = Mock()
16+
span = Mock()
17+
span.__enter__ = Mock(return_value=span)
18+
span.__exit__ = Mock(return_value=None)
19+
tracer.start_multiagent_span.return_value = span
20+
return tracer
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_invoke_async_string_input_success(mock_tracer):
25+
"""Test successful function execution with string input."""
26+
27+
def test_function(task, invocation_state=None, **kwargs):
28+
return f"Processed: {task}"
29+
30+
node = FunctionNode(test_function, "string_test")
31+
32+
with patch.object(node, "tracer", mock_tracer):
33+
result = await node.invoke_async("test input")
34+
35+
assert result.status == Status.COMPLETED
36+
assert "string_test" in result.results
37+
assert result.results["string_test"].status == Status.COMPLETED
38+
assert result.accumulated_usage["inputTokens"] == 0
39+
assert result.accumulated_usage["outputTokens"] == 0
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_invoke_async_content_block_input_success(mock_tracer):
44+
"""Test successful function execution with ContentBlock input."""
45+
46+
def test_function(task, invocation_state=None, **kwargs):
47+
return "ContentBlock processed"
48+
49+
node = FunctionNode(test_function, "content_block_test")
50+
content_blocks = [ContentBlock(text="First block"), ContentBlock(text="Second block")]
51+
52+
with patch.object(node, "tracer", mock_tracer):
53+
result = await node.invoke_async(content_blocks)
54+
55+
assert result.status == Status.COMPLETED
56+
assert "content_block_test" in result.results
57+
node_result = result.results["content_block_test"]
58+
assert node_result.status == Status.COMPLETED
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_invoke_async_with_kwargs(mock_tracer):
63+
"""Test function execution with additional kwargs."""
64+
65+
def test_function(task, invocation_state=None, **kwargs):
66+
extra_param = kwargs.get("extra_param", "none")
67+
return f"Extra: {extra_param}"
68+
69+
node = FunctionNode(test_function, "kwargs_test")
70+
71+
with patch.object(node, "tracer", mock_tracer):
72+
result = await node.invoke_async("test", None, extra_param="test_value")
73+
74+
assert result.status == Status.COMPLETED
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_invoke_async_function_exception(mock_tracer):
79+
"""Test proper exception handling when function raises an error."""
80+
81+
def failing_function(task, invocation_state=None, **kwargs):
82+
raise ValueError("Test exception")
83+
84+
node = FunctionNode(failing_function, "exception_test")
85+
86+
with patch.object(node, "tracer", mock_tracer):
87+
result = await node.invoke_async("test input")
88+
89+
assert result.status == Status.FAILED
90+
assert "exception_test" in result.results
91+
node_result = result.results["exception_test"]
92+
assert node_result.status == Status.FAILED
93+
assert isinstance(node_result.result, ValueError)
94+
assert str(node_result.result) == "Test exception"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Integration tests for FunctionNode with multiagent systems."""
2+
3+
import pytest
4+
5+
from strands import Agent
6+
from strands.multiagent.base import Status
7+
from strands.multiagent.function_node import FunctionNode
8+
from strands.multiagent.graph import GraphBuilder
9+
10+
# Global variable to test function execution
11+
test_global_var = None
12+
13+
14+
def set_global_var(task, invocation_state=None, **kwargs):
15+
"""Simple function that sets a global variable."""
16+
global test_global_var
17+
test_global_var = f"Function executed with: {task}"
18+
return "Global variable set"
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_agent_with_function_node():
23+
"""Test graph with agent and function node."""
24+
global test_global_var
25+
test_global_var = None
26+
27+
# Create nodes
28+
agent = Agent()
29+
function_node = FunctionNode(set_global_var, "setter")
30+
31+
# Build graph
32+
builder = GraphBuilder()
33+
builder.add_node(agent, "agent")
34+
builder.add_node(function_node, "setter")
35+
builder.add_edge("agent", "setter")
36+
builder.set_entry_point("agent")
37+
graph = builder.build()
38+
39+
# Execute
40+
result = await graph.invoke_async("Say hello")
41+
42+
# Verify function was called
43+
assert "Function executed with:" in test_global_var
44+
assert result.status == Status.COMPLETED

0 commit comments

Comments
 (0)