@@ -65,33 +65,6 @@ def foo(x: int, y: str = "hello"):
6565
6666 self ._parse_function (func , arg_desc )
6767
68- def _resolve_pydantic_schema (self , model : type [BaseModel ]) -> dict :
69- """Recursively resolve Pydantic model schema, expanding all references."""
70- schema = model .model_json_schema ()
71-
72- # If there are no definitions to resolve, return the main schema
73- if "$defs" not in schema and "definitions" not in schema :
74- return schema
75-
76- def resolve_refs (obj : Any ) -> Any :
77- if not isinstance (obj , (dict , list )):
78- return obj
79-
80- if isinstance (obj , dict ):
81- if "$ref" in obj :
82- ref_path = obj ["$ref" ].split ("/" )[- 1 ]
83- return resolve_refs (schema ["$defs" ][ref_path ])
84- return {k : resolve_refs (v ) for k , v in obj .items ()}
85-
86- # Must be a list
87- return [resolve_refs (item ) for item in obj ]
88-
89- # Resolve all references in the main schema
90- resolved_schema = resolve_refs (schema )
91- # Remove the $defs key as it's no longer needed
92- resolved_schema .pop ("$defs" , None )
93- return resolved_schema
94-
9568 def _parse_function (self , func : Callable , arg_desc : dict [str , str ] = None ):
9669 """Helper method that parses a function to extract the name, description, and args.
9770
@@ -121,7 +94,7 @@ def _parse_function(self, func: Callable, arg_desc: dict[str, str] = None):
12194 origin = get_origin (v ) or v
12295 if isinstance (origin , type ) and issubclass (origin , BaseModel ):
12396 # Get json schema, and replace $ref with the actual schema
124- v_json_schema = self . _resolve_pydantic_schema ( v )
97+ v_json_schema = resolve_json_schema_reference ( v . model_json_schema () )
12598 args [k ] = v_json_schema
12699 else :
127100 args [k ] = TypeAdapter (v ).json_schema ()
@@ -197,3 +170,29 @@ def from_mcp_tool(cls, session: "mcp.client.session.ClientSession", tool: "mcp.t
197170 from dspy .utils .mcp import convert_mcp_tool
198171
199172 return convert_mcp_tool (session , tool )
173+
174+
175+ def resolve_json_schema_reference (schema : dict ) -> dict :
176+ """Recursively resolve json model schema, expanding all references."""
177+
178+ # If there are no definitions to resolve, return the main schema
179+ if "$defs" not in schema and "definitions" not in schema :
180+ return schema
181+
182+ def resolve_refs (obj : Any ) -> Any :
183+ if not isinstance (obj , (dict , list )):
184+ return obj
185+ if isinstance (obj , dict ):
186+ if "$ref" in obj :
187+ ref_path = obj ["$ref" ].split ("/" )[- 1 ]
188+ return resolve_refs (schema ["$defs" ][ref_path ])
189+ return {k : resolve_refs (v ) for k , v in obj .items ()}
190+
191+ # Must be a list
192+ return [resolve_refs (item ) for item in obj ]
193+
194+ # Resolve all references in the main schema
195+ resolved_schema = resolve_refs (schema )
196+ # Remove the $defs key as it's no longer needed
197+ resolved_schema .pop ("$defs" , None )
198+ return resolved_schema
0 commit comments