|
6 | 6 |
|
7 | 7 | from strands.multiagent.base import Status
|
8 | 8 | from strands.multiagent.function_node import FunctionNode
|
9 |
| -from strands.types.content import ContentBlock |
| 9 | +from strands.types.content import ContentBlock, Message |
10 | 10 |
|
11 | 11 |
|
12 | 12 | @pytest.fixture
|
@@ -92,3 +92,54 @@ def failing_function(task, invocation_state=None, **kwargs):
|
92 | 92 | assert node_result.status == Status.FAILED
|
93 | 93 | assert isinstance(node_result.result, ValueError)
|
94 | 94 | assert str(node_result.result) == "Test exception"
|
| 95 | + |
| 96 | + |
| 97 | +@pytest.mark.asyncio |
| 98 | +async def test_function_returns_string(mock_tracer): |
| 99 | + """Test function returning string.""" |
| 100 | + |
| 101 | + def string_function(task, invocation_state=None, **kwargs): |
| 102 | + return "Hello World" |
| 103 | + |
| 104 | + node = FunctionNode(string_function, "string_node") |
| 105 | + |
| 106 | + with patch.object(node, "tracer", mock_tracer): |
| 107 | + result = await node.invoke_async("test") |
| 108 | + |
| 109 | + agent_result = result.results["string_node"].result |
| 110 | + assert agent_result.message["content"][0]["text"] == "Hello World" |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.asyncio |
| 114 | +async def test_function_returns_content_blocks(mock_tracer): |
| 115 | + """Test function returning list of ContentBlocks.""" |
| 116 | + |
| 117 | + def content_block_function(task, invocation_state=None, **kwargs): |
| 118 | + return [ContentBlock(text="Block 1"), ContentBlock(text="Block 2")] |
| 119 | + |
| 120 | + node = FunctionNode(content_block_function, "content_node") |
| 121 | + |
| 122 | + with patch.object(node, "tracer", mock_tracer): |
| 123 | + result = await node.invoke_async("test") |
| 124 | + |
| 125 | + agent_result = result.results["content_node"].result |
| 126 | + assert len(agent_result.message["content"]) == 2 |
| 127 | + assert agent_result.message["content"][0]["text"] == "Block 1" |
| 128 | + assert agent_result.message["content"][1]["text"] == "Block 2" |
| 129 | + |
| 130 | + |
| 131 | +@pytest.mark.asyncio |
| 132 | +async def test_function_returns_message(mock_tracer): |
| 133 | + """Test function returning Message.""" |
| 134 | + |
| 135 | + def message_function(task, invocation_state=None, **kwargs): |
| 136 | + return Message(role="user", content=[ContentBlock(text="Custom message")]) |
| 137 | + |
| 138 | + node = FunctionNode(message_function, "message_node") |
| 139 | + |
| 140 | + with patch.object(node, "tracer", mock_tracer): |
| 141 | + result = await node.invoke_async("test") |
| 142 | + |
| 143 | + agent_result = result.results["message_node"].result |
| 144 | + assert agent_result.message["role"] == "user" |
| 145 | + assert agent_result.message["content"][0]["text"] == "Custom message" |
0 commit comments