Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/strands/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
"""

from .base import MultiAgentBase, MultiAgentResult
from .function_node import FunctionNode
from .graph import GraphBuilder, GraphResult
from .swarm import Swarm, SwarmResult

__all__ = [
"FunctionNode",
"GraphBuilder",
"GraphResult",
"MultiAgentBase",
Expand Down
142 changes: 142 additions & 0 deletions src/strands/multiagent/function_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""FunctionNode implementation for executing deterministic Python functions as graph nodes.

This module provides the FunctionNode class that extends MultiAgentBase to execute
regular Python functions while maintaining compatibility with the existing graph
execution framework, proper error handling, metrics collection, and result formatting.
"""

import logging
import time
from typing import Any, Callable, Union

from opentelemetry import trace as trace_api

from ..agent import AgentResult
from ..telemetry import get_tracer
from ..telemetry.metrics import EventLoopMetrics
from ..types.content import ContentBlock, Message
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status

logger = logging.getLogger(__name__)


FunctionNodeCallable = Callable[[Union[str, list[ContentBlock]], dict[str, Any] | None], str]


class FunctionNode(MultiAgentBase):
"""Execute deterministic Python functions as graph nodes.

FunctionNode wraps any callable Python function and executes it within the
established multiagent framework, handling input conversion, error management,
metrics collection, and result formatting automatically.

Args:
func: The callable function to wrap and execute
name: Required name for the node
"""

def __init__(self, func: FunctionNodeCallable, name: str):
"""Initialize FunctionNode with a callable function and required name.

Args:
func: The callable function to wrap and execute
name: Required name for the node
"""
self.func = func
self.name = name
self.tracer = get_tracer()

async def invoke_async(
self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Execute the wrapped function and return formatted results.

Args:
task: The task input (string or ContentBlock list) to pass to the function
invocation_state: Additional state/context (preserved for interface compatibility)
**kwargs: Additional keyword arguments (preserved for future extensibility)

Returns:
MultiAgentResult containing the function execution results and metadata
"""
if invocation_state is None:
invocation_state = {}

logger.debug("task=<%s> | starting function node execution", task)
logger.debug("function_name=<%s> | executing function", self.name)

span = self.tracer.start_multiagent_span(task, "function_node")
with trace_api.use_span(span, end_on_exit=True):
try:
start_time = time.time()
# Execute the wrapped function with proper parameters
function_result = self.func(task, invocation_state, **kwargs)
# Calculate execution time
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds
logger.debug(
"function_result=<%s>, execution_time=<%dms> | function executed successfully",
function_result,
execution_time,
)

# Convert function result to Message
message = Message(role="assistant", content=[ContentBlock(text=str(function_result))])
agent_result = AgentResult(
stop_reason="end_turn", # "Normal completion of the response" - function executed successfully
message=message,
metrics=EventLoopMetrics(),
state={},
)

# Create NodeResult for this function execution
node_result = NodeResult(
result=agent_result, # type is AgentResult
execution_time=execution_time,
status=Status.COMPLETED,
execution_count=1,
)

# Create MultiAgentResult with the NodeResult
multi_agent_result = MultiAgentResult(
status=Status.COMPLETED,
results={self.name: node_result},
execution_count=1,
execution_time=execution_time,
)

logger.debug(
"function_name=<%s>, execution_time=<%dms> | function node completed successfully",
self.name,
execution_time,
)

return multi_agent_result

except Exception as e:
# Calculate execution time even for failed executions
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds

logger.error("function_name=<%s>, error=<%s> | function node failed", self.name, e)

# Create failed NodeResult with exception
node_result = NodeResult(
result=e,
execution_time=execution_time,
status=Status.FAILED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
accumulated_metrics=Metrics(latencyMs=execution_time),
execution_count=1,
)

# Create failed MultiAgentResult
multi_agent_result = MultiAgentResult(
status=Status.FAILED,
results={self.name: node_result},
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
accumulated_metrics=Metrics(latencyMs=execution_time),
execution_count=1,
execution_time=execution_time,
)

return multi_agent_result
94 changes: 94 additions & 0 deletions tests/strands/multiagent/test_function_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Unit tests for FunctionNode implementation."""

from unittest.mock import Mock, patch

import pytest

from strands.multiagent.base import Status
from strands.multiagent.function_node import FunctionNode
from strands.types.content import ContentBlock


@pytest.fixture
def mock_tracer():
"""Create a mock tracer for testing."""
tracer = Mock()
span = Mock()
span.__enter__ = Mock(return_value=span)
span.__exit__ = Mock(return_value=None)
tracer.start_multiagent_span.return_value = span
return tracer


@pytest.mark.asyncio
async def test_invoke_async_string_input_success(mock_tracer):
"""Test successful function execution with string input."""

def test_function(task, invocation_state=None, **kwargs):
return f"Processed: {task}"

node = FunctionNode(test_function, "string_test")

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async("test input")

assert result.status == Status.COMPLETED
assert "string_test" in result.results
assert result.results["string_test"].status == Status.COMPLETED
assert result.accumulated_usage["inputTokens"] == 0
assert result.accumulated_usage["outputTokens"] == 0


@pytest.mark.asyncio
async def test_invoke_async_content_block_input_success(mock_tracer):
"""Test successful function execution with ContentBlock input."""

def test_function(task, invocation_state=None, **kwargs):
return "ContentBlock processed"

node = FunctionNode(test_function, "content_block_test")
content_blocks = [ContentBlock(text="First block"), ContentBlock(text="Second block")]

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async(content_blocks)

assert result.status == Status.COMPLETED
assert "content_block_test" in result.results
node_result = result.results["content_block_test"]
assert node_result.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_invoke_async_with_kwargs(mock_tracer):
"""Test function execution with additional kwargs."""

def test_function(task, invocation_state=None, **kwargs):
extra_param = kwargs.get("extra_param", "none")
return f"Extra: {extra_param}"

node = FunctionNode(test_function, "kwargs_test")

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async("test", None, extra_param="test_value")

assert result.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_invoke_async_function_exception(mock_tracer):
"""Test proper exception handling when function raises an error."""

def failing_function(task, invocation_state=None, **kwargs):
raise ValueError("Test exception")

node = FunctionNode(failing_function, "exception_test")

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async("test input")

assert result.status == Status.FAILED
assert "exception_test" in result.results
node_result = result.results["exception_test"]
assert node_result.status == Status.FAILED
assert isinstance(node_result.result, ValueError)
assert str(node_result.result) == "Test exception"
44 changes: 44 additions & 0 deletions tests_integ/test_multiagent_function_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Integration tests for FunctionNode with multiagent systems."""

import pytest

from strands import Agent
from strands.multiagent.base import Status
from strands.multiagent.function_node import FunctionNode
from strands.multiagent.graph import GraphBuilder

# Global variable to test function execution
test_global_var = None


def set_global_var(task, invocation_state=None, **kwargs):
"""Simple function that sets a global variable."""
global test_global_var
test_global_var = f"Function executed with: {task}"
return "Global variable set"


@pytest.mark.asyncio
async def test_agent_with_function_node():
"""Test graph with agent and function node."""
global test_global_var
test_global_var = None

# Create nodes
agent = Agent()
function_node = FunctionNode(set_global_var, "setter")

# Build graph
builder = GraphBuilder()
builder.add_node(agent, "agent")
builder.add_node(function_node, "setter")
builder.add_edge("agent", "setter")
builder.set_entry_point("agent")
graph = builder.build()

# Execute
result = await graph.invoke_async("Say hello")

# Verify function was called
assert "Function executed with:" in test_global_var
assert result.status == Status.COMPLETED
Loading