Skip to content

Commit b58f075

Browse files
Add dspy.Tool.from_mcp_tool (#8130)
* Add util for using mcp tools as DSPy tools Co-authored-by: ThanabordeeN <[email protected]> * rename * update dependency * add from_mcp_tool * fix circular import * test fix --------- Co-authored-by: ThanabordeeN <[email protected]>
1 parent 103d3d7 commit b58f075

File tree

6 files changed

+239
-7
lines changed

6 files changed

+239
-7
lines changed

dspy/primitives/tool.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import asyncio
22
import inspect
3-
from typing import Any, Callable, Optional, get_origin, get_type_hints
3+
from typing import Any, Callable, Optional, get_origin, get_type_hints, TYPE_CHECKING
44

55
from jsonschema import ValidationError, validate
66
from pydantic import BaseModel, TypeAdapter, create_model
77

88
from dspy.utils.callback import with_callbacks
99

10+
if TYPE_CHECKING:
11+
import mcp
1012

1113
class Tool:
1214
"""Tool class.
@@ -176,3 +178,18 @@ async def acall(self, **kwargs):
176178
if not asyncio.iscoroutine(result):
177179
raise ValueError("You are calling `acall` on a non-async tool, please use `__call__` instead.")
178180
return await result
181+
182+
@classmethod
183+
def from_mcp_tool(cls, session: "mcp.client.session.ClientSession", tool: "mcp.types.Tool") -> "Tool":
184+
"""
185+
Build a DSPy tool from an MCP tool and a ClientSession.
186+
187+
Args:
188+
session: The MCP session to use.
189+
tool: The MCP tool to convert.
190+
191+
Returns:
192+
A Tool object.
193+
"""
194+
from dspy.utils.mcp import convert_mcp_tool
195+
return convert_mcp_tool(session, tool)

dspy/utils/mcp.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Any, Tuple, Type, Union, TYPE_CHECKING
2+
from dspy.primitives.tool import Tool
3+
4+
if TYPE_CHECKING:
5+
import mcp
6+
7+
TYPE_MAPPING = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}
8+
9+
10+
def _convert_input_schema_to_tool_args(
11+
schema: dict[str, Any],
12+
) -> Tuple[dict[str, Any], dict[str, Type], dict[str, str]]:
13+
"""Convert an input schema to tool arguments compatible with DSPy Tool.
14+
15+
Args:
16+
schema: An input schema describing the tool's input parameters
17+
18+
Returns:
19+
A tuple of (args, arg_types, arg_desc) for DSPy Tool definition.
20+
"""
21+
args, arg_types, arg_desc = {}, {}, {}
22+
properties = schema.get("properties", None)
23+
if properties is None:
24+
return args, arg_types, arg_desc
25+
26+
required = schema.get("required", [])
27+
28+
for name, prop in properties.items():
29+
args[name] = prop
30+
# MCP tools are validated through jsonschema using args, so arg_types are not strictly required.
31+
arg_types[name] = TYPE_MAPPING.get(prop.get("type"), Any)
32+
arg_desc[name] = prop.get("description", "No description provided.")
33+
if name in required:
34+
arg_desc[name] += " (Required)"
35+
36+
return args, arg_types, arg_desc
37+
38+
39+
def _convert_mcp_tool_result(call_tool_result: "mcp.types.CallToolResult") -> Union[str, list[Any]]:
40+
from mcp.types import (
41+
TextContent,
42+
)
43+
44+
text_contents: list[TextContent] = []
45+
non_text_contents = []
46+
for content in call_tool_result.content:
47+
if isinstance(content, TextContent):
48+
text_contents.append(content)
49+
else:
50+
non_text_contents.append(content)
51+
52+
tool_content = [content.text for content in text_contents]
53+
if len(text_contents) == 1:
54+
tool_content = tool_content[0]
55+
56+
if call_tool_result.isError:
57+
raise RuntimeError(f"Failed to call a MCP tool: {tool_content}")
58+
59+
return tool_content or non_text_contents
60+
61+
62+
def convert_mcp_tool(session: "mcp.client.session.ClientSession", tool: "mcp.types.Tool") -> Tool:
63+
"""Build a DSPy tool from an MCP tool.
64+
65+
Args:
66+
session: The MCP session to use.
67+
tool: The MCP tool to convert.
68+
69+
Returns:
70+
A dspy Tool object.
71+
"""
72+
args, arg_types, arg_desc = _convert_input_schema_to_tool_args(tool.inputSchema)
73+
74+
# Convert the MCP tool and Session to a single async method
75+
async def func(*args, **kwargs):
76+
result = await session.call_tool(tool.name, arguments=kwargs)
77+
return _convert_mcp_tool_result(result)
78+
79+
return Tool(func=func, name=tool.name, desc=tool.description, args=args, arg_types=arg_types, arg_desc=arg_desc)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ dev = [
6262
"build>=1.0.3",
6363
"litellm>=1.60.3; sys_platform == 'win32'",
6464
"litellm[proxy]>=1.60.3; sys_platform != 'win32'",
65+
"mcp>=1.5.0; python_version >= '3.10'",
6566
]
67+
mcp = ["mcp; python_version >= '3.10'"]
6668

6769
[tool.setuptools.packages.find]
6870
where = ["."]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from mcp.server.fastmcp import FastMCP
2+
3+
mcp = FastMCP("test")
4+
5+
6+
@mcp.tool()
7+
def add(a: int, b: int) -> int:
8+
"""Add two numbers"""
9+
return a + b
10+
11+
12+
@mcp.tool()
13+
def hello(names: list[str]) -> str:
14+
"""Greet people"""
15+
return [f"Hello, {name}!" for name in names]
16+
17+
@mcp.tool()
18+
def wrong_tool():
19+
"""This tool raises an error"""
20+
raise ValueError("error!")
21+
22+
if __name__ == "__main__":
23+
mcp.run()

tests/utils/test_mcp.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import asyncio
3+
4+
from mcp import ClientSession, StdioServerParameters
5+
from mcp.client.stdio import stdio_client
6+
7+
from dspy.utils.mcp import convert_mcp_tool
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_convert_mcp_tool():
12+
server_params = StdioServerParameters(
13+
command="python",
14+
args=["tests/utils/resources/mcp_server.py"],
15+
env=None,
16+
)
17+
async with stdio_client(server_params) as (read, write):
18+
async with ClientSession(read, write) as session:
19+
await asyncio.wait_for(session.initialize(), timeout=5)
20+
response = await session.list_tools()
21+
22+
# Check add
23+
add_tool = convert_mcp_tool(session, response.tools[0])
24+
assert add_tool.name == "add"
25+
assert add_tool.desc == "Add two numbers"
26+
assert add_tool.args == {"a": {"title": "A", "type": "integer"}, "b": {"title": "B", "type": "integer"}}
27+
assert add_tool.arg_types == {"a": int, "b": int}
28+
assert add_tool.arg_desc == {
29+
"a": "No description provided. (Required)",
30+
"b": "No description provided. (Required)",
31+
}
32+
assert await add_tool.acall(a=1, b=2) == "3"
33+
34+
# Check hello
35+
hello_tool = convert_mcp_tool(session, response.tools[1])
36+
assert hello_tool.name == "hello"
37+
assert hello_tool.desc == "Greet people"
38+
assert hello_tool.args == {"names": {"title": "Names", "type": "array", "items": {"type": "string"}}}
39+
assert hello_tool.arg_types == {"names": list}
40+
assert hello_tool.arg_desc == {"names": "No description provided. (Required)"}
41+
assert await hello_tool.acall(names=["Bob", "Tom"]) == ["Hello, Bob!", "Hello, Tom!"]
42+
43+
# Check error handling
44+
error_tool = convert_mcp_tool(session, response.tools[2])
45+
assert error_tool.name == "wrong_tool"
46+
assert error_tool.desc == "This tool raises an error"
47+
with pytest.raises(
48+
RuntimeError, match="Failed to call a MCP tool: Error executing tool wrong_tool: error!"
49+
):
50+
await error_tool.acall()

uv.lock

Lines changed: 67 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)