Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 51 additions & 47 deletions src/strands/tools/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import logging
import os
import sys
import warnings
from pathlib import Path
from typing import cast
from typing import List, cast

from ..types.tools import AgentTool
from .decorator import DecoratedFunctionTool
Expand All @@ -18,60 +19,42 @@ class ToolLoader:
"""Handles loading of tools from different sources."""

@staticmethod
def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
"""Load a Python tool module.

Args:
tool_path: Path to the Python tool file.
tool_name: Name of the tool.

Returns:
Tool instance.
def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]:
"""Load a Python tool module and return all discovered function-based tools as a list.

Raises:
AttributeError: If required attributes are missing from the tool module.
ImportError: If there are issues importing the tool module.
TypeError: If the tool function is not callable.
ValueError: If function in module is not a valid tool.
Exception: For other errors during tool loading.
This method always returns a list of AgentTool (possibly length 1). It is the
canonical API for retrieving multiple tools from a single Python file.
"""
try:
# Check if tool_path is in the format "package.module:function"; but keep in mind windows whose file path
# could have a colon so also ensure that it's not a file
# Support module:function style (e.g. package.module:function)
if not os.path.exists(tool_path) and ":" in tool_path:
module_path, function_name = tool_path.rsplit(":", 1)
logger.debug("tool_name=<%s>, module_path=<%s> | importing tool from path", function_name, module_path)

try:
# Import the module
module = __import__(module_path, fromlist=["*"])

# Get the function
if not hasattr(module, function_name):
raise AttributeError(f"Module {module_path} has no function named {function_name}")

func = getattr(module, function_name)

if isinstance(func, DecoratedFunctionTool):
logger.debug(
"tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path
)
# mypy has problems converting between DecoratedFunctionTool <-> AgentTool
return cast(AgentTool, func)
else:
raise ValueError(
f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)"
)

except ImportError as e:
raise ImportError(f"Failed to import module {module_path}: {str(e)}") from e

if not hasattr(module, function_name):
raise AttributeError(f"Module {module_path} has no function named {function_name}")

func = getattr(module, function_name)
if isinstance(func, DecoratedFunctionTool):
logger.debug(
"tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path
)
return [cast(AgentTool, func)]
else:
raise ValueError(
f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)"
)

# Normal file-based tool loading
abs_path = str(Path(tool_path).resolve())

logger.debug("tool_path=<%s> | loading python tool from path", abs_path)

# First load the module to get TOOL_SPEC and check for Lambda deployment
# Load the module by spec
spec = importlib.util.spec_from_file_location(tool_name, abs_path)
if not spec:
raise ImportError(f"Could not create spec for {tool_name}")
Expand All @@ -82,24 +65,26 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
sys.modules[tool_name] = module
spec.loader.exec_module(module)

# First, check for function-based tools with @tool decorator
# Collect function-based tools decorated with @tool
function_tools: List[AgentTool] = []
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, DecoratedFunctionTool):
logger.debug(
"tool_name=<%s>, tool_path=<%s> | found function-based tool in path", attr_name, tool_path
)
# mypy has problems converting between DecoratedFunctionTool <-> AgentTool
return cast(AgentTool, attr)
function_tools.append(cast(AgentTool, attr))

# If no function-based tools found, fall back to traditional module-level tool
if function_tools:
return function_tools

# Fall back to module-level TOOL_SPEC + function
tool_spec = getattr(module, "TOOL_SPEC", None)
if not tool_spec:
raise AttributeError(
f"Tool {tool_name} missing TOOL_SPEC (neither at module level nor as a decorated function)"
)

# Standard local tool loading
tool_func_name = tool_name
if not hasattr(module, tool_func_name):
raise AttributeError(f"Tool {tool_name} missing function {tool_func_name}")
Expand All @@ -108,12 +93,31 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
if not callable(tool_func):
raise TypeError(f"Tool {tool_name} function is not callable")

return PythonAgentTool(tool_name, tool_spec, tool_func)
return [PythonAgentTool(tool_name, tool_spec, tool_func)]

except Exception:
logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path)
logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool(s)", tool_name, sys.path)
raise

@staticmethod
def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
"""DEPRECATED: Load a Python tool module and return a single AgentTool for backwards compatibility.

Use `load_python_tools` to retrieve all tools defined in a .py file (returns a list).
This function will emit a `DeprecationWarning` and return the first discovered tool.
"""
warnings.warn(
"ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. "
"Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.",
DeprecationWarning,
stacklevel=2,
)

tools = ToolLoader.load_python_tools(tool_path, tool_name)
if not tools:
raise RuntimeError(f"No tools found in {tool_path} for {tool_name}")
return tools[0]

@classmethod
def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:
"""Load a tool based on its file extension.
Expand All @@ -123,7 +127,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:
tool_name: Name of the tool.

Returns:
Tool instance.
A single Tool instance.

Raises:
FileNotFoundError: If the tool file does not exist.
Expand Down
11 changes: 7 additions & 4 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,19 +318,22 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes
"""
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))

mapped_content = [
mapped_content
# Build a typed list of ToolResultContent. Use a clearer local name to avoid shadowing
# and annotate the result for mypy so it knows the intended element type.
mapped_contents: list[ToolResultContent] = [
mc
for content in call_tool_result.content
if (mapped_content := self._map_mcp_content_to_tool_result_content(content)) is not None
if (mc := self._map_mcp_content_to_tool_result_content(content)) is not None
]

status: ToolResultStatus = "error" if call_tool_result.isError else "success"
self._log_debug_with_thread("tool execution completed with status: %s", status)
result = MCPToolResult(
status=status,
toolUseId=tool_use_id,
content=mapped_content,
content=mapped_contents,
)

if call_tool_result.structuredContent:
result["structuredContent"] = call_tool_result.structuredContent

Expand Down
11 changes: 6 additions & 5 deletions src/strands/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from strands.tools.decorator import DecoratedFunctionTool

from ..types.tools import AgentTool, ToolSpec
from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec
from .tools import PythonAgentTool, normalize_loaded_tools, normalize_schema, normalize_tool_spec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,10 +128,11 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None:
raise FileNotFoundError(f"Tool file not found: {tool_path}")

loaded_tool = ToolLoader.load_tool(tool_path, tool_name)
loaded_tool.mark_dynamic()

# Because we're explicitly registering the tool we don't need an allowlist
self.register_tool(loaded_tool)
# normalize_loaded_tools handles single tool or list of tools
for t in normalize_loaded_tools(loaded_tool):
t.mark_dynamic()
# Because we're explicitly registering the tool we don't need an allowlist
self.register_tool(t)
except Exception as e:
exception_str = str(e)
logger.exception("tool_name=<%s> | failed to load tool", tool_name)
Expand Down
17 changes: 12 additions & 5 deletions src/strands/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import inspect
import logging
import re
from typing import Any
from typing import Any, List, Union

from typing_extensions import override

Expand All @@ -33,22 +33,22 @@ def validate_tool_use(tool: ToolUse) -> None:
validate_tool_use_name(tool)


def validate_tool_use_name(tool: ToolUse) -> None:
def validate_tool_use_name(tool_use: ToolUse) -> None:
"""Validate the name of a tool use.

Args:
tool: The tool use to validate.
tool_use: The tool use to validate.

Raises:
InvalidToolUseNameException: If the tool name is invalid.
"""
# We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any]
if "name" not in tool:
if "name" not in tool_use:
message = "tool name missing" # type: ignore[unreachable]
logger.warning(message)
raise InvalidToolUseNameException(message)

tool_name = tool["name"]
tool_name = tool_use["name"]
tool_name_pattern = r"^[a-zA-Z0-9_\-]{1,}$"
tool_name_max_length = 64
valid_name_pattern = bool(re.match(tool_name_pattern, tool_name))
Expand Down Expand Up @@ -146,6 +146,13 @@ def normalize_tool_spec(tool_spec: ToolSpec) -> ToolSpec:
return normalized


def normalize_loaded_tools(loaded: Union[AgentTool, List[AgentTool]]) -> List[AgentTool]:
"""Normalize ToolLoader.load_tool return value to always be a list of AgentTool."""
if isinstance(loaded, list):
return loaded
return [loaded]


class PythonAgentTool(AgentTool):
"""Tool implementation for Python-based tools.

Expand Down
29 changes: 29 additions & 0 deletions tests/strands/tools/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,32 @@ def no_spec():
def test_load_tool_no_spec(tool_path):
with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"):
ToolLoader.load_tool(tool_path, "no_spec")


@pytest.mark.parametrize(
"tool_path",
[
textwrap.dedent(
"""
import strands

@strands.tools.tool
def alpha():
return "alpha"

@strands.tools.tool
def bravo():
return "bravo"
"""
)
],
indirect=True,
)
def test_load_python_tool_path_multiple_function_based(tool_path):
# load_python_tool returns a list when multiple decorated tools are present
loaded = ToolLoader.load_python_tools(tool_path, "alpha")
assert isinstance(loaded, list)
assert len(loaded) == 2
assert all(isinstance(t, DecoratedFunctionTool) for t in loaded)
names = {getattr(t, "tool_name", getattr(t, "name", None)) for t in loaded}
assert "alpha" in names and "bravo" in names