Skip to content

Commit e597e07

Browse files
authored
Automatically flatten nested tool collections (#508)
Fixes issue #50 Customers naturally want to pass nested collections of tools - the above issue has gathered enough data points proving that.
1 parent 022ec55 commit e597e07

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

src/strands/tools/registry.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from importlib import import_module, util
1212
from os.path import expanduser
1313
from pathlib import Path
14-
from typing import Any, Dict, List, Optional
14+
from typing import Any, Dict, Iterable, List, Optional
1515

1616
from typing_extensions import TypedDict, cast
1717

@@ -54,7 +54,7 @@ def process_tools(self, tools: List[Any]) -> List[str]:
5454
"""
5555
tool_names = []
5656

57-
for tool in tools:
57+
def add_tool(tool: Any) -> None:
5858
# Case 1: String file path
5959
if isinstance(tool, str):
6060
# Extract tool name from path
@@ -97,9 +97,16 @@ def process_tools(self, tools: List[Any]) -> List[str]:
9797
elif isinstance(tool, AgentTool):
9898
self.register_tool(tool)
9999
tool_names.append(tool.tool_name)
100+
# Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool
101+
elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)):
102+
for t in tool:
103+
add_tool(t)
100104
else:
101105
logger.warning("tool=<%s> | unrecognized tool specification", tool)
102106

107+
for a_tool in tools:
108+
add_tool(a_tool)
109+
103110
return tool_names
104111

105112
def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None:

tests/strands/agent/test_agent.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,25 @@ def test_agent__init__with_string_model_id():
231231
assert agent.model.config["model_id"] == "nonsense"
232232

233233

234+
def test_agent__init__nested_tools_flattening(tool_decorated, tool_module, tool_imported, tool_registry):
235+
_ = tool_registry
236+
# Nested structure: [tool_decorated, [tool_module, [tool_imported]]]
237+
agent = Agent(tools=[tool_decorated, [tool_module, [tool_imported]]])
238+
tru_tool_names = sorted(agent.tool_names)
239+
exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
240+
assert tru_tool_names == exp_tool_names
241+
242+
243+
def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_imported, tool_registry):
244+
_ = tool_registry
245+
# Deeply nested structure
246+
nested_tools = [[[[tool_decorated]], [[tool_module]], tool_imported]]
247+
agent = Agent(tools=nested_tools)
248+
tru_tool_names = sorted(agent.tool_names)
249+
exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
250+
assert tru_tool_names == exp_tool_names
251+
252+
234253
def test_agent__call__(
235254
mock_model,
236255
system_prompt,

tests/strands/tools/test_registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,30 @@ def tool_function_4(d):
9393

9494
assert len(tools) == 2
9595
assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)
96+
97+
98+
def test_process_tools_flattens_lists_and_tuples_and_sets():
99+
def function() -> str:
100+
return "done"
101+
102+
tool_a = tool(name="tool_a")(function)
103+
tool_b = tool(name="tool_b")(function)
104+
tool_c = tool(name="tool_c")(function)
105+
tool_d = tool(name="tool_d")(function)
106+
tool_e = tool(name="tool_e")(function)
107+
tool_f = tool(name="tool_f")(function)
108+
109+
registry = ToolRegistry()
110+
111+
all_tools = [tool_a, (tool_b, tool_c), [{tool_d, tool_e}, [tool_f]]]
112+
113+
tru_tool_names = sorted(registry.process_tools(all_tools))
114+
exp_tool_names = [
115+
"tool_a",
116+
"tool_b",
117+
"tool_c",
118+
"tool_d",
119+
"tool_e",
120+
"tool_f",
121+
]
122+
assert tru_tool_names == exp_tool_names

0 commit comments

Comments
 (0)