generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 425
feat(multiagent): add FunctionNode to improve DX for deterministic cases #991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dbschmigelski
wants to merge
3
commits into
strands-agents:main
Choose a base branch
from
dbschmigelski:function-node
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+348
−0
Open
Changes from 2 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
"""FunctionNode implementation for executing deterministic Python functions as graph nodes. | ||
|
||
This module provides the FunctionNode class that extends MultiAgentBase to execute | ||
regular Python functions while maintaining compatibility with the existing graph | ||
execution framework, proper error handling, metrics collection, and result formatting. | ||
""" | ||
|
||
import logging | ||
import time | ||
from typing import Any, Callable, Union | ||
|
||
from opentelemetry import trace as trace_api | ||
|
||
from ..agent import AgentResult | ||
from ..telemetry import get_tracer | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..types.content import ContentBlock, Message | ||
from ..types.event_loop import Metrics, Usage | ||
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
FunctionNodeCallable = Callable[[Union[str, list[ContentBlock]], dict[str, Any] | None], str] | ||
dbschmigelski marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
dbschmigelski marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
|
||
class FunctionNode(MultiAgentBase): | ||
"""Execute deterministic Python functions as graph nodes. | ||
|
||
FunctionNode wraps any callable Python function and executes it within the | ||
established multiagent framework, handling input conversion, error management, | ||
metrics collection, and result formatting automatically. | ||
|
||
Args: | ||
func: The callable function to wrap and execute | ||
name: Required name for the node | ||
""" | ||
|
||
def __init__(self, func: FunctionNodeCallable, name: str): | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Initialize FunctionNode with a callable function and required name. | ||
|
||
Args: | ||
func: The callable function to wrap and execute | ||
name: Required name for the node | ||
""" | ||
self.func = func | ||
self.name = name | ||
self.tracer = get_tracer() | ||
|
||
async def invoke_async( | ||
self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> MultiAgentResult: | ||
"""Execute the wrapped function and return formatted results. | ||
|
||
Args: | ||
task: The task input (string or ContentBlock list) to pass to the function | ||
invocation_state: Additional state/context (preserved for interface compatibility) | ||
**kwargs: Additional keyword arguments (preserved for future extensibility) | ||
|
||
Returns: | ||
MultiAgentResult containing the function execution results and metadata | ||
""" | ||
if invocation_state is None: | ||
invocation_state = {} | ||
|
||
logger.debug("task=<%s> | starting function node execution", task) | ||
logger.debug("function_name=<%s> | executing function", self.name) | ||
|
||
span = self.tracer.start_multiagent_span(task, "function_node") | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with trace_api.use_span(span, end_on_exit=True): | ||
try: | ||
start_time = time.time() | ||
# Execute the wrapped function with proper parameters | ||
function_result = self.func(task, invocation_state, **kwargs) | ||
# Calculate execution time | ||
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds | ||
logger.debug( | ||
"function_result=<%s>, execution_time=<%dms> | function executed successfully", | ||
function_result, | ||
execution_time, | ||
) | ||
|
||
# Convert function result to Message | ||
message = Message(role="assistant", content=[ContentBlock(text=str(function_result))]) | ||
dbschmigelski marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
agent_result = AgentResult( | ||
stop_reason="end_turn", # "Normal completion of the response" - function executed successfully | ||
message=message, | ||
metrics=EventLoopMetrics(), | ||
state={}, | ||
) | ||
|
||
# Create NodeResult for this function execution | ||
node_result = NodeResult( | ||
result=agent_result, # type is AgentResult | ||
execution_time=execution_time, | ||
status=Status.COMPLETED, | ||
execution_count=1, | ||
) | ||
|
||
# Create MultiAgentResult with the NodeResult | ||
multi_agent_result = MultiAgentResult( | ||
status=Status.COMPLETED, | ||
results={self.name: node_result}, | ||
execution_count=1, | ||
execution_time=execution_time, | ||
) | ||
|
||
logger.debug( | ||
"function_name=<%s>, execution_time=<%dms> | function node completed successfully", | ||
self.name, | ||
execution_time, | ||
) | ||
|
||
return multi_agent_result | ||
|
||
except Exception as e: | ||
# Calculate execution time even for failed executions | ||
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
logger.error("function_name=<%s>, error=<%s> | function node failed", self.name, e) | ||
|
||
# Create failed NodeResult with exception | ||
node_result = NodeResult( | ||
result=e, | ||
execution_time=execution_time, | ||
status=Status.FAILED, | ||
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), | ||
accumulated_metrics=Metrics(latencyMs=execution_time), | ||
execution_count=1, | ||
) | ||
|
||
# Create failed MultiAgentResult | ||
multi_agent_result = MultiAgentResult( | ||
status=Status.FAILED, | ||
results={self.name: node_result}, | ||
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), | ||
accumulated_metrics=Metrics(latencyMs=execution_time), | ||
execution_count=1, | ||
execution_time=execution_time, | ||
) | ||
|
||
return multi_agent_result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""Unit tests for FunctionNode implementation.""" | ||
|
||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from strands.multiagent.base import Status | ||
from strands.multiagent.function_node import FunctionNode | ||
from strands.types.content import ContentBlock | ||
|
||
|
||
@pytest.fixture | ||
def mock_tracer(): | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Create a mock tracer for testing.""" | ||
tracer = Mock() | ||
span = Mock() | ||
span.__enter__ = Mock(return_value=span) | ||
span.__exit__ = Mock(return_value=None) | ||
tracer.start_multiagent_span.return_value = span | ||
return tracer | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_string_input_success(mock_tracer): | ||
"""Test successful function execution with string input.""" | ||
|
||
def test_function(task, invocation_state=None, **kwargs): | ||
return f"Processed: {task}" | ||
|
||
node = FunctionNode(test_function, "string_test") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test input") | ||
|
||
assert result.status == Status.COMPLETED | ||
assert "string_test" in result.results | ||
assert result.results["string_test"].status == Status.COMPLETED | ||
assert result.accumulated_usage["inputTokens"] == 0 | ||
assert result.accumulated_usage["outputTokens"] == 0 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_content_block_input_success(mock_tracer): | ||
"""Test successful function execution with ContentBlock input.""" | ||
|
||
def test_function(task, invocation_state=None, **kwargs): | ||
return "ContentBlock processed" | ||
|
||
node = FunctionNode(test_function, "content_block_test") | ||
content_blocks = [ContentBlock(text="First block"), ContentBlock(text="Second block")] | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async(content_blocks) | ||
|
||
assert result.status == Status.COMPLETED | ||
assert "content_block_test" in result.results | ||
node_result = result.results["content_block_test"] | ||
assert node_result.status == Status.COMPLETED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_with_kwargs(mock_tracer): | ||
"""Test function execution with additional kwargs.""" | ||
|
||
def test_function(task, invocation_state=None, **kwargs): | ||
extra_param = kwargs.get("extra_param", "none") | ||
return f"Extra: {extra_param}" | ||
|
||
node = FunctionNode(test_function, "kwargs_test") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test", None, extra_param="test_value") | ||
|
||
assert result.status == Status.COMPLETED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_function_exception(mock_tracer): | ||
"""Test proper exception handling when function raises an error.""" | ||
|
||
def failing_function(task, invocation_state=None, **kwargs): | ||
raise ValueError("Test exception") | ||
|
||
node = FunctionNode(failing_function, "exception_test") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test input") | ||
|
||
assert result.status == Status.FAILED | ||
assert "exception_test" in result.results | ||
node_result = result.results["exception_test"] | ||
assert node_result.status == Status.FAILED | ||
assert isinstance(node_result.result, ValueError) | ||
assert str(node_result.result) == "Test exception" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Integration tests for FunctionNode with multiagent systems.""" | ||
|
||
import pytest | ||
|
||
from strands import Agent | ||
from strands.multiagent.base import Status | ||
from strands.multiagent.function_node import FunctionNode | ||
from strands.multiagent.graph import GraphBuilder | ||
|
||
# Global variable to test function execution | ||
test_global_var = None | ||
|
||
|
||
def set_global_var(task, invocation_state=None, **kwargs): | ||
"""Simple function that sets a global variable.""" | ||
global test_global_var | ||
test_global_var = f"Function executed with: {task}" | ||
return "Global variable set" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_agent_with_function_node(): | ||
"""Test graph with agent and function node.""" | ||
global test_global_var | ||
test_global_var = None | ||
|
||
# Create nodes | ||
agent = Agent() | ||
function_node = FunctionNode(set_global_var, "setter") | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Build graph | ||
builder = GraphBuilder() | ||
builder.add_node(agent, "agent") | ||
builder.add_node(function_node, "setter") | ||
builder.add_edge("agent", "setter") | ||
builder.set_entry_point("agent") | ||
graph = builder.build() | ||
|
||
# Execute | ||
result = await graph.invoke_async("Say hello") | ||
|
||
# Verify function was called | ||
assert "Function executed with:" in test_global_var | ||
assert result.status == Status.COMPLETED |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.