Skip to content

Commit ac2d24c

Browse files
refactor tool
1 parent c42bb34 commit ac2d24c

File tree

4 files changed

+57
-43
lines changed

4 files changed

+57
-43
lines changed

src/msgflux/nn/modules/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from msgflux.nn.modules.module import Module
77
from msgflux.nn.modules.retriever import Retriever
88
from msgflux.nn.modules.speaker import Speaker
9-
from msgflux.nn.modules.tool import Tool, ToolLibrary
10-
11-
ToolBase = Tool
9+
from msgflux.nn.modules.tool import LocalTool, MCPTool, Tool, ToolLibrary
1210
from msgflux.nn.modules.transcriber import Transcriber
1311

1412
__all__ = [
1513
"Agent",
1614
"Embedder",
1715
"LM",
16+
"LocalTool",
17+
"MCPTool",
1818
"MediaMaker",
1919
"Module",
2020
"ModuleDict",
@@ -23,7 +23,6 @@
2323
"Sequential",
2424
"Speaker",
2525
"Tool",
26-
"ToolBase",
2726
"ToolLibrary",
2827
"Transcriber",
2928
]

src/msgflux/nn/modules/lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Union
1+
from typing import Union
22

33
from msgflux.models.gateway import ModelGateway
44
from msgflux.models.response import ModelResponse, ModelStreamResponse
@@ -44,4 +44,4 @@ def _set_model(self, model: Union[ChatCompletionModel, ModelGateway]):
4444
raise TypeError(
4545
f"`model` must be a `chat_completion` model, given `{type(model)}`"
4646
)
47-
self.register_buffer("model", model)
47+
self.register_buffer("model", model)

src/msgflux/nn/modules/module.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -848,8 +848,9 @@ async def _aexecute_input_guardrail(self, model_execution_params: Dict[str, Any]
848848
elif inspect.iscoroutinefunction(input_guardrail):
849849
guardrail_response = await input_guardrail(**guardrail_params)
850850
else:
851-
# Fallback to sync call
852-
guardrail_response = input_guardrail(**guardrail_params)
851+
# Fallback to sync call in executor to avoid blocking event loop
852+
loop = asyncio.get_event_loop()
853+
guardrail_response = await loop.run_in_executor(None, lambda: input_guardrail(**guardrail_params))
853854

854855
if isinstance(guardrail_response, ModelResponse):
855856
guardrail_response = self._extract_raw_response(guardrail_response)
@@ -884,8 +885,9 @@ async def _aexecute_output_guardrail(self, model_response: Dict[str, Any]):
884885
elif inspect.iscoroutinefunction(output_guardrail):
885886
guardrail_response = await output_guardrail(**guardrail_params)
886887
else:
887-
# Fallback to sync call
888-
guardrail_response = output_guardrail(**guardrail_params)
888+
# Fallback to sync call in executor to avoid blocking event loop
889+
loop = asyncio.get_event_loop()
890+
guardrail_response = await loop.run_in_executor(None, lambda: output_guardrail(**guardrail_params))
889891

890892
if isinstance(guardrail_response, ModelResponse):
891893
guardrail_response = self._extract_raw_response(guardrail_response)

src/msgflux/nn/modules/tool.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1+
import asyncio
12
import inspect
23
from dataclasses import asdict, dataclass, field
34
from functools import partial
45
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Tuple
56

67
import msgspec
7-
from opentelemetry import trace
8-
from opentelemetry.trace import Status, StatusCode
98

109
from msgflux.dotdict import dotdict
11-
from msgflux.envs import envs
1210
from msgflux.logger import logger
1311
from msgflux.nn import functional as F
1412
from msgflux.nn.modules.container import ModuleDict
@@ -149,6 +147,41 @@ async def aforward(self, **kwargs) -> Any:
149147
# Extract and return result
150148
return extract_tool_result_text(result)
151149

150+
class LocalTool(Tool):
151+
"""Local tool implementation."""
152+
def __init__(
153+
self,
154+
name: str,
155+
description: str,
156+
annotations: Dict[str, Any],
157+
tool_config: Dict[str, Any],
158+
impl: Callable,
159+
):
160+
super().__init__()
161+
self.set_name(name)
162+
self.set_description(description)
163+
self.set_annotations(annotations)
164+
self.register_buffer("tool_config", tool_config)
165+
self.impl = impl # Not a buffer for now
166+
167+
@tool_retry
168+
@set_tool_attributes(execution_type="local")
169+
def forward(self, **kwargs):
170+
if inspect.iscoroutinefunction(self.impl):
171+
return F.wait_for(self.impl, **kwargs)
172+
return self.impl(**kwargs)
173+
174+
@tool_retry
175+
@aset_tool_attributes(execution_type="local")
176+
async def aforward(self, *args, **kwargs):
177+
if hasattr(self.impl, "acall"):
178+
return await self.impl.acall(*args, **kwargs)
179+
elif inspect.iscoroutinefunction(self.impl):
180+
return await self.impl(*args, **kwargs)
181+
# Fall back to sync call in executor to avoid blocking event loop
182+
loop = asyncio.get_event_loop()
183+
return await loop.run_in_executor(None, lambda: self.impl(*args, **kwargs))
184+
152185
def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
153186
"""Convert a callable in nn.Tool."""
154187
tool_config = impl.__dict__.get("tool_config", dotdict())
@@ -168,7 +201,7 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
168201
or
169202
getattr(impl, "__doc__", None)
170203
or
171-
getattr(impl.__call__, "__doc__", None)
204+
getattr(impl.__call__, "__doc__", None)
172205
)
173206
if doc is None:
174207
raise NotImplementedError(
@@ -183,7 +216,7 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
183216
or
184217
getattr(impl, "__annotations__", None)
185218
or
186-
getattr(impl.__call__, "__annotations__", None)
219+
getattr(impl.__call__, "__annotations__", None)
187220
)
188221
if annotations is None:
189222
if fn_has_parameters(impl.__call__):
@@ -217,7 +250,7 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
217250
)
218251

219252
annotations = impl.__annotations__
220-
253+
221254
if annotations is None:
222255
if fn_has_parameters(impl):
223256
raise NotImplementedError(
@@ -241,33 +274,13 @@ def _convert_module_to_nn_tool(impl: Callable) -> Tool: # noqa: C901
241274
if tool_config.get("background"):
242275
doc = "This tool will run in the background. \n" + doc
243276

244-
class LocalTool(Tool):
245-
"""Local tool implementation."""
246-
def __init__(self):
247-
super().__init__()
248-
self.set_name(name)
249-
self.set_description(doc)
250-
self.set_annotations(annotations)
251-
self.register_buffer("tool_config", tool_config)
252-
self.impl = impl # Not a buffer for now
253-
254-
@tool_retry
255-
@set_tool_attributes(execution_type="local")
256-
def forward(self, **kwargs):
257-
if inspect.iscoroutinefunction(self.impl):
258-
return F.wait_for(self.impl, **kwargs)
259-
return self.impl(**kwargs)
260-
261-
@tool_retry
262-
@aset_tool_attributes(execution_type="local")
263-
async def aforward(self, *args, **kwargs):
264-
if hasattr(self.impl, "acall"):
265-
return await self.impl.acall(*args, **kwargs)
266-
elif inspect.iscoroutinefunction(self.impl):
267-
return await self.impl(*args, **kwargs)
268-
# Fall back to sync call
269-
return self.impl(*args, **kwargs)
270-
return LocalTool()
277+
return LocalTool(
278+
name=name,
279+
description=doc,
280+
annotations=annotations,
281+
tool_config=tool_config,
282+
impl=impl,
283+
)
271284

272285

273286
class ToolLibrary(Module):

0 commit comments

Comments
 (0)