1+ import asyncio
12import inspect
23from dataclasses import asdict , dataclass , field
34from functools import partial
45from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Union , Tuple
56
67import msgspec
7- from opentelemetry import trace
8- from opentelemetry .trace import Status , StatusCode
98
109from msgflux .dotdict import dotdict
11- from msgflux .envs import envs
1210from msgflux .logger import logger
1311from msgflux .nn import functional as F
1412from msgflux .nn .modules .container import ModuleDict
@@ -149,6 +147,41 @@ async def aforward(self, **kwargs) -> Any:
149147 # Extract and return result
150148 return extract_tool_result_text (result )
151149
150+ class LocalTool (Tool ):
151+ """Local tool implementation."""
152+ def __init__ (
153+ self ,
154+ name : str ,
155+ description : str ,
156+ annotations : Dict [str , Any ],
157+ tool_config : Dict [str , Any ],
158+ impl : Callable ,
159+ ):
160+ super ().__init__ ()
161+ self .set_name (name )
162+ self .set_description (description )
163+ self .set_annotations (annotations )
164+ self .register_buffer ("tool_config" , tool_config )
165+ self .impl = impl # Not a buffer for now
166+
167+ @tool_retry
168+ @set_tool_attributes (execution_type = "local" )
169+ def forward (self , ** kwargs ):
170+ if inspect .iscoroutinefunction (self .impl ):
171+ return F .wait_for (self .impl , ** kwargs )
172+ return self .impl (** kwargs )
173+
174+ @tool_retry
175+ @aset_tool_attributes (execution_type = "local" )
176+ async def aforward (self , * args , ** kwargs ):
177+ if hasattr (self .impl , "acall" ):
178+ return await self .impl .acall (* args , ** kwargs )
179+ elif inspect .iscoroutinefunction (self .impl ):
180+ return await self .impl (* args , ** kwargs )
181+ # Fall back to sync call in executor to avoid blocking event loop
182+ loop = asyncio .get_event_loop ()
183+ return await loop .run_in_executor (None , lambda : self .impl (* args , ** kwargs ))
184+
152185def _convert_module_to_nn_tool (impl : Callable ) -> Tool : # noqa: C901
153186 """Convert a callable in nn.Tool."""
154187 tool_config = impl .__dict__ .get ("tool_config" , dotdict ())
@@ -168,7 +201,7 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
168201 or
169202 getattr (impl , "__doc__" , None )
170203 or
171- getattr (impl .__call__ , "__doc__" , None )
204+ getattr (impl .__call__ , "__doc__" , None )
172205 )
173206 if doc is None :
174207 raise NotImplementedError (
@@ -183,7 +216,7 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
183216 or
184217 getattr (impl , "__annotations__" , None )
185218 or
186- getattr (impl .__call__ , "__annotations__" , None )
219+ getattr (impl .__call__ , "__annotations__" , None )
187220 )
188221 if annotations is None :
189222 if fn_has_parameters (impl .__call__ ):
@@ -217,7 +250,7 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
217250 )
218251
219252 annotations = impl .__annotations__
220-
253+
221254 if annotations is None :
222255 if fn_has_parameters (impl ):
223256 raise NotImplementedError (
@@ -241,33 +274,13 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
241274 if tool_config .get ("background" ):
242275 doc = "This tool will run in the background. \n " + doc
243276
244- class LocalTool (Tool ):
245- """Local tool implementation."""
246- def __init__ (self ):
247- super ().__init__ ()
248- self .set_name (name )
249- self .set_description (doc )
250- self .set_annotations (annotations )
251- self .register_buffer ("tool_config" , tool_config )
252- self .impl = impl # Not a buffer for now
253-
254- @tool_retry
255- @set_tool_attributes (execution_type = "local" )
256- def forward (self , ** kwargs ):
257- if inspect .iscoroutinefunction (self .impl ):
258- return F .wait_for (self .impl , ** kwargs )
259- return self .impl (** kwargs )
260-
261- @tool_retry
262- @aset_tool_attributes (execution_type = "local" )
263- async def aforward (self , * args , ** kwargs ):
264- if hasattr (self .impl , "acall" ):
265- return await self .impl .acall (* args , ** kwargs )
266- elif inspect .iscoroutinefunction (self .impl ):
267- return await self .impl (* args , ** kwargs )
268- # Fall back to sync call
269- return self .impl (* args , ** kwargs )
270- return LocalTool ()
277+ return LocalTool (
278+ name = name ,
279+ description = doc ,
280+ annotations = annotations ,
281+ tool_config = tool_config ,
282+ impl = impl ,
283+ )
271284
272285
273286class ToolLibrary (Module ):
0 commit comments