Skip to content

Commit 2936f84

Browse files
committed
feat: 添加 SqlReporterAgent,支持生成 SQL 查询报告并调用图表工具
1 parent a8491c5 commit 2936f84

File tree

8 files changed

+94
-19
lines changed

8 files changed

+94
-19
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
语析是一个功能强大的智能问答平台,融合了 RAG 知识库与知识图谱技术,基于 LangGraph + Vue.js + FastAPI + LightRAG 架构构建。
1919

20+
---
21+
2022
🙏 感谢 Star ~ ⭐⭐⭐⭐⭐
2123

2224
详细文档请查看全新的 [**📄文档中心**](https://xerrors.github.io/Yuxi-Know/)[📽️ 点击查看视频演示 v0.2](https://www.bilibili.com/video/BV1ETedzREgY/?share_source=copy_web&vd_source=37b0bdbf95b72ea38b2dc959cfadc4d8)

server/routers/chat_router.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,11 @@ async def get_tools(agent_id: str, current_user: User = Depends(get_required_use
408408
if not (agent := agent_manager.get_agent(agent_id)):
409409
raise HTTPException(status_code=404, detail=f"智能体 {agent_id} 不存在")
410410

411-
if hasattr(agent, "get_tools"):
412-
tools = agent.get_tools()
411+
if hasattr(agent, "get_tools") and callable(agent.get_tools):
412+
if asyncio.iscoroutinefunction(agent.get_tools):
413+
tools = await agent.get_tools()
414+
else:
415+
tools = agent.get_tools()
413416
else:
414417
tools = get_buildin_tools()
415418

src/agents/common/mcp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,19 @@ async def get_mcp_client(
5858
return None
5959

6060

61-
async def get_mcp_tools(server_name: str) -> list[Callable[..., Any]]:
61+
async def get_mcp_tools(server_name: str, additional_servers: dict[str, dict] = None) -> list[Callable[..., Any]]:
6262
"""Get MCP tools for a specific server, initializing client if needed."""
6363
global _mcp_tools_cache
6464

6565
# Return cached tools if available
6666
if server_name in _mcp_tools_cache:
6767
return _mcp_tools_cache[server_name]
6868

69+
mcp_servers = MCP_SERVERS | (additional_servers or {})
70+
6971
try:
70-
assert server_name in MCP_SERVERS, f"Server {server_name} not found in MCP_SERVERS"
71-
client = await get_mcp_client({server_name: MCP_SERVERS[server_name]})
72+
assert server_name in mcp_servers, f"Server {server_name} not found in MCP_SERVERS"
73+
client = await get_mcp_client({server_name: mcp_servers[server_name]})
7274
if client is None:
7375
return []
7476

@@ -86,7 +88,6 @@ async def get_mcp_tools(server_name: str) -> list[Callable[..., Any]]:
8688
logger.error(f"Failed to load tools from MCP server '{server_name}': {e}")
8789
return []
8890

89-
9091
async def get_all_mcp_tools() -> list[Callable[..., Any]]:
9192
"""Get all tools from all configured MCP servers."""
9293
all_tools = []

src/agents/common/toolkits/mysql/tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TableListModel(BaseModel):
4444
pass
4545

4646

47-
@tool(args_schema=TableListModel)
47+
@tool(name_or_callable="查询表名", args_schema=TableListModel)
4848
def mysql_list_tables() -> str:
4949
"""获取数据库中的所有表名
5050
@@ -94,7 +94,7 @@ class TableDescribeModel(BaseModel):
9494
table_name: str = Field(description="要查询的表名", example="users")
9595

9696

97-
@tool(args_schema=TableDescribeModel)
97+
@tool(name_or_callable="描述表", args_schema=TableDescribeModel)
9898
def mysql_describe_table(table_name: Annotated[str, "要查询结构的表名"]) -> str:
9999
"""获取指定表的详细结构信息
100100
@@ -168,7 +168,7 @@ class QueryModel(BaseModel):
168168
timeout: int | None = Field(default=10, description="查询超时时间(秒),默认10秒,最大60秒", ge=1, le=60)
169169

170170

171-
@tool(args_schema=QueryModel)
171+
@tool(name_or_callable="执行 SQL 查询", args_schema=QueryModel)
172172
def mysql_query(
173173
sql: Annotated[str, "要执行的SQL查询语句(只能是SELECT语句)"],
174174
limit: Annotated[int | None, "限制返回的最大行数,默认100,最大1000"] = 100,

src/agents/common/tools.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def gen_tool_info(tools) -> list[dict[str, Any]]:
146146
}
147147

148148
if hasattr(tool_obj, "args_schema") and tool_obj.args_schema:
149-
schema = tool_obj.args_schema.schema()
149+
if isinstance(tool_obj.args_schema, dict):
150+
schema = tool_obj.args_schema
151+
else:
152+
schema = tool_obj.args_schema.schema()
153+
150154
for arg_name, arg_info in schema.get("properties", {}).items():
151155
info["args"].append(
152156
{
@@ -161,7 +165,8 @@ def gen_tool_info(tools) -> list[dict[str, Any]]:
161165

162166
except Exception as e:
163167
logger.error(
164-
f"Failed to process tool {getattr(tool_obj, 'name', 'unknown')}: {e}\n{traceback.format_exc()}"
168+
f"Failed to process tool {getattr(tool_obj, 'name', 'unknown')}: {e}\n{traceback.format_exc()}. "
169+
f"Details: {dict(tool_obj.__dict__)}"
165170
)
166171
continue
167172

src/agents/react/graph.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33
from langchain.agents import create_agent
44
from langchain.agents.middleware import ModelRequest, ModelResponse, dynamic_prompt, wrap_model_call
55

6-
from src import config as sys_config
76
from src.agents.common.base import BaseAgent
87
from src.agents.common.models import load_chat_model
98
from src.agents.common.tools import get_buildin_tools
109
from src.utils import logger
1110

12-
model = load_chat_model("siliconflow/Qwen/Qwen3-235B-A22B-Instruct-2507")
13-
1411

1512
@dynamic_prompt
1613
def context_aware_prompt(request: ModelRequest) -> str:
@@ -34,9 +31,6 @@ class ReActAgent(BaseAgent):
3431

3532
def __init__(self, **kwargs):
3633
super().__init__(**kwargs)
37-
self.graph = None
38-
self.workdir = Path(sys_config.save_dir) / "agents" / self.module_name
39-
self.workdir.mkdir(parents=True, exist_ok=True)
4034

4135
def get_tools(self):
4236
return get_buildin_tools()
@@ -47,12 +41,11 @@ async def get_graph(self, **kwargs):
4741

4842
# 创建 ReActAgent
4943
graph = create_agent(
50-
model=model,
44+
model=load_chat_model("siliconflow/Qwen/Qwen3-235B-A22B-Instruct-2507"), # 实际会被覆盖
5145
tools=self.get_tools(),
5246
middleware=[context_aware_prompt, context_based_model],
5347
checkpointer=await self._get_checkpointer(),
5448
)
5549

5650
self.graph = graph
57-
logger.info("ReActAgent 使用内存 checkpointer 构建成功")
5851
return graph

src/agents/reporter/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .graph import SqlReporterAgent
2+
3+
4+
__all__ = ["SqlReporterAgent"]

src/agents/reporter/graph.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import textwrap
2+
from pathlib import Path
3+
4+
from langchain.agents import create_agent
5+
from langchain.agents.middleware import ModelRequest, ModelResponse, dynamic_prompt, wrap_model_call
6+
7+
from src.agents.common.base import BaseAgent
8+
from src.agents.common.models import load_chat_model
9+
from src.agents.common.mcp import get_mcp_tools
10+
from src.agents.common.toolkits.mysql import get_mysql_tools
11+
from src.utils import logger
12+
13+
_mcp_servers = {
14+
"mcp-server-chart": {
15+
"url": "https://mcp.api-inference.modelscope.net/9993ae42524c4c/mcp",
16+
"transport": "streamable_http",
17+
},
18+
}
19+
20+
@dynamic_prompt
21+
def context_aware_prompt(request: ModelRequest) -> str:
22+
user_prompt = request.runtime.context.system_prompt
23+
agent_prompt = user_prompt + textwrap.dedent("""
24+
You are an SQL reporting assistant. Your task is to generate SQL queries based on user requests
25+
and provide insights from the database. Use the tools provided to you to answer the questions.
26+
""")
27+
28+
return agent_prompt
29+
30+
31+
@wrap_model_call
32+
async def context_based_model(request: ModelRequest, handler) -> ModelResponse:
33+
# 从 runtime context 读取配置
34+
model_spec = request.runtime.context.model
35+
model = load_chat_model(model_spec)
36+
37+
request = request.override(model=model)
38+
return await handler(request)
39+
40+
41+
class SqlReporterAgent(BaseAgent):
42+
name = "SQL 报告助手"
43+
description = "一个能够生成 SQL 查询报告的智能体助手。同时调用 Charts MCP 生成图表。"
44+
45+
def __init__(self, **kwargs):
46+
super().__init__(**kwargs)
47+
48+
async def get_tools(self):
49+
chart_tools = await get_mcp_tools("mcp-server-chart", additional_servers=_mcp_servers)
50+
mysql_tools = get_mysql_tools()
51+
return chart_tools + mysql_tools
52+
53+
async def get_graph(self, **kwargs):
54+
if self.graph:
55+
return self.graph
56+
57+
# 创建 SqlReporterAgent
58+
graph = create_agent(
59+
model=load_chat_model("siliconflow/Qwen/Qwen3-235B-A22B-Instruct-2507"),
60+
tools=await self.get_tools(),
61+
middleware=[context_aware_prompt, context_based_model],
62+
checkpointer=await self._get_checkpointer(),
63+
)
64+
65+
self.graph = graph
66+
logger.info("SqlReporterAgent 构建成功")
67+
return graph

0 commit comments

Comments
 (0)