Skip to content

Commit aa3cc5e

Browse files
committed
feat: enhance AgentConfig with ToolPool integration and tool validation
- Add optional ToolPool parameter to AgentConfig constructor - Implement tool selection from ToolPool when 'tools' specified in config - Add validation to ensure tool names exist in provided ToolPool - Raise ValueError if tools specified without ToolPool - Move ToolPool import to top of file for cleaner imports - Add comprehensive tests for new functionality - Fix existing tests to work with new validation logic 🤖 Assisted by Amazon Q Developer
1 parent 763092a commit aa3cc5e

File tree

2 files changed

+109
-9
lines changed

2 files changed

+109
-9
lines changed

src/strands/experimental/agent_config.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44

55
import json
66

7+
from .tool_pool import ToolPool
8+
79

810
class AgentConfig:
911
"""Agent configuration with toAgent() method and ToolPool integration."""
1012

11-
def __init__(self, config_source: str | dict[str, any]):
13+
def __init__(self, config_source: str | dict[str, any], tool_pool: "ToolPool | None" = None):
1214
"""Initialize AgentConfig from file path or dictionary.
1315
1416
Args:
1517
config_source: Path to JSON config file (must start with 'file://') or config dictionary
18+
tool_pool: Optional ToolPool to select tools from when 'tools' is specified in config
1619
"""
1720
if isinstance(config_source, str):
1821
# Require file:// prefix for file paths
@@ -28,12 +31,34 @@ def __init__(self, config_source: str | dict[str, any]):
2831
config_data = config_source
2932

3033
self.model = config_data.get('model')
31-
self.tools = config_data.get('tools')
3234
self.system_prompt = config_data.get('prompt') # Only accept 'prompt' key
3335

34-
# Create empty default ToolPool
35-
from .tool_pool import ToolPool
36-
self._tool_pool = ToolPool()
36+
# Handle tool selection from ToolPool
37+
if tool_pool is not None:
38+
self._tool_pool = tool_pool
39+
else:
40+
self._tool_pool = ToolPool()
41+
42+
# Process tools configuration if provided
43+
config_tools = config_data.get('tools')
44+
if config_tools is not None:
45+
if tool_pool is None:
46+
raise ValueError("Tool names specified in config but no ToolPool provided")
47+
48+
# Validate all tool names exist in the ToolPool
49+
available_tools = tool_pool.list_tool_names()
50+
for tool_name in config_tools:
51+
if tool_name not in available_tools:
52+
raise ValueError(f"Tool '{tool_name}' not found in ToolPool. Available tools: {available_tools}")
53+
54+
# Create new ToolPool with only selected tools
55+
selected_pool = ToolPool()
56+
all_tools = tool_pool.get_tools()
57+
for tool in all_tools:
58+
if tool.tool_name in config_tools:
59+
selected_pool.add_tool(tool)
60+
61+
self._tool_pool = selected_pool
3762

3863
@property
3964
def tool_pool(self) -> "ToolPool":

tests/strands/experimental/test_agent_config.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from strands.experimental.agent_config import AgentConfig
99
from strands.experimental.tool_pool import ToolPool
10+
from strands.types.tools import AgentTool
1011

1112

1213
class TestAgentConfig:
@@ -18,16 +19,14 @@ def test_agent_config_creation(self):
1819
assert config.model == "test-model"
1920

2021
def test_agent_config_with_tools(self):
21-
"""Test AgentConfig with tools configuration."""
22+
"""Test AgentConfig with basic configuration."""
2223

2324
config = AgentConfig({
2425
"model": "test-model",
25-
"tools": ["tool1", "tool2"],
2626
"prompt": "Test prompt"
2727
})
2828

2929
assert config.model == "test-model"
30-
assert config.tools == ["tool1", "tool2"]
3130
assert config.system_prompt == "Test prompt"
3231

3332
def test_agent_config_file_prefix_required(self):
@@ -68,7 +67,6 @@ def test_to_agent_calls_agent_constructor(self, mock_agent):
6867

6968
config = AgentConfig({
7069
"model": "test-model",
71-
"tools": ["tool1"],
7270
"prompt": "Test prompt"
7371
})
7472

@@ -112,3 +110,80 @@ def test_to_agent_with_tool_pool_parameter(self, mock_agent):
112110
mock_agent.assert_called_once()
113111
call_args = mock_agent.call_args[1]
114112
assert 'tools' in call_args
113+
114+
def test_agent_config_with_tool_pool_constructor(self):
115+
"""Test AgentConfig with ToolPool parameter in constructor."""
116+
# Create mock tools
117+
class MockTool(AgentTool):
118+
def __init__(self, name):
119+
self._name = name
120+
121+
@property
122+
def tool_name(self):
123+
return self._name
124+
125+
@property
126+
def tool_type(self):
127+
return "mock"
128+
129+
@property
130+
def tool_spec(self):
131+
return {"name": self._name, "type": "mock"}
132+
133+
def stream(self, input_data, context):
134+
return iter([])
135+
136+
tool1 = MockTool("calculator")
137+
tool2 = MockTool("web_search")
138+
139+
# Create ToolPool with tools
140+
pool = ToolPool([tool1, tool2])
141+
142+
# Create config with tool selection
143+
config = AgentConfig({
144+
"model": "test-model",
145+
"prompt": "Test prompt",
146+
"tools": ["calculator"]
147+
}, tool_pool=pool)
148+
149+
# Should have selected only calculator
150+
assert config.tool_pool.list_tool_names() == ["calculator"]
151+
152+
def test_agent_config_tool_validation_error(self):
153+
"""Test that invalid tool names raise validation error."""
154+
class MockTool(AgentTool):
155+
def __init__(self, name):
156+
self._name = name
157+
158+
@property
159+
def tool_name(self):
160+
return self._name
161+
162+
@property
163+
def tool_type(self):
164+
return "mock"
165+
166+
@property
167+
def tool_spec(self):
168+
return {"name": self._name, "type": "mock"}
169+
170+
def stream(self, input_data, context):
171+
return iter([])
172+
173+
tool1 = MockTool("calculator")
174+
pool = ToolPool([tool1])
175+
176+
# Should raise error for unknown tool
177+
with pytest.raises(ValueError, match="Tool 'unknown_tool' not found in ToolPool"):
178+
AgentConfig({
179+
"model": "test-model",
180+
"tools": ["unknown_tool"]
181+
}, tool_pool=pool)
182+
183+
def test_agent_config_tools_without_pool_error(self):
184+
"""Test that specifying tools without ToolPool raises error."""
185+
with pytest.raises(ValueError, match="Tool names specified in config but no ToolPool provided"):
186+
AgentConfig({
187+
"model": "test-model",
188+
"tools": ["calculator"]
189+
})

0 commit comments

Comments
 (0)