forked from modelscope/ms-agent
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool_manager.py
More file actions
249 lines (225 loc) · 11 KB
/
tool_manager.py
File metadata and controls
249 lines (225 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# Copyright (c) ModelScope Contributors. All rights reserved.
import asyncio
import importlib
import inspect
import os
import sys
from copy import copy
from types import TracebackType
from typing import Any, Dict, List, Optional
import json
from ms_agent.llm.utils import Tool, ToolCall
from ms_agent.tools.agent_tool import AgentTool
from ms_agent.tools.base import ToolBase
from ms_agent.tools.code import CodeExecutionTool, LocalCodeExecutionTool
from ms_agent.tools.filesystem_tool import FileSystemTool
from ms_agent.tools.image_generator import ImageGenerator
from ms_agent.tools.mcp_client import MCPClient
from ms_agent.tools.shell.shell import Shell
from ms_agent.tools.split_task import SplitTask
from ms_agent.tools.video_generator import VideoGenerator
from ms_agent.utils import get_logger
from ms_agent.utils.constants import TOOL_PLUGIN_NAME
logger = get_logger()
MAX_TOOL_NAME_LEN = int(os.getenv('MAX_TOOL_NAME_LEN', 64))
TOOL_CALL_TIMEOUT = int(os.getenv('TOOL_CALL_TIMEOUT', 30))
MAX_CONCURRENT_TOOLS = int(os.getenv('MAX_CONCURRENT_TOOLS', 20))
class ToolManager:
"""Interacting with Agent class, hold all tools
"""
TOOL_SPLITER = '---'
def __init__(self,
config,
mcp_config: Optional[Dict[str, Any]] = None,
mcp_client: Optional[MCPClient] = None,
**kwargs):
self.config = config
self.trust_remote_code = kwargs.get('trust_remote_code', False)
self.extra_tools: List[ToolBase] = []
self.has_split_task_tool = False
if hasattr(config, 'tools') and hasattr(config.tools, 'split_task'):
self.extra_tools.append(SplitTask(config))
if hasattr(config, 'tools') and hasattr(config.tools,
'image_generator'):
self.extra_tools.append(ImageGenerator(config))
if hasattr(config, 'tools') and hasattr(config.tools,
'video_generator'):
self.extra_tools.append(VideoGenerator(config))
if hasattr(config, 'tools') and hasattr(config.tools, 'shell'):
self.extra_tools.append(Shell(config))
if hasattr(config, 'tools') and hasattr(config.tools, 'file_system'):
self.extra_tools.append(
FileSystemTool(
config, trust_remote_code=self.trust_remote_code))
if hasattr(config, 'tools') and hasattr(config.tools, 'code_executor'):
code_exec_cfg = getattr(config.tools, 'code_executor')
implementation = getattr(code_exec_cfg, 'implementation',
'sandbox')
if isinstance(implementation,
str) and implementation.lower() == 'python_env':
self.extra_tools.append(LocalCodeExecutionTool(config))
elif isinstance(implementation,
str) and implementation.lower() == 'sandbox':
self.extra_tools.append(CodeExecutionTool(config))
else:
logger.warning(
f'Unknown code execution implementation: {implementation},'
f'using sandbox instead.')
self.extra_tools.append(CodeExecutionTool(config))
if hasattr(config, 'tools') and hasattr(config.tools,
'financial_data_fetcher'):
from ms_agent.tools.findata.findata_fetcher import FinancialDataFetcher
self.extra_tools.append(FinancialDataFetcher(config))
if hasattr(config, 'tools') and getattr(config.tools, 'agent_tools',
None):
agent_tool = AgentTool(
config, trust_remote_code=self.trust_remote_code)
if agent_tool.enabled:
self.extra_tools.append(agent_tool)
self.tool_call_timeout = getattr(config, 'tool_call_timeout',
TOOL_CALL_TIMEOUT)
local_dir = self.config.local_dir if hasattr(self.config,
'local_dir') else None
if hasattr(config, 'tools') and hasattr(config.tools,
TOOL_PLUGIN_NAME):
plugins = getattr(config.tools, TOOL_PLUGIN_NAME)
for plugin in plugins:
subdir = os.path.dirname(plugin)
_plugin = os.path.basename(plugin)
assert local_dir is not None, 'Using external py files, but local_dir cannot be found.'
if subdir:
subdir = os.path.join(local_dir, str(subdir))
if not self.trust_remote_code:
raise AssertionError(
'[External Code Found] Your config file contains external code, '
'instantiate the code may be UNSAFE, if you trust the code, '
'please pass `trust_remote_code=True` or `--trust_remote_code true`'
)
if local_dir not in sys.path:
sys.path.insert(0, local_dir)
if subdir and subdir not in sys.path:
sys.path.insert(0, subdir)
if _plugin.endswith('.py'):
_plugin = _plugin[:-3]
plugin_file = importlib.import_module(_plugin)
module_classes = {
name: cls
for name, cls in inspect.getmembers(
plugin_file, inspect.isclass)
}
for name, cls in module_classes.items():
# Find cls which base class is `ToolBase`
if issubclass(cls, ToolBase) and cls.__module__ == _plugin:
self.register_tool(cls(self.config))
self._tool_index = {}
# Used temporarily during async initialization; the actual client is managed in self.servers
self.mcp_client = mcp_client
self.mcp_config = mcp_config
self.servers = None
self._managed_client = mcp_client is None
# Initialize concurrency limiter (will be set in connect)
self._concurrent_limiter = None
self._init_lock = None
def register_tool(self, tool: ToolBase):
self.extra_tools.append(tool)
async def connect(self):
if self.mcp_client and isinstance(self.mcp_client, MCPClient):
self.servers = self.mcp_client
await self.servers.add_mcp_config(self.mcp_config)
self.mcp_config = self.servers.mcp_config
else:
self.servers = MCPClient(self.mcp_config, self.config)
await self.servers.connect()
for tool in self.extra_tools:
await tool.connect()
await self.reindex_tool()
# Initialize concurrency limiter
self._concurrent_limiter = asyncio.Semaphore(MAX_CONCURRENT_TOOLS)
logger.info(f'Tool concurrency limit set to {MAX_CONCURRENT_TOOLS}')
async def cleanup(self):
if self._managed_client and self.servers:
try:
await self.servers.cleanup()
except Exception: # noqa
pass
self.servers = None
for tool in self.extra_tools:
try:
await tool.cleanup()
except Exception: # noqa
pass
async def reindex_tool(self):
def extend_tool(tool_ins: ToolBase, server_name: str,
tool_list: List[Tool]):
for tool in tool_list:
# Subtract the length of the tool name splitter
max_server_len = MAX_TOOL_NAME_LEN - len(
tool['tool_name']) - len(self.TOOL_SPLITER)
if len(server_name) > max_server_len:
key = f"{server_name[:max(0, max_server_len)]}{self.TOOL_SPLITER}{tool['tool_name']}"
else:
key = f"{server_name}{self.TOOL_SPLITER}{tool['tool_name']}"
assert key not in self._tool_index, f'Tool name duplicated {tool["tool_name"]}'
tool = copy(tool)
tool['tool_name'] = key
self._tool_index[key] = (tool_ins, server_name, tool)
mcps = await self.servers.get_tools()
for server_name, tool_list in mcps.items():
extend_tool(self.servers, server_name, tool_list)
for extra_tool in self.extra_tools:
tools = await extra_tool.get_tools()
for server_name, tool_list in tools.items():
extend_tool(extra_tool, server_name, tool_list)
async def get_tools(self):
return [value[2] for value in self._tool_index.values()]
async def single_call_tool(self, tool_info: ToolCall):
if self._concurrent_limiter is None:
if self._init_lock is None:
self._init_lock = asyncio.Lock()
async with self._init_lock:
if self._concurrent_limiter is None:
self._concurrent_limiter = asyncio.Semaphore(
MAX_CONCURRENT_TOOLS)
async with self._concurrent_limiter:
brief_info = json.dumps(tool_info, ensure_ascii=False)
if len(brief_info) > 1024:
brief_info = brief_info[:1024] + '...'
try:
tool_name = tool_info['tool_name']
tool_args = tool_info['arguments']
while isinstance(tool_args, str):
try:
tool_args = json.loads(tool_args)
except Exception: # noqa
return f'The input {tool_args} is not a valid JSON, fix your arguments and try again'
assert tool_name in self._tool_index, f'Tool name {tool_name} not found'
tool_ins, server_name, _ = self._tool_index[tool_name]
response = await asyncio.wait_for(
tool_ins.call_tool(
server_name,
tool_name=tool_name.split(self.TOOL_SPLITER)[1],
tool_args=tool_args),
timeout=self.tool_call_timeout)
return response
except asyncio.TimeoutError:
import traceback
logger.warning(traceback.format_exc())
# TODO: How to get the information printed by the tool before hanging to return to the model?
return f'Execute tool call timeout: {brief_info}'
except Exception as e:
import traceback
logger.warning(traceback.format_exc())
return f'Tool calling failed: {brief_info}, details: {str(e)}'
async def parallel_call_tool(self, tool_list: List[ToolCall]):
tasks = [self.single_call_tool(tool) for tool in tool_list]
result = await asyncio.gather(*tasks)
return result
async def __aenter__(self) -> 'ToolManager':
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass