|
| 1 | +"""MCP Client setup and management for LangGraph ReAct Agent.""" |
| 2 | + |
| 3 | +import logging |
| 4 | +from typing import Any, Callable, Dict, List, Optional, cast |
| 5 | + |
| 6 | +from langchain_mcp_adapters.client import ( # type: ignore[import-untyped] |
| 7 | + MultiServerMCPClient, |
| 8 | +) |
| 9 | + |
| 10 | +logger = logging.getLogger(__name__) |
| 11 | + |
| 12 | +# Global MCP client and tools cache |
| 13 | +_mcp_client: Optional[MultiServerMCPClient] = None |
| 14 | +_mcp_tools_cache: Dict[str, List[Callable[..., Any]]] = {} |
| 15 | + |
| 16 | +# MCP Server configurations |
| 17 | +MCP_SERVERS = { |
| 18 | + "deepwiki": { |
| 19 | + "url": "https://mcp.deepwiki.com/mcp", |
| 20 | + "transport": "streamable_http", |
| 21 | + }, |
| 22 | + # Add more MCP servers here as needed |
| 23 | + # "context7": { |
| 24 | + # "url": "https://mcp.context7.com/sse", |
| 25 | + # "transport": "sse", |
| 26 | + # }, |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +async def get_mcp_client( |
| 31 | + server_configs: Optional[Dict[str, Any]] = None, |
| 32 | +) -> Optional[MultiServerMCPClient]: |
| 33 | + """Get or initialize the global MCP client with given server configurations.""" |
| 34 | + global _mcp_client |
| 35 | + |
| 36 | + if _mcp_client is None: |
| 37 | + configs = server_configs or MCP_SERVERS |
| 38 | + try: |
| 39 | + _mcp_client = MultiServerMCPClient(configs) # pyright: ignore[reportArgumentType] |
| 40 | + logger.info(f"Initialized MCP client with servers: {list(configs.keys())}") |
| 41 | + except Exception as e: |
| 42 | + logger.error("Failed to initialize MCP client: %s", e) |
| 43 | + return None |
| 44 | + return _mcp_client |
| 45 | + |
| 46 | + |
| 47 | +async def get_mcp_tools(server_name: str) -> List[Callable[..., Any]]: |
| 48 | + """Get MCP tools for a specific server, initializing client if needed.""" |
| 49 | + global _mcp_tools_cache |
| 50 | + |
| 51 | + # Return cached tools if available |
| 52 | + if server_name in _mcp_tools_cache: |
| 53 | + return _mcp_tools_cache[server_name] |
| 54 | + |
| 55 | + try: |
| 56 | + client = await get_mcp_client() |
| 57 | + if client is None: |
| 58 | + _mcp_tools_cache[server_name] = [] |
| 59 | + return [] |
| 60 | + |
| 61 | + # Get all tools and filter by server (if tools have server metadata) |
| 62 | + all_tools = await client.get_tools() |
| 63 | + tools = cast(List[Callable[..., Any]], all_tools) |
| 64 | + |
| 65 | + _mcp_tools_cache[server_name] = tools |
| 66 | + logger.info(f"Loaded {len(tools)} tools from MCP server '{server_name}'") |
| 67 | + return tools |
| 68 | + except Exception as e: |
| 69 | + logger.warning(f"Failed to load tools from MCP server '{server_name}': %s", e) |
| 70 | + _mcp_tools_cache[server_name] = [] |
| 71 | + return [] |
| 72 | + |
| 73 | + |
| 74 | +async def get_deepwiki_tools() -> List[Callable[..., Any]]: |
| 75 | + """Get DeepWiki MCP tools.""" |
| 76 | + return await get_mcp_tools("deepwiki") |
| 77 | + |
| 78 | + |
| 79 | +async def get_all_mcp_tools() -> List[Callable[..., Any]]: |
| 80 | + """Get all tools from all configured MCP servers.""" |
| 81 | + all_tools = [] |
| 82 | + for server_name in MCP_SERVERS.keys(): |
| 83 | + tools = await get_mcp_tools(server_name) |
| 84 | + all_tools.extend(tools) |
| 85 | + return all_tools |
| 86 | + |
| 87 | + |
| 88 | +def add_mcp_server(name: str, config: Dict[str, Any]) -> None: |
| 89 | + """Add a new MCP server configuration.""" |
| 90 | + MCP_SERVERS[name] = config |
| 91 | + # Clear client to force reinitialization with new config |
| 92 | + clear_mcp_cache() |
| 93 | + |
| 94 | + |
| 95 | +def clear_mcp_cache() -> None: |
| 96 | + """Clear the MCP client and tools cache (useful for testing).""" |
| 97 | + global _mcp_client, _mcp_tools_cache |
| 98 | + _mcp_client = None |
| 99 | + _mcp_tools_cache = {} |
0 commit comments