Skip to content

Commit 4b9b2f5

Browse files
committed
comments
1 parent 709ef4b commit 4b9b2f5

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

src/strands/multiagent/function_node.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import logging
99
import time
10-
from typing import Any, Callable, Union
10+
from typing import Any, Protocol, Union
1111

1212
from opentelemetry import trace as trace_api
1313

@@ -21,7 +21,14 @@
2121
logger = logging.getLogger(__name__)
2222

2323

24-
FunctionNodeCallable = Callable[[Union[str, list[ContentBlock]], dict[str, Any] | None], str]
24+
class FunctionNodeCallable(Protocol):
25+
"""Protocol for functions that can be executed within FunctionNode."""
26+
27+
def __call__(
28+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
29+
) -> str | list[ContentBlock] | Message:
30+
"""Execute deterministic logic within the multiagent system."""
31+
...
2532

2633

2734
class FunctionNode(MultiAgentBase):
@@ -80,8 +87,16 @@ async def invoke_async(
8087
execution_time,
8188
)
8289

83-
# Convert function result to Message
84-
message = Message(role="assistant", content=[ContentBlock(text=str(function_result))])
90+
# Convert function result to Message based on type
91+
if isinstance(function_result, dict) and "role" in function_result and "content" in function_result:
92+
# Already a Message
93+
message = function_result
94+
elif isinstance(function_result, list):
95+
# List of ContentBlocks
96+
message = Message(role="assistant", content=function_result)
97+
else:
98+
# String or other type - convert to string
99+
message = Message(role="assistant", content=[ContentBlock(text=str(function_result))])
85100
agent_result = AgentResult(
86101
stop_reason="end_turn", # "Normal completion of the response" - function executed successfully
87102
message=message,

tests/strands/multiagent/test_function_node.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from strands.multiagent.base import Status
88
from strands.multiagent.function_node import FunctionNode
9-
from strands.types.content import ContentBlock
9+
from strands.types.content import ContentBlock, Message
1010

1111

1212
@pytest.fixture
@@ -92,3 +92,54 @@ def failing_function(task, invocation_state=None, **kwargs):
9292
assert node_result.status == Status.FAILED
9393
assert isinstance(node_result.result, ValueError)
9494
assert str(node_result.result) == "Test exception"
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_function_returns_string(mock_tracer):
99+
"""Test function returning string."""
100+
101+
def string_function(task, invocation_state=None, **kwargs):
102+
return "Hello World"
103+
104+
node = FunctionNode(string_function, "string_node")
105+
106+
with patch.object(node, "tracer", mock_tracer):
107+
result = await node.invoke_async("test")
108+
109+
agent_result = result.results["string_node"].result
110+
assert agent_result.message["content"][0]["text"] == "Hello World"
111+
112+
113+
@pytest.mark.asyncio
114+
async def test_function_returns_content_blocks(mock_tracer):
115+
"""Test function returning list of ContentBlocks."""
116+
117+
def content_block_function(task, invocation_state=None, **kwargs):
118+
return [ContentBlock(text="Block 1"), ContentBlock(text="Block 2")]
119+
120+
node = FunctionNode(content_block_function, "content_node")
121+
122+
with patch.object(node, "tracer", mock_tracer):
123+
result = await node.invoke_async("test")
124+
125+
agent_result = result.results["content_node"].result
126+
assert len(agent_result.message["content"]) == 2
127+
assert agent_result.message["content"][0]["text"] == "Block 1"
128+
assert agent_result.message["content"][1]["text"] == "Block 2"
129+
130+
131+
@pytest.mark.asyncio
132+
async def test_function_returns_message(mock_tracer):
133+
"""Test function returning Message."""
134+
135+
def message_function(task, invocation_state=None, **kwargs):
136+
return Message(role="user", content=[ContentBlock(text="Custom message")])
137+
138+
node = FunctionNode(message_function, "message_node")
139+
140+
with patch.object(node, "tracer", mock_tracer):
141+
result = await node.invoke_async("test")
142+
143+
agent_result = result.results["message_node"].result
144+
assert agent_result.message["role"] == "user"
145+
assert agent_result.message["content"][0]["text"] == "Custom message"

0 commit comments

Comments
 (0)