Skip to content

Commit 171779a

Browse files
Unshurepgrayy
andauthored
feat: Refactor and update tool loading to support modules (#989)
* feat: Refactor and update tool loading to support modules * Update registry.py * feat: Address pr feedback * Update src/strands/tools/registry.py Co-authored-by: Patrick Gray <[email protected]> * Update src/strands/tools/loader.py Co-authored-by: Patrick Gray <[email protected]> --------- Co-authored-by: Patrick Gray <[email protected]>
1 parent 2a26ffa commit 171779a

File tree

8 files changed

+364
-63
lines changed

8 files changed

+364
-63
lines changed

.github/workflows/test-lint.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ jobs:
6666
id: tests
6767
run: hatch test tests --cover
6868
continue-on-error: false
69+
70+
- name: Upload coverage reports to Codecov
71+
uses: codecov/codecov-action@v5
72+
with:
73+
token: ${{ secrets.CODECOV_TOKEN }}
6974
lint:
7075
name: Lint
7176
runs-on: ubuntu-latest

src/strands/tools/loader.py

Lines changed: 148 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import os
66
import sys
77
import warnings
8+
from importlib.machinery import ModuleSpec
89
from pathlib import Path
10+
from posixpath import expanduser
11+
from types import ModuleType
912
from typing import List, cast
1013

1114
from ..types.tools import AgentTool
@@ -15,16 +18,151 @@
1518
logger = logging.getLogger(__name__)
1619

1720

21+
def load_tool_from_string(tool_string: str) -> List[AgentTool]:
22+
"""Load tools follows strands supported input string formats.
23+
24+
This function can load a tool based on a string in the following ways:
25+
1. Local file path to a module based tool: `./path/to/module/tool.py`
26+
2. Module import path
27+
2.1. Path to a module based tool: `strands_tools.file_read`
28+
2.2. Path to a module with multiple AgentTool instances (@tool decorated): `tests.fixtures.say_tool`
29+
2.3. Path to a module and a specific function: `tests.fixtures.say_tool:say`
30+
"""
31+
# Case 1: Local file path to a tool
32+
# Ex: ./path/to/my_cool_tool.py
33+
tool_path = expanduser(tool_string)
34+
if os.path.exists(tool_path):
35+
return load_tools_from_file_path(tool_path)
36+
37+
# Case 2: Module import path
38+
# Ex: test.fixtures.say_tool:say (Load specific @tool decorated function)
39+
# Ex: strands_tools.file_read (Load all @tool decorated functions, or module tool)
40+
return load_tools_from_module_path(tool_string)
41+
42+
43+
def load_tools_from_file_path(tool_path: str) -> List[AgentTool]:
44+
"""Load module from specified path, and then load tools from that module.
45+
46+
This function attempts to load the passed in path as a python module, and if it succeeds,
47+
then it tries to import strands tool(s) from that module.
48+
"""
49+
abs_path = str(Path(tool_path).resolve())
50+
logger.debug("tool_path=<%s> | loading python tool from path", abs_path)
51+
52+
# Load the module by spec
53+
54+
# Using this to determine the module name
55+
# ./path/to/my_cool_tool.py -> my_cool_tool
56+
module_name = os.path.basename(tool_path).split(".")[0]
57+
58+
# This function imports a module based on its path, and gives it the provided name
59+
60+
spec: ModuleSpec = cast(ModuleSpec, importlib.util.spec_from_file_location(module_name, abs_path))
61+
if not spec:
62+
raise ImportError(f"Could not create spec for {module_name}")
63+
if not spec.loader:
64+
raise ImportError(f"No loader available for {module_name}")
65+
66+
module = importlib.util.module_from_spec(spec)
67+
# Load, or re-load, the module
68+
sys.modules[module_name] = module
69+
# Execute the module to run any top level code
70+
spec.loader.exec_module(module)
71+
72+
return load_tools_from_module(module, module_name)
73+
74+
75+
def load_tools_from_module_path(module_tool_path: str) -> list[AgentTool]:
76+
"""Load strands tool from a module path.
77+
78+
Example module paths:
79+
my.module.path
80+
my.module.path:tool_name
81+
"""
82+
if ":" in module_tool_path:
83+
module_path, tool_func_name = module_tool_path.split(":")
84+
else:
85+
module_path, tool_func_name = (module_tool_path, None)
86+
87+
try:
88+
module = importlib.import_module(module_path)
89+
except ModuleNotFoundError as e:
90+
raise AttributeError(f'Tool string: "{module_tool_path}" is not a valid tool string.') from e
91+
92+
# If a ':' is present in the string, then its a targeted function in a module
93+
if tool_func_name:
94+
if hasattr(module, tool_func_name):
95+
target_tool = getattr(module, tool_func_name)
96+
if isinstance(target_tool, DecoratedFunctionTool):
97+
return [target_tool]
98+
99+
raise AttributeError(f"Tool {tool_func_name} not found in module {module_path}")
100+
101+
# Else, try to import all of the @tool decorated tools, or the module based tool
102+
module_name = module_path.split(".")[-1]
103+
return load_tools_from_module(module, module_name)
104+
105+
106+
def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTool]:
107+
"""Load tools from a module.
108+
109+
First checks if the passed in module has instances of DecoratedToolFunction classes as atributes to the module.
110+
If so, then it returns them as a list of tools. If not, then it attempts to load the module as a module based tool.
111+
"""
112+
logger.debug("tool_name=<%s>, module=<%s> | loading tools from module", module_name, module_name)
113+
114+
# Try and see if any of the attributes in the module are function-based tools decorated with @tool
115+
# This means that there may be more than one tool available in this module, so we load them all
116+
117+
function_tools: List[AgentTool] = []
118+
# Function tools will appear as attributes in the module
119+
for attr_name in dir(module):
120+
attr = getattr(module, attr_name)
121+
# Check if the module attribute is a DecoratedFunctiontool
122+
if isinstance(attr, DecoratedFunctionTool):
123+
logger.debug("tool_name=<%s>, module=<%s> | found function-based tool in module", attr_name, module_name)
124+
function_tools.append(cast(AgentTool, attr))
125+
126+
if function_tools:
127+
return function_tools
128+
129+
# Finally, if no DecoratedFunctionTools are found in the module, fall back
130+
# to module based tools, and search for TOOL_SPEC + function
131+
module_tool_name = module_name
132+
tool_spec = getattr(module, "TOOL_SPEC", None)
133+
if not tool_spec:
134+
raise AttributeError(
135+
f"The module {module_tool_name} is not a valid module for loading tools."
136+
"This module must contain @tool decorated function(s), or must be a module based tool."
137+
)
138+
139+
# If this is a module based tool, the module should have a function with the same name as the module itself
140+
if not hasattr(module, module_tool_name):
141+
raise AttributeError(f"Module-based tool {module_tool_name} missing function {module_tool_name}")
142+
143+
tool_func = getattr(module, module_tool_name)
144+
if not callable(tool_func):
145+
raise TypeError(f"Tool {module_tool_name} function is not callable")
146+
147+
return [PythonAgentTool(module_tool_name, tool_spec, tool_func)]
148+
149+
18150
class ToolLoader:
19151
"""Handles loading of tools from different sources."""
20152

21153
@staticmethod
22154
def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]:
23-
"""Load a Python tool module and return all discovered function-based tools as a list.
155+
"""DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list.
24156
25157
This method always returns a list of AgentTool (possibly length 1). It is the
26158
canonical API for retrieving multiple tools from a single Python file.
27159
"""
160+
warnings.warn(
161+
"ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. "
162+
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
163+
DeprecationWarning,
164+
stacklevel=2,
165+
)
28166
try:
29167
# Support module:function style (e.g. package.module:function)
30168
if not os.path.exists(tool_path) and ":" in tool_path:
@@ -108,7 +246,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
108246
"""
109247
warnings.warn(
110248
"ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. "
111-
"Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.",
249+
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
112250
DeprecationWarning,
113251
stacklevel=2,
114252
)
@@ -127,7 +265,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:
127265
"""
128266
warnings.warn(
129267
"ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. "
130-
"Use ToolLoader.load_tools(...) which always returns a list of AgentTool.",
268+
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
131269
DeprecationWarning,
132270
stacklevel=2,
133271
)
@@ -140,7 +278,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:
140278

141279
@classmethod
142280
def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]:
143-
"""Load tools from a file based on its file extension.
281+
"""DEPRECATED: Load tools from a file based on its file extension.
144282
145283
Args:
146284
tool_path: Path to the tool file.
@@ -154,6 +292,12 @@ def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]:
154292
ValueError: If the tool file has an unsupported extension.
155293
Exception: For other errors during tool loading.
156294
"""
295+
warnings.warn(
296+
"ToolLoader.load_tools is deprecated and will be removed in Strands SDK 2.0. "
297+
"Use the `load_tools_from_string` or `load_tools_from_module` methods instead.",
298+
DeprecationWarning,
299+
stacklevel=2,
300+
)
157301
ext = Path(tool_path).suffix.lower()
158302
abs_path = str(Path(tool_path).resolve())
159303

0 commit comments

Comments
 (0)