diff --git a/src/langchain_mcp/toolkit.py b/src/langchain_mcp/toolkit.py index dc0da0a..0a7b311 100644 --- a/src/langchain_mcp/toolkit.py +++ b/src/langchain_mcp/toolkit.py @@ -2,16 +2,18 @@ # SPDX-License-Identifier: MIT import asyncio +import sys import warnings from collections.abc import Callable +from enum import Enum +from typing import Any, TypeVar, Union import pydantic import pydantic_core import typing_extensions as t from langchain_core.tools.base import BaseTool, BaseToolkit, ToolException from mcp import ClientSession, ListToolsResult -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import core_schema as cs +from pydantic import BaseModel, Field, create_model class MCPToolkit(BaseToolkit): @@ -42,59 +44,149 @@ def get_tools(self) -> list[BaseTool]: session=self.session, name=tool.name, description=tool.description or "", - args_schema=create_schema_model(tool.inputSchema), + args_schema=create_model_from_schema(tool.inputSchema, tool.name), ) - # list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools for tool in self._tools.tools ] -TYPEMAP = { +# Define type alias for clarity +JsonSchemaType = type[Any] + +TYPEMAP: dict[str, JsonSchemaType] = { + "string": str, "integer": int, "number": float, - "array": list, "boolean": bool, - "string": str, + "array": list, + "object": dict, "null": type(None), } -FIELD_DEFAULTS = { - int: 0, - float: 0.0, - list: [], - bool: False, - str: "", - type(None): None, -} +def resolve_ref(root_schema: dict[str, Any], ref: str) -> dict[str, Any]: + """Resolve a $ref pointer in the schema""" + if not ref.startswith("#/"): + raise ValueError(f"Only local references supported: {ref}") + + path = ref.lstrip("#/").split("/") + current = root_schema + + for part in path: + if part not in current: + raise ValueError(f"Could not find {part} in schema. Available keys: {list(current.keys())}") + current = current[part] + + return current + + +def get_field_type(root_schema: dict[str, Any], type_def: dict[str, Any]) -> Any: + """Convert JSON schema type definition to Python/Pydantic type""" + # Handle non-dict type definitions (like when additionalProperties is a boolean) + if not isinstance(type_def, dict): + return Any + + if "$ref" in type_def: + referenced_schema = resolve_ref(root_schema, type_def["$ref"]) + # Create a forward reference since the model might not exist yet + return referenced_schema.get("title", "UntitledModel") + + if "enum" in type_def: + # Create an Enum class for this field + enum_name = f"Enum_{hash(str(type_def['enum']))}" + enum_values = {str(v): v for v in type_def["enum"]} + return Enum(enum_name, enum_values) + + if "anyOf" in type_def: + types = [get_field_type(root_schema, t) for t in type_def["anyOf"]] + # Remove None from types list to handle it separately + types = [t for t in types if t is not type(None)] # noqa: E721 + if type(None) in [get_field_type(root_schema, t) for t in type_def["anyOf"]]: + # If None is one of the possible types, make the field optional + if len(types) == 1: + return types[0] | type(None) + return Union[tuple(types + [type(None)])] # noqa: UP007 + if len(types) == 1: + return types[0] + return Union[tuple(types)] # noqa: UP007 + + if "type" not in type_def: + return Any + + type_name = type_def["type"] + if type_name == "array": + if "items" in type_def: + item_type = get_field_type(root_schema, type_def["items"]) + return list[item_type] # type: ignore + return list[Any] + + if type_name == "object": + if "additionalProperties" in type_def: + additional_props = type_def["additionalProperties"] + # Handle case where additionalProperties is a boolean + if isinstance(additional_props, bool): + return dict[str, Any] + # Handle case where additionalProperties is a schema + value_type = get_field_type(root_schema, additional_props) + return dict[str, value_type] # type: ignore + return dict[str, Any] + + return TYPEMAP.get(type_name, Any) + + +ModelType = TypeVar("ModelType", bound=BaseModel) + + +def create_model_from_schema( + schema: dict[str, Any], name: str, root_schema: dict[str, Any] | None = None, created_models: set[str] | None = None +) -> type[ModelType]: + """Create a Pydantic model from a JSON schema definition + + Args: + schema: The schema for this specific model + name: Name for the model + root_schema: The complete schema containing all definitions + created_models: Set to track which models have already been created + """ + # Initialize tracking of created models + if created_models is None: + created_models = set() -def configure_field(name: str, type_: dict[str, t.Any], required: list[str]) -> tuple[type, t.Any]: - field_type = TYPEMAP[type_["type"]] - default_ = FIELD_DEFAULTS.get(field_type) if name not in required else ... - return field_type, default_ + # If root_schema is not provided, use the current schema as root + if root_schema is None: + root_schema = schema + # If we've already created this model, return its class from the module + if name in created_models: + return getattr(sys.modules[__name__], name) -def create_schema_model(schema: dict[str, t.Any]) -> type[pydantic.BaseModel]: - # Create a new model class that returns our JSON schema. - # LangChain requires a BaseModel class. - class SchemaBase(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra="allow") + # Add this model to created_models before processing to handle circular references + created_models.add(name) - @t.override - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: cs.CoreSchema, handler: pydantic.GetJsonSchemaHandler - ) -> JsonSchemaValue: - return schema + # Create referenced models first if we have definitions + if "$defs" in root_schema: + for model_name, model_schema in root_schema["$defs"].items(): + if model_schema.get("type") == "object" and model_name not in created_models: + create_model_from_schema(model_schema, model_name, root_schema, created_models) - # Since this langchain patch, we need to synthesize pydantic fields from the schema - # https://github.com/langchain-ai/langchain/commit/033ac417609297369eb0525794d8b48a425b8b33 + properties = schema.get("properties", {}) required = schema.get("required", []) - fields: dict[str, t.Any] = { - name: configure_field(name, type_, required) for name, type_ in schema["properties"].items() - } - return pydantic.create_model("Schema", __base__=SchemaBase, **fields) + fields: dict[str, tuple[Any, Any]] = {} + for field_name, field_schema in properties.items(): + field_type = get_field_type(root_schema, field_schema) + default = field_schema.get("default", ...) + if field_name not in required and default is ...: + field_type = field_type | type(None) + default = None + + description = field_schema.get("description", "") + fields[field_name] = (field_type, Field(default=default, description=description)) + + model = create_model(name, **fields) # type: ignore + # Add model to the module's namespace so it can be referenced + setattr(sys.modules[__name__], name, model) + return model class MCPTool(BaseTool): @@ -106,7 +198,7 @@ class MCPTool(BaseTool): handle_tool_error: bool | str | Callable[[ToolException], str] | None = True @t.override - def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + def _run(self, *args: Any, **kwargs: Any) -> Any: warnings.warn( "Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.", stacklevel=1, @@ -114,7 +206,7 @@ def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any: return asyncio.run(self._arun(*args, **kwargs)) @t.override - async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + async def _arun(self, *args: Any, **kwargs: Any) -> Any: result = await self.session.call_tool(self.name, arguments=kwargs) content = pydantic_core.to_json(result.content).decode() if result.isError: