Skip to content

Commit e04dd62

Browse files
committed
✨ feat(common): integrate DeepWiki MCP tools with dynamic tool loading
- Add enable_deepwiki flag to Context with env var support and default qwen-turbo model - Implement MCP client management and tool retrieval for DeepWiki MCP server - Dynamically load DeepWiki tools in common.tools.get_tools if enabled in context - Modify agent graph to use dynamic tools node executing tools based on current config - Update Makefile test target to run unit and integration tests - Add langchain-mcp-adapters dependency to pyproject.toml - Refactor test fixtures and e2e tests to cover DeepWiki MCP tool usage strictly - Improve test react agent suite for tool usage, streaming, and context persistence
1 parent 9da47ce commit e04dd62

File tree

18 files changed

+1120
-452
lines changed

18 files changed

+1120
-452
lines changed

Makefile

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@ all: help
77
# TESTING
88
######################
99

10-
# Legacy test command (defaults to unit tests for backward compatibility)
11-
test:
12-
python -m pytest tests/unit_tests/
10+
# Legacy test command (defaults to unit and integration tests for backward compatibility)
11+
test: test_unit test_integration
1312

1413
# Specific test targets
1514
test_unit:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"python-dotenv>=1.0.1",
1717
"langchain-tavily>=0.1",
1818
"langchain-qwq>=0.2.1",
19+
"langchain-mcp-adapters>=0.1.9",
1920
]
2021

2122

src/common/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
from . import prompts
44
from .context import Context
55
from .models import create_qwen_model, get_supported_qwen_models
6-
from .tools import TOOLS, web_search
6+
from .tools import web_search
77
from .utils import load_chat_model
88

99
__all__ = [
1010
"Context",
1111
"create_qwen_model",
1212
"get_supported_qwen_models",
13-
"TOOLS",
1413
"web_search",
1514
"load_chat_model",
1615
"prompts",

src/common/context.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from __future__ import annotations
44

5-
import os
6-
from dataclasses import dataclass, field, fields
5+
from dataclasses import dataclass, field
76
from typing import Annotated
87

98
from . import prompts
@@ -22,7 +21,7 @@ class Context:
2221
)
2322

2423
model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
25-
default="qwen:qwen-plus",
24+
default="qwen:qwen-turbo",
2625
metadata={
2726
"description": "The name of the language model to use for the agent's main interactions. "
2827
"Should be in the form: provider:model-name."
@@ -36,11 +35,33 @@ class Context:
3635
},
3736
)
3837

38+
enable_deepwiki: bool = field(
39+
default=False,
40+
metadata={
41+
"description": "Whether to enable the DeepWiki MCP tool for accessing open source project documentation."
42+
},
43+
)
44+
3945
def __post_init__(self) -> None:
4046
"""Fetch env vars for attributes that were not passed as args."""
47+
import os
48+
from dataclasses import fields
49+
4150
for f in fields(self):
4251
if not f.init:
4352
continue
4453

45-
if getattr(self, f.name) == f.default:
46-
setattr(self, f.name, os.environ.get(f.name.upper(), f.default))
54+
current_value = getattr(self, f.name)
55+
default_value = f.default
56+
env_var_name = f.name.upper()
57+
env_value = os.environ.get(env_var_name)
58+
59+
# Only override with environment variable if current value equals default
60+
# This preserves explicit configuration from LangGraph configurable
61+
if current_value == default_value and env_value is not None:
62+
if isinstance(default_value, bool):
63+
# Handle boolean environment variables
64+
env_bool_value = env_value.lower() in ("true", "1", "yes", "on")
65+
setattr(self, f.name, env_bool_value)
66+
else:
67+
setattr(self, f.name, env_value)

src/common/mcp.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""MCP Client setup and management for LangGraph ReAct Agent."""
2+
3+
import logging
4+
from typing import Any, Callable, Dict, List, Optional, cast
5+
6+
from langchain_mcp_adapters.client import ( # type: ignore[import-untyped]
7+
MultiServerMCPClient,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
# Global MCP client and tools cache
13+
_mcp_client: Optional[MultiServerMCPClient] = None
14+
_mcp_tools_cache: Dict[str, List[Callable[..., Any]]] = {}
15+
16+
# MCP Server configurations
17+
MCP_SERVERS = {
18+
"deepwiki": {
19+
"url": "https://mcp.deepwiki.com/mcp",
20+
"transport": "streamable_http",
21+
},
22+
# Add more MCP servers here as needed
23+
# "context7": {
24+
# "url": "https://mcp.context7.com/sse",
25+
# "transport": "sse",
26+
# },
27+
}
28+
29+
30+
async def get_mcp_client(
31+
server_configs: Optional[Dict[str, Any]] = None,
32+
) -> Optional[MultiServerMCPClient]:
33+
"""Get or initialize the global MCP client with given server configurations."""
34+
global _mcp_client
35+
36+
if _mcp_client is None:
37+
configs = server_configs or MCP_SERVERS
38+
try:
39+
_mcp_client = MultiServerMCPClient(configs) # pyright: ignore[reportArgumentType]
40+
logger.info(f"Initialized MCP client with servers: {list(configs.keys())}")
41+
except Exception as e:
42+
logger.error("Failed to initialize MCP client: %s", e)
43+
return None
44+
return _mcp_client
45+
46+
47+
async def get_mcp_tools(server_name: str) -> List[Callable[..., Any]]:
48+
"""Get MCP tools for a specific server, initializing client if needed."""
49+
global _mcp_tools_cache
50+
51+
# Return cached tools if available
52+
if server_name in _mcp_tools_cache:
53+
return _mcp_tools_cache[server_name]
54+
55+
try:
56+
client = await get_mcp_client()
57+
if client is None:
58+
_mcp_tools_cache[server_name] = []
59+
return []
60+
61+
# Get all tools and filter by server (if tools have server metadata)
62+
all_tools = await client.get_tools()
63+
tools = cast(List[Callable[..., Any]], all_tools)
64+
65+
_mcp_tools_cache[server_name] = tools
66+
logger.info(f"Loaded {len(tools)} tools from MCP server '{server_name}'")
67+
return tools
68+
except Exception as e:
69+
logger.warning(f"Failed to load tools from MCP server '{server_name}': %s", e)
70+
_mcp_tools_cache[server_name] = []
71+
return []
72+
73+
74+
async def get_deepwiki_tools() -> List[Callable[..., Any]]:
75+
"""Get DeepWiki MCP tools."""
76+
return await get_mcp_tools("deepwiki")
77+
78+
79+
async def get_all_mcp_tools() -> List[Callable[..., Any]]:
80+
"""Get all tools from all configured MCP servers."""
81+
all_tools = []
82+
for server_name in MCP_SERVERS.keys():
83+
tools = await get_mcp_tools(server_name)
84+
all_tools.extend(tools)
85+
return all_tools
86+
87+
88+
def add_mcp_server(name: str, config: Dict[str, Any]) -> None:
89+
"""Add a new MCP server configuration."""
90+
MCP_SERVERS[name] = config
91+
# Clear client to force reinitialization with new config
92+
clear_mcp_cache()
93+
94+
95+
def clear_mcp_cache() -> None:
96+
"""Clear the MCP client and tools cache (useful for testing)."""
97+
global _mcp_client, _mcp_tools_cache
98+
_mcp_client = None
99+
_mcp_tools_cache = {}

src/common/tools.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
consider implementing more robust and specialized tools tailored to your needs.
77
"""
88

9+
import logging
910
from typing import Any, Callable, List, Optional, cast
1011

1112
from langchain_tavily import TavilySearch
1213
from langgraph.runtime import get_runtime
1314

1415
from common.context import Context
16+
from common.mcp import get_deepwiki_tools
17+
18+
logger = logging.getLogger(__name__)
1519

1620

1721
async def web_search(query: str) -> Optional[dict[str, Any]]:
@@ -26,4 +30,15 @@ async def web_search(query: str) -> Optional[dict[str, Any]]:
2630
return cast(dict[str, Any], await wrapped.ainvoke({"query": query}))
2731

2832

29-
TOOLS: List[Callable[..., Any]] = [web_search]
33+
async def get_tools() -> List[Callable[..., Any]]:
34+
"""Get all available tools based on configuration."""
35+
tools = [web_search]
36+
37+
runtime = get_runtime(Context)
38+
39+
if runtime.context.enable_deepwiki:
40+
deepwiki_tools = await get_deepwiki_tools()
41+
tools.extend(deepwiki_tools)
42+
logger.info(f"Loaded {len(deepwiki_tools)} deepwiki tools")
43+
44+
return tools

src/react_agent/graph.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from datetime import UTC, datetime
77
from typing import Dict, List, Literal, cast
88

9-
from langchain_core.messages import AIMessage
9+
from langchain_core.messages import AIMessage, ToolMessage
1010
from langgraph.graph import StateGraph
1111
from langgraph.prebuilt import ToolNode
1212
from langgraph.runtime import Runtime
1313

1414
from common.context import Context
15-
from common.tools import TOOLS
15+
from common.tools import get_tools
1616
from common.utils import load_chat_model
1717
from react_agent.state import InputState, State
1818

@@ -33,8 +33,11 @@ async def call_model(
3333
Returns:
3434
dict: A dictionary containing the model's response message.
3535
"""
36+
# Get available tools based on configuration
37+
available_tools = await get_tools()
38+
3639
# Initialize the model with tool binding. Change the model or add more tools here.
37-
model = load_chat_model(runtime.context.model).bind_tools(TOOLS)
40+
model = load_chat_model(runtime.context.model).bind_tools(available_tools)
3841

3942
# Format the system prompt. Customize this to change the agent's behavior.
4043
system_message = runtime.context.system_prompt.format(
@@ -64,13 +67,33 @@ async def call_model(
6467
return {"messages": [response]}
6568

6669

70+
async def dynamic_tools_node(
71+
state: State, runtime: Runtime[Context]
72+
) -> Dict[str, List[ToolMessage]]:
73+
"""Execute tools dynamically based on configuration.
74+
75+
This function gets the available tools based on the current configuration
76+
and executes the requested tool calls from the last message.
77+
"""
78+
# Get available tools based on configuration
79+
available_tools = await get_tools()
80+
81+
# Create a ToolNode with the available tools
82+
tool_node = ToolNode(available_tools)
83+
84+
# Execute the tool node
85+
result = await tool_node.ainvoke(state)
86+
87+
return cast(Dict[str, List[ToolMessage]], result)
88+
89+
6790
# Define a new graph
6891

6992
builder = StateGraph(State, input_schema=InputState, context_schema=Context)
7093

7194
# Define the two nodes we will cycle between
7295
builder.add_node(call_model)
73-
builder.add_node("tools", ToolNode(TOOLS))
96+
builder.add_node("tools", dynamic_tools_node)
7497

7598
# Set the entrypoint as `call_model`
7699
# This means that this node is the first one called

tests/conftest.py

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,12 @@
22

33
import os
44
from pathlib import Path
5-
from unittest.mock import AsyncMock, Mock
65

76
import pytest
87
from dotenv import load_dotenv
9-
from langchain_core.messages import AIMessage, HumanMessage
8+
from langchain_core.messages import HumanMessage
109
from langgraph_sdk import get_client
1110

12-
from tests.test_data import TestModels, TestUrls
13-
1411

1512
@pytest.fixture(scope="session", autouse=True)
1613
def load_env():
@@ -24,51 +21,17 @@ def load_env():
2421

2522
# Ensure required environment variables are available for tests
2623
# You can add fallback values or skip tests if keys are missing
27-
required_keys = ["OPENAI_API_KEY", "TAVILY_API_KEY"]
24+
required_keys = ["DASHSCOPE_API_KEY", "TAVILY_API_KEY"]
2825
missing_keys = [key for key in required_keys if not os.getenv(key)]
2926

3027
if missing_keys:
3128
pytest.skip(f"Missing required environment variables: {missing_keys}")
3229

3330

34-
@pytest.fixture
35-
def mock_model():
36-
"""Create a mock model for testing."""
37-
mock = Mock()
38-
mock.bind_tools.return_value = mock
39-
mock.ainvoke = AsyncMock()
40-
return mock
41-
42-
43-
@pytest.fixture
44-
def sample_ai_response():
45-
"""Create a sample AI response for testing."""
46-
return AIMessage(
47-
content="This is a test response from the AI model.", tool_calls=[]
48-
)
49-
50-
51-
@pytest.fixture
52-
def sample_human_message():
53-
"""Create a sample human message for testing."""
54-
return HumanMessage(content="This is a test message from human.")
55-
56-
57-
@pytest.fixture
58-
def test_context():
59-
"""Create a test context with common configuration."""
60-
from common.context import Context
61-
62-
return Context(
63-
model=TestModels.QWEN_PLUS,
64-
system_prompt="You are a helpful AI assistant for testing.",
65-
)
66-
67-
6831
@pytest.fixture
6932
async def langgraph_client():
7033
"""Create a LangGraph client for e2e testing."""
71-
return get_client(url=TestUrls.LANGGRAPH_LOCAL)
34+
return get_client(url="http://127.0.0.1:2024")
7235

7336

7437
@pytest.fixture
@@ -80,14 +43,6 @@ async def assistant_id(langgraph_client):
8043
return assistants[0]["assistant_id"]
8144

8245

83-
@pytest.fixture
84-
async def test_thread(langgraph_client):
85-
"""Create a test thread for e2e testing."""
86-
thread = await langgraph_client.threads.create()
87-
yield thread
88-
# Cleanup could be added here if needed
89-
90-
9146
class TestHelpers:
9247
"""Helper methods for common test operations."""
9348

0 commit comments

Comments
 (0)