Skip to content

Commit 1a753ef

Browse files
authored
Merge pull request #114 from apconw/dev
v1.1.7
2 parents 1b44250 + bfde1e2 commit 1a753ef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+3078
-939
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include web/Makefile
55
SERVER_PROJECT_NAME = sanic-web
66

77
# 服务端 Docker 镜像标签
8-
SERVER_DOCKER_IMAGE = apconw/$(SERVER_PROJECT_NAME):1.1.6
8+
SERVER_DOCKER_IMAGE = apconw/$(SERVER_PROJECT_NAME):1.1.7
99

1010
# 阿里云镜像仓库地址 (需要根据实际情况修改)
1111
ALIYUN_REGISTRY = crpi-7xkxsdc0iki61l0q.cn-hangzhou.personal.cr.aliyuncs.com
@@ -27,7 +27,7 @@ docker-build-server-multi:
2727

2828

2929
# 构建服务端arm64/amd64架构镜像并推送至阿里云镜像仓库
30-
docker-build-aliyun-multi:
31-
docker buildx build --platform linux/amd64,linux/arm64 --push -t $(ALIYUN_IMAGE_NAME):1.1.6 -f ./docker/Dockerfile .
30+
docker-build-aliyun-server-multi:
31+
docker buildx build --platform linux/amd64,linux/arm64 --push -t $(ALIYUN_IMAGE_NAME):1.1.7 -f ./docker/Dockerfile .
3232

3333
.PHONY: web-build service-build

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393

9494
| 技术支持方式 | 赞助 |
9595
|:------------------------------------------------|:----------:|
96-
| 一对一技术支持 我将亲自远程帮您 **配置环境并部署** **讲解项目架构&大模型学习资料** | **300元/次** |
96+
| 一对一技术支持 我将亲自远程帮您 **配置环境并部署** **讲解项目架构&大模型学习资料** | **100元/次** |
9797
| 需求开发支持 **具体场景Dify画布开发** **下面开源Dify画布前后端适配开发** | **500元起** |
9898

9999

agent/langgraph_react_agent.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import traceback
66
from typing import Optional
7+
from uuid import uuid4
78

89
from langchain_core.messages import SystemMessage, HumanMessage
910
from langchain_core.messages.utils import trim_messages
@@ -13,7 +14,7 @@
1314
from langgraph.prebuilt import create_react_agent
1415

1516
from constants.code_enum import DataTypeEnum, DiFyAppEnum
16-
from services.user_service import add_user_record
17+
from services.user_service import add_user_record, decode_jwt_token
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -83,6 +84,9 @@ def __init__(self):
8384
# 全局checkpointer用于持久化所有用户的对话状态
8485
self.checkpointer = InMemorySaver()
8586

87+
# 存储运行中的任务
88+
self.running_tasks = {}
89+
8690
@staticmethod
8791
def _create_response(
8892
content: str, message_type: str = "continue", data_type: str = DataTypeEnum.ANSWER.value[0]
@@ -126,6 +130,13 @@ async def run_agent(
126130
:param user_token:
127131
:return:
128132
"""
133+
134+
# 获取用户信息 标识对话状态
135+
user_dict = await decode_jwt_token(user_token)
136+
task_id = user_dict["id"]
137+
task_context = {"cancelled": False}
138+
self.running_tasks[task_id] = task_context
139+
129140
try:
130141
t02_answer_data = []
131142

@@ -150,6 +161,15 @@ async def run_agent(
150161
config=config,
151162
stream_mode="messages",
152163
):
164+
# 检查是否已取消
165+
if self.running_tasks[task_id]["cancelled"]:
166+
await response.write(
167+
self._create_response("\n> 这条消息已停止", "info", DataTypeEnum.ANSWER.value[0])
168+
)
169+
# 发送最终停止确认消息
170+
await response.write(self._create_response("", "end", DataTypeEnum.STREAM_END.value[0]))
171+
break
172+
153173
# print(message_chunk)
154174
# 工具输出
155175
if metadata["langgraph_node"] == "tools":
@@ -172,12 +192,40 @@ async def run_agent(
172192
await response.flush()
173193
await asyncio.sleep(0)
174194

175-
await add_user_record(
176-
uuid_str, session_id, query, t02_answer_data, {}, DiFyAppEnum.COMMON_QA.value[0], user_token
177-
)
195+
# 只有在未取消的情况下才保存记录
196+
if not self.running_tasks[task_id]["cancelled"]:
197+
await add_user_record(
198+
uuid_str, session_id, query, t02_answer_data, {}, DiFyAppEnum.COMMON_QA.value[0], user_token
199+
)
200+
201+
except asyncio.CancelledError:
202+
await response.write(self._create_response("\n> 这条消息已停止", "info", DataTypeEnum.ANSWER.value[0]))
203+
await response.write(self._create_response("", "end", DataTypeEnum.STREAM_END.value[0]))
178204
except Exception as e:
179205
print(f"[ERROR] Agent运行异常: {e}")
180206
traceback.print_exception(e)
181207
await response.write(
182208
self._create_response("[ERROR] 智能体运行异常:", "error", DataTypeEnum.ANSWER.value[0])
183209
)
210+
finally:
211+
# 清理任务记录
212+
if task_id in self.running_tasks:
213+
del self.running_tasks[task_id]
214+
215+
async def cancel_task(self, task_id: str) -> bool:
216+
"""
217+
取消指定的任务
218+
:param task_id: 任务ID
219+
:return: 是否成功取消
220+
"""
221+
if task_id in self.running_tasks:
222+
self.running_tasks[task_id]["cancelled"] = True
223+
return True
224+
return False
225+
226+
def get_running_tasks(self):
227+
"""
228+
获取当前运行中的任务列表
229+
:return: 运行中的任务列表
230+
"""
231+
return list(self.running_tasks.keys())

agent/text2sql/analysis/data_render_antv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ async def data_render_ant(state: AgentState):
6868
- 不要解释图表内容或生成文字说明。
6969
- 必须返回符合格式的图表链接。
7070
- 图表需清晰表达数据关系,符合可视化最佳实践。
71+
- x轴和y轴的标签必须使用中文显示。
7172
7273
### 返回格式
7374
![图表](https://example.com/chart.png)

agent/text2sql/analysis/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,16 @@ def create_graph():
3333
graph = StateGraph(AgentState)
3434

3535
graph.add_node("schema_inspector", DatabaseService.get_table_schema)
36-
graph.add_node("llm_reasoning", create_reasoning_steps)
36+
# graph.add_node("llm_reasoning", create_reasoning_steps)
3737
graph.add_node("sql_generator", sql_generate)
3838
graph.add_node("sql_executor", DatabaseService.execute_sql)
3939
graph.add_node("data_render", data_render_ant)
4040
graph.add_node("data_render_apache", data_render_apache)
4141
graph.add_node("summarize", summarize)
4242

4343
graph.set_entry_point("schema_inspector")
44-
graph.add_edge("schema_inspector", "llm_reasoning")
45-
graph.add_edge("llm_reasoning", "sql_generator")
44+
# graph.add_edge("schema_inspector", "llm_reasoning")
45+
graph.add_edge("schema_inspector", "sql_generator")
4646
graph.add_edge("sql_generator", "sql_executor")
4747
graph.add_edge("sql_executor", "summarize")
4848

agent/text2sql/analysis/llm_reasoning.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from datetime import datetime
23

34
from langchain.prompts import ChatPromptTemplate
45

@@ -20,7 +21,7 @@ def create_reasoning_steps(state):
2021
prompt = ChatPromptTemplate.from_template(
2122
"""
2223
You are a helpful data analyst who is great at thinking deeply and reasoning about the user's question and the database schema, and you provide a step-by-step reasoning plan in order to answer the user's question.
23-
24+
2425
1. Think deeply and reason about the user's question and the database schema.
2526
2. Give a step by step reasoning plan in order to answer user's question.
2627
3. The reasoning plan should be in the language same as the language user provided in the input.
@@ -34,8 +35,10 @@ def create_reasoning_steps(state):
3435
3536
Database Schema:
3637
{db_schema}
37-
38+
3839
User's Question: {user_query}
40+
41+
Current Time: {current_time}
3942
4043
Let's think step by step.
4144
"""
@@ -48,10 +51,11 @@ def create_reasoning_steps(state):
4851
{
4952
"db_schema": state["db_info"],
5053
"user_query": state["user_query"],
54+
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
5155
}
5256
)
5357

54-
# logger.info(f"Raw LLM response: {response.content}")
58+
logger.info(f"Raw LLM response: {response.content}")
5559

5660
state["sql_reasoning"] = response.content
5761

agent/text2sql/analysis/llm_summarizer.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,26 @@ def summarize(state: AgentState):
3333
Current Time: {current_time}
3434
3535
## 核心能力
36-
- 趋势识别:捕捉数据变动方向与强度
37-
- 模式归纳:提炼周期性或阶段性规律
38-
- 异常检测:识别显著偏离正常范围的点
39-
- 关键指标提取:聚焦驱动变化的核心维度
36+
- 趋势识别:判断变动方向与持续性(若有时间序列)
37+
- 结构洞察:在截面数据中识别关键分布特征与异常模式
38+
- 模式归纳:提炼可解释的品类/维度差异与行为信号
39+
- 异常检测:发现偏离常规的数值或比例关系
40+
- 驱动分析:定位主导整体表现的核心因素
4041
41-
## 分析流程
42-
1. 解析数据结构,确认时间轴与观测指标
43-
2. 检测整体趋势方向(上升、下降、平稳)
44-
3. 计算相邻周期变化率(环比/同比)
45-
4. 识别突变点、拐点或异常波动
46-
5. 提炼可复用的模式或信号
42+
## 分析策略(动态适配)
43+
- 若含时间维度:执行趋势分析(环比/拐点/周期性)
44+
- 若为单期数据:转向结构分析,聚焦分布不均、高值集中、量价背离等信号
45+
- 统一提取关键指标:如客单价、订单密度、销售额集中度等
46+
- 结合业务常识推断潜在动因或风险
4747
4848
## 输出规范
4949
- **格式**:Markdown 文本,禁用代码块
5050
- **结构**:
51-
## 趋势概述
51+
## 数据分析
5252
一句话概括整体走势
5353
**关键发现**
5454
- 列出2-3项核心结论(**加粗**重点)
55-
**注意**
56-
- 指出异常、波动或数据局限
57-
- **要求**:≤300字,仅简体中文,结论有数据支撑,数据不足则返回“无法判断”
55+
- **要求**:≤300字,仅使用简体中文,语言简洁、数据驱动、逻辑闭环
5856
"""
5957
)
6058

agent/text2sql/database/db_service.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
2+
import re
23
import traceback
3-
4+
from typing import List
5+
import jieba
46
import pandas as pd
57
from sqlalchemy.inspection import inspect
68
from sqlalchemy.sql.expression import text
7-
9+
from rank_bm25 import BM25Okapi
810
from agent.text2sql.state.agent_state import AgentState, ExecutionResult
911
from model.db_connection_pool import get_db_pool
1012

@@ -19,7 +21,37 @@ class DatabaseService:
1921

2022
def __init__(self):
2123
pass
22-
# self.engine = db_pool.get_engine()
24+
25+
@staticmethod
26+
def _build_document(table_name: str, table_info: dict) -> str:
27+
"""
28+
将表结构拼接成一段文本,用于匹配
29+
"""
30+
parts = []
31+
32+
# 添加表名和注释
33+
table_comment = table_info.get("table_comment", "")
34+
parts.append(f"{table_name} {table_comment}")
35+
36+
# 添加列信息
37+
for col_name, col_info in table_info.get("columns", {}).items():
38+
col_comment = col_info.get("comment", "")
39+
col_cn_name = col_info.get("cn_name", "")
40+
parts.append(f"{col_name} {col_cn_name} {col_comment}")
41+
42+
return " ".join(parts)
43+
44+
@staticmethod
45+
def _tokenize_text(text: str) -> List[str]:
46+
"""
47+
:param text
48+
对文本进行分词
49+
"""
50+
# 过滤掉标点符号和特殊字符,只保留中文、英文和数字
51+
filtered_text = re.sub(r"[^\u4e00-\u9fa5a-zA-Z0-9]", " ", text)
52+
tokens = list(jieba.cut(filtered_text))
53+
# 过滤空字符串
54+
return [token.strip() for token in tokens if token.strip()]
2355

2456
@staticmethod
2557
def get_table_schema(state: AgentState):
@@ -28,6 +60,7 @@ def get_table_schema(state: AgentState):
2860
:param state:
2961
:return:
3062
获取数据中所有表schema信息
63+
使用BM25算法过滤出相关表信息
3164
:return: 表schema信息
3265
"""
3366
try:
@@ -46,7 +79,60 @@ def get_table_schema(state: AgentState):
4679
]
4780

4881
table_info[table_name] = {"columns": columns, "foreign_keys": foreign_keys}
49-
state["db_info"] = table_info
82+
83+
# 如果有用户查询,则根据查询过滤表信息
84+
user_query = state.get("user_query", "")
85+
if user_query and table_info:
86+
# 构建表文档
87+
corpus = []
88+
table_names = []
89+
table_comments = []
90+
for table_name, info in table_info.items():
91+
doc = DatabaseService._build_document(table_name, info)
92+
corpus.append(doc)
93+
table_names.append(table_name)
94+
table_comments.append(info.get("table_comment", ""))
95+
96+
# 对文档进行分词
97+
tokenized_corpus = [DatabaseService._tokenize_text(doc) for doc in corpus]
98+
99+
# 使用BM25算法训练模型
100+
bm25 = BM25Okapi(tokenized_corpus)
101+
102+
# 对查询进行分词
103+
query_tokens = DatabaseService._tokenize_text(user_query)
104+
105+
# 计算文档得分
106+
doc_scores = bm25.get_scores(query_tokens)
107+
108+
# 优化算法:提高表注释匹配的优先级
109+
# 如果查询内容直接匹配表注释,则给予更高的权重
110+
for i, (table_comment, score) in enumerate(zip(table_comments, doc_scores)):
111+
if score > 0 and table_comment:
112+
# 检查查询是否直接包含表注释中的关键词
113+
comment_tokens = DatabaseService._tokenize_text(table_comment)
114+
query_text = "".join(query_tokens)
115+
comment_text = "".join(comment_tokens)
116+
117+
# 如果查询中包含表注释的关键内容,增加权重
118+
if comment_text and (comment_text in query_text or query_text in comment_text):
119+
doc_scores[i] *= 2 # 给予两倍权重
120+
121+
# 按得分排序,取前3个最相关的表
122+
top_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)[:3]
123+
124+
# 只保留最相关的表
125+
filtered_table_info = {
126+
table_names[idx]: table_info[table_names[idx]]
127+
for idx in top_indices
128+
if doc_scores[idx] > 0 # 只保留得分大于0的表
129+
}
130+
131+
state["db_info"] = filtered_table_info
132+
else:
133+
state["db_info"] = table_info
134+
135+
logger.info(f"获取数据库表信息成功: {state.get('db_info')}")
50136
except Exception as e:
51137
logger.error(f"获取数据库表信息失败: {e}")
52138
state["db_info"] = {}

0 commit comments

Comments
 (0)