Skip to content

Commit 1c27898

Browse files
support complex args in dspy mcp (#8142)
1 parent 348fedd commit 1c27898

File tree

4 files changed

+88
-33
lines changed

4 files changed

+88
-33
lines changed

dspy/primitives/tool.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,33 +65,6 @@ def foo(x: int, y: str = "hello"):
6565

6666
self._parse_function(func, arg_desc)
6767

68-
def _resolve_pydantic_schema(self, model: type[BaseModel]) -> dict:
69-
"""Recursively resolve Pydantic model schema, expanding all references."""
70-
schema = model.model_json_schema()
71-
72-
# If there are no definitions to resolve, return the main schema
73-
if "$defs" not in schema and "definitions" not in schema:
74-
return schema
75-
76-
def resolve_refs(obj: Any) -> Any:
77-
if not isinstance(obj, (dict, list)):
78-
return obj
79-
80-
if isinstance(obj, dict):
81-
if "$ref" in obj:
82-
ref_path = obj["$ref"].split("/")[-1]
83-
return resolve_refs(schema["$defs"][ref_path])
84-
return {k: resolve_refs(v) for k, v in obj.items()}
85-
86-
# Must be a list
87-
return [resolve_refs(item) for item in obj]
88-
89-
# Resolve all references in the main schema
90-
resolved_schema = resolve_refs(schema)
91-
# Remove the $defs key as it's no longer needed
92-
resolved_schema.pop("$defs", None)
93-
return resolved_schema
94-
9568
def _parse_function(self, func: Callable, arg_desc: dict[str, str] = None):
9669
"""Helper method that parses a function to extract the name, description, and args.
9770
@@ -121,7 +94,7 @@ def _parse_function(self, func: Callable, arg_desc: dict[str, str] = None):
12194
origin = get_origin(v) or v
12295
if isinstance(origin, type) and issubclass(origin, BaseModel):
12396
# Get json schema, and replace $ref with the actual schema
124-
v_json_schema = self._resolve_pydantic_schema(v)
97+
v_json_schema = resolve_json_schema_reference(v.model_json_schema())
12598
args[k] = v_json_schema
12699
else:
127100
args[k] = TypeAdapter(v).json_schema()
@@ -197,3 +170,29 @@ def from_mcp_tool(cls, session: "mcp.client.session.ClientSession", tool: "mcp.t
197170
from dspy.utils.mcp import convert_mcp_tool
198171

199172
return convert_mcp_tool(session, tool)
173+
174+
175+
def resolve_json_schema_reference(schema: dict) -> dict:
176+
"""Recursively resolve json model schema, expanding all references."""
177+
178+
# If there are no definitions to resolve, return the main schema
179+
if "$defs" not in schema and "definitions" not in schema:
180+
return schema
181+
182+
def resolve_refs(obj: Any) -> Any:
183+
if not isinstance(obj, (dict, list)):
184+
return obj
185+
if isinstance(obj, dict):
186+
if "$ref" in obj:
187+
ref_path = obj["$ref"].split("/")[-1]
188+
return resolve_refs(schema["$defs"][ref_path])
189+
return {k: resolve_refs(v) for k, v in obj.items()}
190+
191+
# Must be a list
192+
return [resolve_refs(item) for item in obj]
193+
194+
# Resolve all references in the main schema
195+
resolved_schema = resolve_refs(schema)
196+
# Remove the $defs key as it's no longer needed
197+
resolved_schema.pop("$defs", None)
198+
return resolved_schema

dspy/utils/mcp.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Tuple, Type, Union, TYPE_CHECKING
2-
from dspy.primitives.tool import Tool
1+
from typing import TYPE_CHECKING, Any, Tuple, Type, Union
2+
3+
from dspy.primitives.tool import Tool, resolve_json_schema_reference
34

45
if TYPE_CHECKING:
56
import mcp
@@ -25,7 +26,11 @@ def _convert_input_schema_to_tool_args(
2526

2627
required = schema.get("required", [])
2728

29+
defs = schema.get("$defs", {})
30+
2831
for name, prop in properties.items():
32+
if len(defs) > 0:
33+
prop = resolve_json_schema_reference({"$defs": defs, **prop})
2934
args[name] = prop
3035
# MCP tools are validated through jsonschema using args, so arg_types are not strictly required.
3136
arg_types[name] = TYPE_MAPPING.get(prop.get("type"), Any)
@@ -37,9 +42,7 @@ def _convert_input_schema_to_tool_args(
3742

3843

3944
def _convert_mcp_tool_result(call_tool_result: "mcp.types.CallToolResult") -> Union[str, list[Any]]:
40-
from mcp.types import (
41-
TextContent,
42-
)
45+
from mcp.types import TextContent
4346

4447
text_contents: list[TextContent] = []
4548
non_text_contents = []

tests/utils/resources/mcp_server.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
from mcp.server.fastmcp import FastMCP
2+
from pydantic import BaseModel
23

34
mcp = FastMCP("test")
45

56

7+
class Profile(BaseModel):
8+
name: str
9+
age: int
10+
11+
12+
class Account(BaseModel):
13+
profile: Profile
14+
account_id: str
15+
16+
617
@mcp.tool()
718
def add(a: int, b: int) -> int:
819
"""Add two numbers"""
@@ -14,10 +25,18 @@ def hello(names: list[str]) -> str:
1425
"""Greet people"""
1526
return [f"Hello, {name}!" for name in names]
1627

28+
1729
@mcp.tool()
1830
def wrong_tool():
1931
"""This tool raises an error"""
2032
raise ValueError("error!")
2133

34+
35+
@mcp.tool()
36+
def get_account_name(account: Account):
37+
"""This extracts the name from account"""
38+
return account.profile.name
39+
40+
2241
if __name__ == "__main__":
2342
mcp.run()

tests/utils/test_mcp.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,37 @@ async def test_convert_mcp_tool():
4848
RuntimeError, match="Failed to call a MCP tool: Error executing tool wrong_tool: error!"
4949
):
5050
await error_tool.acall()
51+
52+
# Check nested Pydantic arg
53+
nested_pydantic_tool = convert_mcp_tool(session, response.tools[3])
54+
55+
assert nested_pydantic_tool.name == "get_account_name"
56+
assert nested_pydantic_tool.desc == "This extracts the name from account"
57+
assert nested_pydantic_tool.args == {
58+
"account": {
59+
"title": "Account",
60+
"type": "object",
61+
"required": ["profile", "account_id"],
62+
"properties": {
63+
"profile": {
64+
"title": "Profile",
65+
"type": "object",
66+
"properties": {
67+
"name": {"title": "Name", "type": "string"},
68+
"age": {"title": "Age", "type": "integer"},
69+
},
70+
"required": ["name", "age"],
71+
},
72+
"account_id": {"title": "Account Id", "type": "string"},
73+
},
74+
}
75+
}
76+
account_in_json = {
77+
"profile": {
78+
"name": "Bob",
79+
"age": 20,
80+
},
81+
"account_id": "123",
82+
}
83+
result = await nested_pydantic_tool.acall(account=account_in_json)
84+
assert result == "Bob"

0 commit comments

Comments
 (0)