diff --git a/.env.dev b/.env.dev
index 4a86b05b..bf52f1fc 100644
--- a/.env.dev
+++ b/.env.dev
@@ -12,7 +12,7 @@ MINIO_ENDPOINT=127.0.0.1:9000
SQLALCHEMY_DATABASE_URI=postgresql+psycopg2://aix_db:1@127.0.0.1:15432/aix_db
# LangFuse 配置 默认关闭 (可选)
-LANGFUSE_TRACING_ENABLED="true"
+LANGFUSE_TRACING_ENABLED="false"
LANGFUSE_SECRET_KEY = "sk-lf-4bf2a844-4a9c-4626-af69-0cae99bf2bfb"
LANGFUSE_PUBLIC_KEY = "pk-lf-8aff3c29-3239-4a52-8028-bacc185f6c22"
LANGFUSE_BASE_URL = "http://localhost:3000"
diff --git a/agent/deepagent/output/medical_projects_monthly_report.html b/agent/deepagent/output/medical_projects_monthly_report.html
new file mode 100644
index 00000000..ab66c210
--- /dev/null
+++ b/agent/deepagent/output/medical_projects_monthly_report.html
@@ -0,0 +1,329 @@
+
+
+
+
+
+ 医疗项目月度使用数量分析报告
+
+
+
+
+
+
+
+
+
+
+
+
137,767
+
最高单月用量(输液泵)
+
+
+
+
+
+
📊 Top 10 医疗项目 - 最近 3 个月使用量趋势
+
+
+
+
+
📈 各类别项目月度总使用量对比
+
+
+
+
+
🔍 护理类 vs 检验类 vs 药品类 - 使用量分布
+
+
+
+
+
📋 医疗项目月度使用量详细数据 (Top 50)
+
+
+
+ | 排名 |
+ 项目名称 |
+ 2026-01 |
+ 2025-12 |
+ 2025-11 |
+ 2025-10 |
+ 2025-09 |
+ 总使用量 |
+
+
+
+
+
+
+
+
+
+
+
diff --git a/agent/text2sql/analysis/data_render_antv.py b/agent/text2sql/analysis/data_render_antv.py
index 663171ea..76009eb7 100644
--- a/agent/text2sql/analysis/data_render_antv.py
+++ b/agent/text2sql/analysis/data_render_antv.py
@@ -4,6 +4,7 @@
"""
import json
import logging
+import os
import traceback
from decimal import Decimal
from datetime import datetime, date
@@ -37,6 +38,9 @@
"excel": "postgres", # Excel 使用 PostgreSQL 规则
}
+# 前端表格预览最大行数(避免一次性返回全部数据导致页面卡顿)
+TABLE_PREVIEW_MAX_ROWS = int(os.getenv("TABLE_PREVIEW_MAX_ROWS", "100"))
+
def convert_value(v):
"""转换数据类型"""
@@ -513,8 +517,11 @@ async def data_render_ant(state: AgentState):
except Exception as e:
logger.warning(f"获取数据源类型失败: {e},使用默认值 mysql")
+ # 视图侧仅预览前 TABLE_PREVIEW_MAX_ROWS 行,避免一次性渲染全部数据导致前端卡顿
+ preview_data = data[:TABLE_PREVIEW_MAX_ROWS]
+
# 获取实际的列名(从第一条数据中提取)
- actual_columns = list(data[0].keys()) if data else []
+ actual_columns = list(preview_data[0].keys()) if preview_data else []
if not actual_columns:
logger.warning("无法从数据中提取列名,跳过数据渲染")
@@ -536,9 +543,9 @@ async def data_render_ant(state: AgentState):
else:
logger.warning(f"列名映射失败或返回空,使用原始列名。actual_columns={actual_columns[:3]}")
- # 转换数据格式: 将英文列名映射为中文列名
+ # 转换数据格式: 将英文列名映射为中文列名(仅对预览数据执行)
formatted_data = []
- for row in data:
+ for row in preview_data:
formatted_row = {}
for col_name, value in row.items():
chinese_col_name = column_mapping.get(col_name, col_name)
diff --git a/agent/text2sql/analysis/llm_summarizer.py b/agent/text2sql/analysis/llm_summarizer.py
index 492e2ef8..994623ba 100644
--- a/agent/text2sql/analysis/llm_summarizer.py
+++ b/agent/text2sql/analysis/llm_summarizer.py
@@ -1,4 +1,5 @@
import logging
+import os
from datetime import datetime, date
import json
import re
@@ -11,6 +12,11 @@
from agent.text2sql.template.prompt_builder import PromptBuilder
logger = logging.getLogger(__name__)
+
+# 总结时最多带入的数据行数,避免 prompt 过大拖慢 LLM(可配置)
+SUMMARIZE_MAX_ROWS = int(os.getenv("SUMMARIZE_MAX_ROWS", "25"))
+# 总结时数据 JSON 最大字符数,超出则截断并注明(可配置)
+SUMMARIZE_MAX_CHARS = int(os.getenv("SUMMARIZE_MAX_CHARS", "12000"))
"""
大模型数据总结节点
"""
@@ -65,14 +71,16 @@ def summarize(state: AgentState):
prompt_builder = PromptBuilder()
try:
- # 获取数据结果
+ # 获取数据结果,限制行数与长度以加快 LLM 总结
data_result = state["execution_result"].data
-
- # 如果数据是字典或列表,转换为JSON字符串
+ if isinstance(data_result, list) and len(data_result) > SUMMARIZE_MAX_ROWS:
+ data_result = data_result[:SUMMARIZE_MAX_ROWS]
if isinstance(data_result, (dict, list)):
data_result_str = json.dumps(data_result, ensure_ascii=False, indent=2, cls=DecimalEncoder)
else:
data_result_str = str(data_result)
+ if len(data_result_str) > SUMMARIZE_MAX_CHARS:
+ data_result_str = data_result_str[:SUMMARIZE_MAX_CHARS] + "\n\n...(数据已截断,仅展示前一部分供总结)"
# 获取当前时间
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
diff --git a/agent/text2sql/analysis/parallel_collector.py b/agent/text2sql/analysis/parallel_collector.py
index b469bd55..1ea07f00 100644
--- a/agent/text2sql/analysis/parallel_collector.py
+++ b/agent/text2sql/analysis/parallel_collector.py
@@ -80,7 +80,7 @@ def parallel_collect(state: AgentState, tasks: list[str] = None) -> AgentState:
for task, future in futures.items():
try:
- result_state = future.result(timeout=180) # 最多等待60秒
+ result_state = future.result(timeout=60) # 单任务最多等待 60 秒,避免整体过慢
results[task] = result_state
logger.info(f"✅ 任务完成: {task}")
except Exception as e:
diff --git a/agent/text2sql/analysis/unified_collector.py b/agent/text2sql/analysis/unified_collector.py
index 9599b64b..07e790ce 100644
--- a/agent/text2sql/analysis/unified_collector.py
+++ b/agent/text2sql/analysis/unified_collector.py
@@ -44,11 +44,37 @@ async def unified_collect(state: AgentState) -> AgentState:
# chart_config 为空,检查是否是因为查询结果为空
execution_result = state.get("execution_result")
if execution_result and execution_result.success and not execution_result.data:
- # SQL执行成功但无数据,生成空结果卡片,让前端显示空结果提示并可查看SQL
- logger.info("📊 SQL执行成功但无数据,生成空结果卡片")
+ # SQL执行成功但无数据,仍然返回表格模板(temp01),这样前端可以显示表格结构和分页控件
+ logger.info("📊 SQL执行成功但无数据,生成空表格")
+ # 尝试从 SQL 中提取列名
+ columns = []
+ generated_sql = state.get("generated_sql", "") or state.get("filtered_sql", "")
+ if generated_sql:
+ try:
+ # 简单的列名提取:从 SELECT 和 FROM 之间提取
+ import re
+ select_match = re.search(r'SELECT\s+(.*?)\s+FROM', generated_sql, re.IGNORECASE | re.DOTALL)
+ if select_match:
+ select_clause = select_match.group(1)
+ # 分割列名(简单处理,不考虑复杂的子查询)
+ col_parts = [c.strip() for c in select_clause.split(',')]
+ for col in col_parts:
+ # 提取别名或列名
+ if ' AS ' in col.upper():
+ alias = col.split(' AS ')[-1].strip().strip('`').strip('"').strip("'")
+ columns.append(alias)
+ else:
+ # 提取最后一个点号后面的部分(表名.列名 -> 列名)
+ col_name = col.strip().strip('`').strip('"').strip("'")
+ if '.' in col_name:
+ col_name = col_name.split('.')[-1]
+ columns.append(col_name)
+ except Exception as e:
+ logger.warning(f"从 SQL 提取列名失败: {e}")
+
state["render_data"] = {
- "template_code": "temp05",
- "columns": [],
+ "template_code": "temp01", # 改为 temp01,显示表格
+ "columns": columns if columns else [],
"data": [],
}
else:
diff --git a/agent/text2sql/database/db_service.py b/agent/text2sql/database/db_service.py
index f4f37617..0bd84222 100644
--- a/agent/text2sql/database/db_service.py
+++ b/agent/text2sql/database/db_service.py
@@ -12,9 +12,9 @@
import os
import re
import time
-from typing import Dict, List, Tuple, Optional
+from typing import Dict, List, Tuple, Optional, Set
from threading import Lock
-from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import ThreadPoolExecutor, as_completed
import faiss
import jieba
@@ -33,7 +33,7 @@
from model.db_models import TAiModel, TDsPermission, TDsRules
from model.datasource_models import DatasourceTable, DatasourceField
from agent.text2sql.permission.permission_retriever import get_user_permission_filters
-from sqlalchemy import select
+from sqlalchemy import select, update
# 日志配置
logger = logging.getLogger(__name__)
@@ -44,6 +44,13 @@
# 返回表数量配置(可配置,默认 6 个)
TABLE_RETURN_COUNT = int(os.getenv("TABLE_RETURN_COUNT", "6"))
+# 候选表数量少于此值时跳过 Rerank,直接使用 BM25+向量顺序以节省耗时(默认 3)
+RERANK_SKIP_WHEN_CANDIDATES_LE = int(os.getenv("RERANK_SKIP_WHEN_CANDIDATES_LE", "3"))
+# 向量检索 top_k(默认 12,过大增加融合计算量)
+VECTOR_RETRIEVE_TOP_K = int(os.getenv("VECTOR_RETRIEVE_TOP_K", "12"))
+
+# 表结构并行加载的线程数(0 表示不并行,使用串行)
+TABLE_FETCH_MAX_WORKERS = int(os.getenv("TABLE_FETCH_MAX_WORKERS", "8"))
# 缓存配置
_table_info_cache: Dict[Tuple[int, Optional[int]], Tuple[Dict[str, Dict], float]] = {}
@@ -51,6 +58,43 @@
CACHE_TTL = int(os.getenv("TABLE_INFO_CACHE_TTL", "300")) # 缓存有效期(秒),默认5分钟
+def _fetch_one_table_via_inspector(engine, table_name: str, column_permissions: Dict[str, Set[str]]) -> Tuple[str, Optional[Dict]]:
+ """
+ 在独立连接上拉取单表 schema(供并行调用)。
+ 返回 (table_name, table_info_dict) 或 (table_name, None) 表示跳过。
+ """
+ try:
+ with engine.connect() as conn:
+ insp = inspect(conn)
+ columns = {}
+ for col in insp.get_columns(table_name):
+ if table_name in column_permissions and col["name"] not in column_permissions[table_name]:
+ continue
+ columns[col["name"]] = {
+ "type": str(col["type"]),
+ "comment": str(col["comment"] or ""),
+ }
+ if not columns:
+ return table_name, None
+ foreign_keys = [
+ f"{fk['constrained_columns'][0]} -> {fk['referred_table']}.{fk['referred_columns'][0]}"
+ for fk in insp.get_foreign_keys(table_name)
+ ]
+ try:
+ info = insp.get_table_comment(table_name)
+ comment = (info.get("text") or info.get("comment") or "") if isinstance(info, dict) else (info or "")
+ except Exception:
+ comment = ""
+ return table_name, {
+ "columns": columns,
+ "foreign_keys": foreign_keys,
+ "table_comment": str(comment).strip() if comment else "",
+ }
+ except Exception as e:
+ logger.debug(f"读取表 {table_name} 结构失败: {e}")
+ return table_name, None
+
+
# 嵌入模型配置
def get_embedding_model_config():
"""
@@ -467,39 +511,26 @@ def _fetch_all_table_info(self, user_id: Optional[int] = None, use_cache: bool =
logger.warning(f"⚠️ 获取列权限失败: {e}", exc_info=True)
table_info = {}
- for table_name in table_names:
- try:
- columns = {}
- for col in inspector.get_columns(table_name):
- # 权限过滤:如果配置了列权限,只返回有权限的字段
- if table_name in column_permissions:
- if col["name"] not in column_permissions[table_name]:
- continue
-
- columns[col["name"]] = {
- "type": str(col["type"]),
- "comment": str(col["comment"] or ""),
- }
-
- # 如果过滤后没有字段,跳过该表
- if not columns:
- logger.debug(f"⚠️ 表 {table_name} 无可用字段(权限过滤后),跳过")
- continue
-
- foreign_keys = [
- f"{fk['constrained_columns'][0]} -> {fk['referred_table']}.{fk['referred_columns'][0]}"
- for fk in inspector.get_foreign_keys(table_name)
- ]
-
- table_comment = self._get_table_comment(table_name)
-
- table_info[table_name] = {
- "columns": columns,
- "foreign_keys": foreign_keys,
- "table_comment": table_comment,
+ workers = TABLE_FETCH_MAX_WORKERS
+ if workers > 0 and len(table_names) > 1:
+ # 并行加载多表 schema,减少总耗时
+ workers = min(workers, len(table_names), (os.cpu_count() or 4) * 2)
+ logger.info(f"🔀 使用 {workers} 个线程并行加载表结构")
+ with ThreadPoolExecutor(max_workers=workers) as executor:
+ futures = {
+ executor.submit(_fetch_one_table_via_inspector, self._engine, tn, column_permissions): tn
+ for tn in table_names
}
- except Exception as e:
- logger.error(f"❌ 读取表 {table_name} 结构失败: {e}")
+ for future in as_completed(futures):
+ _name, info = future.result()
+ if info is not None:
+ table_info[_name] = info
+ else:
+ # 串行加载(兼容 TABLE_FETCH_MAX_WORKERS=0 或单表)
+ for table_name in table_names:
+ _name, info = _fetch_one_table_via_inspector(self._engine, table_name, column_permissions)
+ if info is not None:
+ table_info[_name] = info
elapsed = time.time() - start_time
logger.info(f"✅ 成功加载 {len(table_info)} 张表,耗时 {elapsed:.2f}s")
@@ -588,6 +619,31 @@ def _fetch_table_info_from_metadata(self, user_id: Optional[int], use_cache: boo
return table_info
+ def _get_table_names_in_metadata(self) -> Set[str]:
+ """
+ 返回当前数据源在 t_datasource_table 中存在的表名集合(大写,用于匹配)。
+ 同时包含「短名」(即 schema.table 的 table 部分),以便与 inspector 返回的 schema.table 形式匹配。
+ """
+ if not self._datasource_id:
+ return set()
+ try:
+ with db_pool.get_session() as session:
+ tables = (
+ session.query(DatasourceTable.table_name)
+ .filter(DatasourceTable.ds_id == self._datasource_id)
+ .all()
+ )
+ out = set()
+ for row in tables:
+ key = str(row[0]).upper()
+ out.add(key)
+ if "." in key:
+ out.add(key.split(".")[-1])
+ return out
+ except Exception as e:
+ logger.warning(f"⚠️ 获取元数据表名失败: {e}")
+ return set()
+
def _get_precomputed_embeddings(self, table_info: Dict[str, Dict]) -> Tuple[Optional[np.ndarray], List[str], List[str]]:
"""
尝试从数据库获取预计算的 embedding。
@@ -608,8 +664,13 @@ def _get_precomputed_embeddings(self, table_info: Dict[str, Dict]) -> Tuple[Opti
.all()
)
- # 构建表名到表的映射(不区分大小写,兼容 Oracle 等会返回大写表名的数据库)
- table_map = {str(table.table_name).upper(): table for table in tables}
+ # 构建表名到表的映射(不区分大小写;同时支持「短名」以便 schema.table 与 table 互匹配)
+ table_map = {}
+ for table in tables:
+ key = str(table.table_name).upper()
+ table_map[key] = table
+ if "." in key:
+ table_map[key.split(".")[-1]] = table
# 收集有预计算 embedding 的表
precomputed_embeddings = []
@@ -617,8 +678,8 @@ def _get_precomputed_embeddings(self, table_info: Dict[str, Dict]) -> Tuple[Opti
missing_table_names = []
for table_name, info in table_info.items():
- # 统一按大写匹配,避免 T_ALARM_INFO / t_alarm_info 不一致导致无法命中
- table = table_map.get(str(table_name).upper())
+ key_upper = str(table_name).upper()
+ table = table_map.get(key_upper) or (table_map.get(key_upper.split(".")[-1]) if "." in key_upper else None)
# 检查是否有 embedding 字段(通过 hasattr 检查,避免字段不存在时报错)
if table and hasattr(table, 'embedding') and table.embedding:
try:
@@ -689,71 +750,171 @@ def _create_embeddings_with_dashscope(self, texts: List[str]) -> np.ndarray:
logger.info(f"✅ 离线模型嵌入生成完成,耗时 {time.time() - start_time:.2f}s,维度: {embedding_dim}")
return embeddings
- # 使用在线模型
- logger.info(f"🌐 调用在线嵌入模型 {self.embedding_model_name}...")
+ # 使用在线模型(批量请求以降低延迟,单次最多 25 条避免 API 限流)
+ logger.info(f"🌐 调用在线嵌入模型 {self.embedding_model_name}(批量)...")
start_time = time.time()
+ EMBEDDING_BATCH_SIZE = 25
embeddings = []
- for doc in texts:
+ for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
+ batch = texts[i : i + EMBEDDING_BATCH_SIZE]
try:
- response = self.embedding_client.embeddings.create(model=self.embedding_model_name, input=doc)
- embeddings.append(response.data[0].embedding)
+ response = self.embedding_client.embeddings.create(
+ model=self.embedding_model_name, input=batch
+ )
+ for item in response.data:
+ embeddings.append(item.embedding)
except Exception as e:
- logger.error(f"❌ 在线模型嵌入生成失败 ({doc[:30]}...): {e}")
- embeddings.append(np.zeros(1024)) # 占位符
-
+ logger.warning(f"⚠️ 在线模型批量嵌入失败 (batch {i}-{i+len(batch)}): {e},逐条回退")
+ for doc in batch:
+ try:
+ r = self.embedding_client.embeddings.create(
+ model=self.embedding_model_name, input=doc
+ )
+ embeddings.append(r.data[0].embedding)
+ except Exception as e2:
+ logger.error(f"❌ 单条嵌入失败: {e2}")
+ embeddings.append(np.zeros(1024))
+ if len(embeddings) != len(texts):
+ logger.warning(f"⚠️ 嵌入数量与输入不一致: {len(embeddings)} vs {len(texts)}")
embeddings = np.array(embeddings).astype("float32")
faiss.normalize_L2(embeddings)
logger.info(f"✅ 在线模型嵌入生成完成,耗时 {time.time() - start_time:.2f}s")
return embeddings
+ def _save_table_embeddings_to_db(self, table_names: List[str], embeddings: np.ndarray) -> None:
+ """
+ 将表结构 embedding 写入数据库 t_datasource_table.embedding。
+ table_names 与 embeddings 逐行对应。
+ """
+ if not self._datasource_id or not table_names or embeddings.size == 0:
+ return
+ try:
+ with db_pool.get_session() as session:
+ tables = (
+ session.query(DatasourceTable)
+ .filter(DatasourceTable.ds_id == self._datasource_id)
+ .all()
+ )
+ name_to_id = {str(t.table_name).upper(): t.id for t in tables}
+ saved = 0
+ for i, table_name in enumerate(table_names):
+ if i >= len(embeddings):
+ break
+ key_upper = str(table_name).strip().upper()
+ table_id = name_to_id.get(key_upper)
+ if not table_id and "." in key_upper:
+ key_short = key_upper.split(".")[-1]
+ table_id = name_to_id.get(key_short)
+ if not table_id:
+ continue
+ emb_list = embeddings[i].tolist() if hasattr(embeddings[i], "tolist") else list(embeddings[i])
+ stmt = (
+ update(DatasourceTable)
+ .where(DatasourceTable.id == table_id)
+ .values(embedding=json.dumps(emb_list))
+ )
+ session.execute(stmt)
+ saved += 1
+ session.commit()
+ logger.info(f"✅ 已保存 {saved}/{len(table_names)} 张表的 embedding 到数据库")
+ except Exception as e:
+ logger.warning(f"⚠️ 保存表 embedding 到数据库失败: {e}", exc_info=True)
+
def _initialize_vector_index(self, table_info: Dict[str, Dict]):
"""
- 初始化 FAISS 向量索引:从数据库读取预计算的 embedding 并构建内存索引。
- 仅使用预计算的 embedding,不在检索时做实时计算。
+ 初始化 FAISS 向量索引:优先从数据库读取预计算的 embedding;
+ 若没有任何预计算或部分表缺失,则实时计算表结构 embedding 并写入数据库,再构建索引。
"""
if self._index_initialized:
return
- # 构建新索引
logger.info("🏗️ 开始构建向量索引(从数据库读取 embedding)...")
start_time = time.time()
- # 记录所有表名和语料(用于 BM25 等)
- self._table_names = list(table_info.keys())
- self._corpus = [self._build_document(name, info) for name, info in table_info.items()]
+ # 仅对「已在 t_datasource_table 中存在」的表建索引并持久化 embedding,避免对未同步到元数据的表重复计算且无法保存
+ tables_in_metadata = self._get_table_names_in_metadata()
+ def _in_metadata(name: str) -> bool:
+ u = str(name).strip().upper()
+ return u in tables_in_metadata or (u.split(".")[-1] if "." in u else u) in tables_in_metadata
+ table_info_indexed = {k: v for k, v in table_info.items() if _in_metadata(k)}
+ skipped_count = len(table_info) - len(table_info_indexed)
+ if skipped_count > 0:
+ logger.info(f"📋 向量索引仅针对元数据中已存在的表:共 {len(table_info_indexed)} 张,跳过 {skipped_count} 张(仅存在于实时库、未在元数据中)")
+ if not table_info_indexed:
+ logger.warning("⚠️ 元数据中无表记录,无法构建向量索引,仅使用 BM25")
+ self._faiss_index = None
+ self._index_initialized = True
+ return
+
+ self._table_names = list(table_info_indexed.keys())
+ self._corpus = [self._build_document(name, info) for name, info in table_info_indexed.items()]
- # 从数据库获取预计算的 embedding(不会做任何实时计算)
precomputed_embeddings, precomputed_table_names, missing_table_names = self._get_precomputed_embeddings(
- table_info
+ table_info_indexed
)
- # 如果没有任何预计算 embedding,则禁用向量索引(仅使用 BM25)
+ # 情况1:没有任何预计算 → 实时计算全部表 embedding,写入 DB,再建索引
if precomputed_embeddings is None or len(precomputed_table_names) == 0:
- logger.warning("⚠️ 未找到任何预计算的表结构 embedding,向量检索将被禁用,仅使用 BM25")
- self._faiss_index = None
+ logger.warning("⚠️ 未找到任何预计算的表结构 embedding,将实时计算并写入数据库后构建向量索引")
+ try:
+ embeddings = self._create_embeddings_with_dashscope(self._corpus)
+ if embeddings is not None and embeddings.size > 0:
+ self._save_table_embeddings_to_db(self._table_names, embeddings)
+ dimension = embeddings.shape[1]
+ self._faiss_index = faiss.IndexFlatIP(dimension)
+ self._faiss_index.add(embeddings)
+ elapsed = time.time() - start_time
+ logger.info(f"🎉 向量索引构建完成(已计算并保存 {len(self._table_names)} 张表),耗时 {elapsed:.2f}s")
+ else:
+ logger.warning("⚠️ 实时计算 embedding 失败,向量检索将被禁用,仅使用 BM25")
+ self._faiss_index = None
+ except Exception as e:
+ logger.warning(f"⚠️ 实时计算表 embedding 失败: {e},向量检索将被禁用", exc_info=True)
+ self._faiss_index = None
self._index_initialized = True
return
# 如果存在缺失的 embedding,为避免索引和表顺序不一致,这里直接禁用向量检索
if len(missing_table_names) > 0:
logger.warning(
- f"⚠️ 共有 {len(missing_table_names)} 张表缺少预计算 embedding,"
- "为保证索引与表顺序一致,本次禁用向量检索,仅使用 BM25"
+ f"⚠️ 共有 {len(missing_table_names)} 张表缺少预计算 embedding,将实时计算缺失部分并合并后构建索引"
)
- self._faiss_index = None
+ try:
+ missing_corpus = [self._build_document(name, table_info_indexed[name]) for name in missing_table_names]
+ missing_embeddings = self._create_embeddings_with_dashscope(missing_corpus)
+ if missing_embeddings is None or missing_embeddings.size == 0:
+ raise ValueError("缺失表的 embedding 计算失败")
+ name_to_precomputed_idx = {name: i for i, name in enumerate(precomputed_table_names)}
+ name_to_missing_idx = {name: i for i, name in enumerate(missing_table_names)}
+ merged_list = []
+ for name in self._table_names:
+ if name in name_to_precomputed_idx:
+ merged_list.append(precomputed_embeddings[name_to_precomputed_idx[name]])
+ else:
+ merged_list.append(missing_embeddings[name_to_missing_idx[name]])
+ embeddings = np.array(merged_list).astype("float32")
+ faiss.normalize_L2(embeddings)
+ self._save_table_embeddings_to_db(missing_table_names, missing_embeddings)
+ dimension = embeddings.shape[1]
+ self._faiss_index = faiss.IndexFlatIP(dimension)
+ self._faiss_index.add(embeddings)
+ elapsed = time.time() - start_time
+ logger.info(f"🎉 向量索引构建完成(已补全 {len(missing_table_names)} 张表),共 {len(self._table_names)} 张表,耗时 {elapsed:.2f}s")
+ except Exception as e:
+ logger.warning(f"⚠️ 补全缺失 embedding 失败: {e},本次禁用向量检索", exc_info=True)
+ self._faiss_index = None
self._index_initialized = True
return
- # 此时说明所有表都存在预计算 embedding,顺序与 self._table_names 一致
+ # 情况3:全部有预计算,直接建索引
embeddings = precomputed_embeddings
-
if embeddings.size == 0:
logger.error("❌ 无法生成嵌入,索引构建失败")
+ self._index_initialized = True
return
- # 初始化 FAISS 索引(仅在内存中)
dimension = embeddings.shape[1]
- self._faiss_index = faiss.IndexFlatIP(dimension) # 内积 = 余弦相似度
+ self._faiss_index = faiss.IndexFlatIP(dimension)
self._faiss_index.add(embeddings)
elapsed = time.time() - start_time
@@ -766,7 +927,7 @@ def _retrieve_by_vector(self, query: str, top_k: int = 10) -> List[int]:
优先使用在线模型,如果没有配置则使用离线模型。
"""
if not self._faiss_index:
- logger.error("❌ 向量索引未初始化")
+ logger.warning("⚠️ 向量索引未初始化(可能无预计算 embedding 且实时计算未执行或失败),跳过向量检索,仅使用 BM25")
return []
try:
@@ -845,9 +1006,65 @@ def _rrf_fusion(bm25_indices: List[int], vector_indices: List[int], k: int = 60)
sorted_indices = sorted(scores.items(), key=lambda x: -x[1])
return [idx for idx, _ in sorted_indices]
+ def _get_dashscope_rerank_url_and_payload(
+ self, query: str, documents: List[str]
+ ) -> tuple:
+ """
+ 根据模型类型返回 DashScope Rerank 的正确 URL 与请求体。
+ - qwen3-rerank: 使用 /compatible-api/v1/reranks,请求体为扁平结构。
+ - gte-rerank-v2 / qwen3-vl-rerank: 使用 /api/v1/services/rerank/text-rerank/text-rerank,请求体为 input/parameters。
+ """
+ base_url = (self.rerank_base_url or "").strip().rstrip("/")
+ model = (self.rerank_model_name or "").strip().lower()
+ is_dashscope_domain = "dashscope.aliyuncs.com" in base_url or "aliyuncs" in base_url
+ has_rerank_path = "/rerank" in base_url or "/reranks" in base_url
+
+ # 若配置的 base_url 仅为域名或缺少 rerank 路径,则按官方文档补全
+ if is_dashscope_domain and not has_rerank_path:
+ from urllib.parse import urlparse
+ parsed = urlparse(base_url if base_url.startswith("http") else f"https://{base_url}")
+ domain = f"{parsed.scheme or 'https'}://{parsed.netloc}"
+ if "qwen3-rerank" in model:
+ effective_url = f"{domain}/compatible-api/v1/reranks"
+ payload = {
+ "model": self.rerank_model_name,
+ "query": query,
+ "documents": documents,
+ "top_n": len(documents),
+ "instruct": "Given a web search query, retrieve relevant passages that answer the query.",
+ }
+ else:
+ effective_url = f"{domain}/api/v1/services/rerank/text-rerank/text-rerank"
+ payload = {
+ "model": self.rerank_model_name,
+ "input": {"query": query, "documents": documents},
+ "parameters": {"top_n": len(documents), "return_documents": False},
+ }
+ return effective_url, payload
+
+ # 使用配置的完整 URL,按模型选择请求体格式
+ if "qwen3-rerank" in model:
+ payload = {
+ "model": self.rerank_model_name,
+ "query": query,
+ "documents": documents,
+ "top_n": len(documents),
+ "instruct": "Given a web search query, retrieve relevant passages that answer the query.",
+ }
+ elif "aliyuncs" in base_url or "qwen" in model or "gte-rerank" in model:
+ payload = {
+ "model": self.rerank_model_name,
+ "input": {"query": query, "documents": documents},
+ "parameters": {"top_n": len(documents), "return_documents": False},
+ }
+ else:
+ payload = {"query": query, "documents": documents}
+ return self.rerank_base_url, payload
+
def _rerank_with_dashscope(self, query: str, candidate_tables: Dict[str, Dict]) -> List[Tuple[str, float]]:
"""
使用 DashScope 重排 API 对候选表进行重排序。
+ 兼容 qwen3-rerank(/compatible-api/v1/reranks)与 gte-rerank-v2 等(/api/v1/services/rerank/...)。
"""
if not self.USE_RERANKER:
logger.debug("⏭️ Reranker 已禁用或配置不完整,跳过重排序")
@@ -866,75 +1083,60 @@ def _rerank_with_dashscope(self, query: str, candidate_tables: Dict[str, Dict])
logger.info(f"🔁 调用重排模型 {self.rerank_model_name} 进行重排序...")
- # 根据API类型选择不同的请求结构
- if "aliyuncs" in self.rerank_base_url or "Qwen" in self.rerank_model_name:
- # 阿里云 DashScope 格式
- payload = {
- "model": self.rerank_model_name,
- "input": {"query": query, "documents": documents},
- "parameters": {"top_n": len(documents), "return_documents": False},
- }
+ # 获取正确的 URL 与请求体(避免 404:DashScope 不同模型路径与 body 不同)
+ if "aliyuncs" in (self.rerank_base_url or "") or (self.rerank_model_name or "").lower().startswith(("qwen", "gte")):
+ effective_url, payload = self._get_dashscope_rerank_url_and_payload(query, documents)
else:
- # 其他格式(如本地模型或通用rerank API)
+ effective_url = self.rerank_base_url
payload = {"query": query, "documents": documents}
- # 设置请求头
headers = {"Authorization": f"Bearer {self.rerank_api_key}", "Content-Type": "application/json"}
-
- # 调用重排 API
- response = requests.post(self.rerank_base_url, headers=headers, json=payload, timeout=30)
+ response = requests.post(effective_url, headers=headers, json=payload, timeout=15)
# 检查响应状态
if response.status_code != 200:
logger.warning(f"⚠️ Rerank API 调用失败: {response.status_code} - {response.text}")
return [(name, 1.0) for name in candidate_tables.keys()]
- # 解析响应
+ # 解析响应(DashScope 两种端点均返回 output.results)
result_data = response.json()
- # 根据API类型解析响应
- if "aliyuncs" in self.rerank_base_url or "Qwen" in self.rerank_model_name:
- # 阿里云格式响应
- if "output" in result_data and "results" in result_data["output"]:
- results = []
- for item in result_data["output"]["results"]:
+ if "output" in result_data and "results" in result_data["output"]:
+ results = []
+ for item in result_data["output"]["results"]:
+ idx = item["index"]
+ score = item["relevance_score"]
+ table_name = next(name for name, text in name_to_text.items() if text == documents[idx])
+ results.append((table_name, score))
+ results.sort(key=lambda x: x[1], reverse=True)
+ logger.info("✅ Rerank 完成")
+ return results
+ # 通用格式(非 DashScope)
+ if "results" in result_data:
+ results = []
+ for item in result_data["results"]:
+ if "index" in item and "relevance_score" in item:
idx = item["index"]
score = item["relevance_score"]
+ if "document" in item and "text" in item["document"]:
+ doc_text = item["document"]["text"]
+ table_name = next(name for name, text in name_to_text.items() if text == doc_text)
+ else:
+ table_name = next(name for name, text in name_to_text.items() if text == documents[idx])
+ results.append((table_name, score))
+ results.sort(key=lambda x: x[1], reverse=True)
+ logger.info("✅ Rerank 完成")
+ return results
+ if isinstance(result_data, list):
+ results = []
+ for i, item in enumerate(result_data):
+ if isinstance(item, dict) and "index" in item:
+ idx = item["index"]
+ score = item.get("score", 1.0 - i * 0.01)
table_name = next(name for name, text in name_to_text.items() if text == documents[idx])
results.append((table_name, score))
-
- results.sort(key=lambda x: x[1], reverse=True)
- logger.info("✅ Rerank 完成")
- return results
- else:
- # 通用格式响应 - 假设直接返回排序结果
- if "results" in result_data:
- results = []
- for item in result_data["results"]:
- if "index" in item and "relevance_score" in item: # 使用relevance_score
- idx = item["index"]
- score = item["relevance_score"] # 使用relevance_score字段
- # 从document对象中提取文本
- if "document" in item and "text" in item["document"]:
- doc_text = item["document"]["text"]
- table_name = next(name for name, text in name_to_text.items() if text == doc_text)
- else:
- table_name = next(name for name, text in name_to_text.items() if text == documents[idx])
- results.append((table_name, score))
- results.sort(key=lambda x: x[1], reverse=True)
- logger.info("✅ Rerank 完成")
- return results
- elif isinstance(result_data, list):
- # 假设直接返回了排序后的索引列表
- results = []
- for i, item in enumerate(result_data):
- if isinstance(item, dict) and "index" in item:
- idx = item["index"]
- score = item.get("score", 1.0 - i * 0.01) # 默认分数递减
- table_name = next(name for name, text in name_to_text.items() if text == documents[idx])
- results.append((table_name, score))
- logger.info("✅ Rerank 完成")
- return results
+ logger.info("✅ Rerank 完成")
+ return results
logger.warning("⚠️ Rerank API 返回格式异常")
return [(name, 1.0) for name in candidate_tables.keys()]
@@ -1177,42 +1379,46 @@ def get_table_schema(self, state: AgentState) -> AgentState:
# 确保 user_query 也在返回的 state 中(虽然它应该已经在初始 state 中了)
state["user_query"] = user_query
- # 初始化向量索引
+ # 初始化向量索引(内部会过滤为仅元数据中存在的表,并写入 self._table_names)
self._initialize_vector_index(all_table_info)
- # 混合检索 - 并行执行 BM25 和向量检索以提高性能
+ # 混合检索仅针对与向量索引一致的表集合,避免 BM25 索引与 self._table_names 不一致导致越界
+ table_info_for_retrieval = {k: all_table_info[k] for k in self._table_names if k in all_table_info}
+ n_tables = len(self._table_names)
+ if n_tables == 0:
+ state["db_info"] = {}
+ return state
+
logger.info("🔍 开始混合检索:BM25 + 向量检索(并行执行)")
- # 使用线程池并行执行 BM25 和向量检索
with ThreadPoolExecutor(max_workers=2) as executor:
- bm25_future = executor.submit(self._retrieve_by_bm25, all_table_info, user_query)
- vector_future = executor.submit(self._retrieve_by_vector, user_query, 20)
+ bm25_future = executor.submit(self._retrieve_by_bm25, table_info_for_retrieval, user_query)
+ vector_future = executor.submit(self._retrieve_by_vector, user_query, VECTOR_RETRIEVE_TOP_K)
- # 等待两个任务完成
bm25_top_indices = bm25_future.result()
vector_top_indices = vector_future.result()
logger.info(f"📊 BM25检索返回 {len(bm25_top_indices)} 个结果")
logger.info(f"🔗 向量检索返回 {len(vector_top_indices)} 个结果")
- # 过滤:仅保留同时在 BM25 前 50 和向量结果中的表
valid_bm25_set = set(bm25_top_indices[:50])
candidate_indices = [idx for idx in vector_top_indices if idx in valid_bm25_set]
logger.info(f"🎯 初步筛选后保留 {len(candidate_indices)} 个候选表")
if not candidate_indices:
- candidate_indices = bm25_top_indices[:TABLE_RETURN_COUNT] # 降级
+ candidate_indices = bm25_top_indices[:TABLE_RETURN_COUNT]
logger.info(f"⚠️ 候选表为空,降级使用BM25前{TABLE_RETURN_COUNT}个结果")
fused_indices = self._rrf_fusion(bm25_top_indices, candidate_indices, k=60)
logger.info(f"🔄 RRF融合后得到 {len(fused_indices)} 个结果")
- # 评分筛选
selected_indices = []
for idx in fused_indices:
- bm25_rank = bm25_top_indices.index(idx) + 1 if idx in bm25_top_indices else len(all_table_info) + 1
+ if idx < 0 or idx >= n_tables:
+ continue
+ bm25_rank = bm25_top_indices.index(idx) + 1 if idx in bm25_top_indices else n_tables + 1
vector_rank = (
- vector_top_indices.index(idx) + 1 if idx in vector_top_indices else len(all_table_info) + 1
+ vector_top_indices.index(idx) + 1 if idx in vector_top_indices else n_tables + 1
)
score = 1 / (60 + bm25_rank) + 1 / (60 + vector_rank)
if score >= 0.01 and len(selected_indices) < 10:
@@ -1221,9 +1427,12 @@ def get_table_schema(self, state: AgentState) -> AgentState:
candidate_table_names = [self._table_names[i] for i in selected_indices]
candidate_table_info = {name: all_table_info[name] for name in candidate_table_names}
- # 重排序
- reranked_results = self._rerank_with_dashscope(user_query, candidate_table_info)
- final_table_names = [name for name, _ in reranked_results][:TABLE_RETURN_COUNT] # 取 top N(可配置)
+ # 候选表较少时跳过 Rerank 以节省约 1–3s
+ if len(candidate_table_info) <= RERANK_SKIP_WHEN_CANDIDATES_LE:
+ reranked_results = [(name, 1.0) for name in candidate_table_names]
+ else:
+ reranked_results = self._rerank_with_dashscope(user_query, candidate_table_info)
+ final_table_names = [name for name, _ in reranked_results][:TABLE_RETURN_COUNT]
# 去重
final_table_names = list(dict.fromkeys(final_table_names))
@@ -1274,12 +1483,36 @@ def execute_sql(self, state: AgentState) -> AgentState:
sql_to_execute = state.get("filtered_sql") or state.get("generated_sql", "")
sql_to_execute = sql_to_execute.strip() if sql_to_execute else ""
- if not sql_to_execute:
- error_msg = "SQL 为空,无法执行"
+ # 处理未成功生成 SQL 的占位符,避免执行无效语句
+ if not sql_to_execute or sql_to_execute == "No SQL query generated":
+ error_msg = (
+ "SQL 为空,无法执行"
+ if not sql_to_execute
+ else "SQL 未成功生成,已跳过执行"
+ )
logger.warning(error_msg)
- state["execution_result"] = ExecutionResult(success=False, error=error_msg)
+ state["execution_result"] = ExecutionResult(
+ success=False,
+ error=error_msg,
+ )
return state
+ # 预览模式行数限制(仅用于页面展示,导出等场景可通过 state['preview_limit_rows']=0 显式关闭)
+ try:
+ from os import getenv as _getenv
+ default_preview_limit = int(_getenv("SQL_EXEC_PREVIEW_LIMIT", "100"))
+ except Exception:
+ default_preview_limit = 100
+ preview_limit = state.get("preview_limit_rows", default_preview_limit)
+
+ sql_for_execution = sql_to_execute
+ if preview_limit and preview_limit > 0 and " limit " not in sql_to_execute.lower():
+ base_sql = sql_to_execute.rstrip(";")
+ sql_for_execution = f"SELECT * FROM ({base_sql}) AS t_preview LIMIT {preview_limit}"
+ logger.info(
+ f"🔎 预览模式:对 SQL 应用 LIMIT {preview_limit} 行(不影响导出等显式关闭预览限制的场景)"
+ )
+
logger.info("▶️ 执行 SQL 语句")
# 记录使用的SQL类型(用于调试)
if state.get("filtered_sql"):
@@ -1299,14 +1532,14 @@ def execute_sql(self, state: AgentState) -> AgentState:
logger.info(f"使用原生驱动执行 SQL(数据源类型: {self._datasource_type})")
config = DatasourceConfigUtil.decrypt_config(self._datasource_config)
result_data = DatasourceConnectionUtil.execute_query(
- self._datasource_type, config, sql_to_execute
+ self._datasource_type, config, sql_for_execution
)
state["execution_result"] = ExecutionResult(success=True, data=result_data)
logger.info(f"✅ SQL 执行成功(原生驱动),返回 {len(result_data)} 条记录")
else:
# 对于 SQLAlchemy 驱动的数据库,使用 engine 执行
with self._engine.connect() as connection:
- result = connection.execute(text(sql_to_execute))
+ result = connection.execute(text(sql_for_execution))
result_data = result.fetchall()
columns = result.keys()
frame = pd.DataFrame(result_data, columns=columns)
diff --git a/agent/text2sql/sql/generator.py b/agent/text2sql/sql/generator.py
index 53799d41..ab228edf 100644
--- a/agent/text2sql/sql/generator.py
+++ b/agent/text2sql/sql/generator.py
@@ -22,6 +22,32 @@
logger = logging.getLogger(__name__)
+# 用户明确要求「全部数据/导出/不限制条数」时的关键词,命中则生成 SQL 时不强制 LIMIT
+_FULL_DATA_OR_EXPORT_KEYWORDS = (
+ "不要限制",
+ "不限制条数",
+ "不要使用限制",
+ "无需限制",
+ "全部数据",
+ "所有数据",
+ "所有明细",
+ "全部明细",
+ "可以导出",
+ "需要导出",
+ "导出",
+ "无限制",
+ "不限条数",
+ "去掉限制",
+)
+
+
+def _user_wants_full_data_or_export(user_query: str) -> bool:
+ """判断用户是否明确要求查看全部数据或导出(不限制条数)。"""
+ if not (user_query and user_query.strip()):
+ return False
+ q = user_query.strip()
+ return any(kw in q for kw in _FULL_DATA_OR_EXPORT_KEYWORDS)
+
def sql_generate(state: AgentState) -> AgentState:
"""
@@ -69,43 +95,29 @@ def sql_generate(state: AgentState) -> AgentState:
except Exception as e:
logger.warning(f"获取数据源信息失败: {e},使用默认值")
- # 表关系补充:在 SQL 生成阶段补充缺失的关联表,并生成外键关系信息
- # 这样可以在 SQL 生成时根据实际需要补充关联表,而不是在检索阶段就补充
- try:
- from agent.text2sql.database.db_service import DatabaseService
- user_id = state.get("user_id")
-
- # 创建 DatabaseService 实例用于表关系补充
- db_service = DatabaseService(datasource_id)
-
- # 获取所有表信息(用于补充关联表,使用缓存避免重复查询)
- all_table_info = db_service._fetch_all_table_info(user_id=user_id, use_cache=True)
-
- # 获取当前选中的表名
- selected_table_names = list(db_info.keys())
-
- # 补充关联表
- supplemented_table_names = db_service.supplement_related_tables(
- selected_table_names,
- all_table_info
- )
-
- # 无论是否补充新表,都用 all_table_info 中的数据(包含 foreign_keys)重建 db_info
- # 这样可以确保已有表的外键关系也能传递到提示词中
- new_db_info = {}
- for table_name in supplemented_table_names:
- if table_name in all_table_info:
- new_db_info[table_name] = all_table_info[table_name]
- else:
- # 兜底:保持原来的表信息
- if table_name in db_info:
+ # 表关系补充:多表时才拉取全量表信息并补充关联表;单表时直接用 state 中的 db_info 以节省耗时
+ if len(db_info) > 1:
+ try:
+ from agent.text2sql.database.db_service import DatabaseService
+ user_id = state.get("user_id")
+ db_service = DatabaseService(datasource_id)
+ all_table_info = db_service._fetch_all_table_info(user_id=user_id, use_cache=True)
+ selected_table_names = list(db_info.keys())
+ supplemented_table_names = db_service.supplement_related_tables(
+ selected_table_names,
+ all_table_info
+ )
+ new_db_info = {}
+ for table_name in supplemented_table_names:
+ if table_name in all_table_info:
+ new_db_info[table_name] = all_table_info[table_name]
+ elif table_name in db_info:
new_db_info[table_name] = db_info[table_name]
-
- db_info = new_db_info
- state["db_info"] = db_info
- logger.debug(f"表关系补充完成,当前 db_info 包含 {len(db_info)} 张表")
- except Exception as e:
- logger.warning(f"表关系补充失败: {e},继续使用原始表列表", exc_info=True)
+ db_info = new_db_info
+ state["db_info"] = db_info
+ logger.debug(f"表关系补充完成,当前 db_info 包含 {len(db_info)} 张表")
+ except Exception as e:
+ logger.warning(f"表关系补充失败: {e},继续使用原始表列表", exc_info=True)
# 格式化 schema 为 M-Schema 格式(包含补充的关联表)
schema_str = format_schema_to_m_schema(
@@ -122,36 +134,45 @@ def sql_generate(state: AgentState) -> AgentState:
# 使用 PromptBuilder 构建提示词
prompt_builder = PromptBuilder()
- # RAG 增强检索:检索术语和训练示例
+ # RAG 增强检索:检索术语和训练示例(带超时,避免拖慢整体)
+ terminologies = ""
+ data_training = ""
try:
- from agent.text2sql.rag.terminology_retriever import retrieve_terminologies
- import asyncio
-
- # 检索术语(同步调用)
- terminologies = retrieve_terminologies(
- question=state["user_query"],
- datasource_id=datasource_id,
- oid=1, # 默认组织ID,后续可以从用户信息获取
- top_k=10,
- )
-
- # 检索训练示例
- from agent.text2sql.rag.training_retriever import retrieve_training_examples
- data_training = retrieve_training_examples(
- question=state["user_query"],
- datasource_id=datasource_id,
- oid=1, # 默认组织ID,后续可以从用户信息获取
- top_k=5,
- )
- except Exception as e:
- logger.warning(f"RAG 检索失败: {e},使用空字符串")
- terminologies = ""
- data_training = ""
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
+ RAG_TIMEOUT = int(__import__("os").getenv("SQL_RAG_TIMEOUT", "4"))
+
+ def _do_rag():
+ from agent.text2sql.rag.terminology_retriever import retrieve_terminologies
+ from agent.text2sql.rag.training_retriever import retrieve_training_examples
+ t = retrieve_terminologies(
+ question=state["user_query"],
+ datasource_id=datasource_id,
+ oid=1,
+ top_k=5,
+ )
+ d = retrieve_training_examples(
+ question=state["user_query"],
+ datasource_id=datasource_id,
+ oid=1,
+ top_k=3,
+ )
+ return t, d
+
+ with ThreadPoolExecutor(max_workers=1) as ex:
+ f = ex.submit(_do_rag)
+ terminologies, data_training = f.result(timeout=RAG_TIMEOUT)
+ except (FuturesTimeoutError, Exception) as e:
+ if "Timeout" in type(e).__name__ or "timeout" in str(e).lower():
+ logger.debug("RAG 检索超时,跳过术语与训练示例")
+ else:
+ logger.warning(f"RAG 检索失败: {e},使用空字符串")
custom_prompt = "" # 自定义提示词(暂时为空)
error_msg = "" # 错误消息(暂时为空)
# 获取系统提示词和用户提示词
+ # 用户明确要求全部数据/导出时不强制 LIMIT,否则默认限制 40 条
+ enable_query_limit = not _user_wants_full_data_or_export(state.get("user_query", ""))
system_prompt, user_prompt = prompt_builder.build_sql_prompt(
db_type=db_type,
schema=schema_str,
@@ -161,7 +182,7 @@ def sql_generate(state: AgentState) -> AgentState:
terminologies=terminologies,
data_training=data_training,
custom_prompt=custom_prompt,
- enable_query_limit=True, # 启用查询限制
+ enable_query_limit=enable_query_limit,
error_msg=error_msg,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
change_title=False, # 暂时不生成对话标题
diff --git a/common/llm_util.py b/common/llm_util.py
index 39b8677b..18196b5d 100644
--- a/common/llm_util.py
+++ b/common/llm_util.py
@@ -73,13 +73,28 @@ def _get_openai():
f"[ERROR] Failed to import ChatOpenAI, please check langchain-openai/langsmith/opentelemetry installation: {e}"
)
raise
-
+ # 构造额外参数,强制关闭 thinking 模式以适配非流式请求
+ # 某些后端通过 default_headers 传递,某些通过 extra_body
+ # 这里优先使用 default_headers 覆盖,防止默认开启 thinking 导致报错
+ extra_headers = {
+ "enable_thinking": "false",
+ # 如果后端需要布尔值而不是字符串,可能需要特定处理,但 HTTP 头通常是字符串
+ # 对于某些兼容层,可能需要放在 extra_body 中
+ }
+
+ # 针对部分后端(如 volcengine, some vLLM setups),参数可能在 body 里
+ extra_body = {
+ "enable_thinking": False
+ }
return ChatOpenAI(
model=model_name,
temperature=temperature,
base_url=model_base_url,
api_key=model_api_key or "empty", # Ensure not None
timeout=timeout, # 设置超时时间(秒)
+ # 关键修改:传入额外参数以禁用 thinking
+ default_headers=extra_headers,
+ extra_body=extra_body,
)
def _get_ollama():
diff --git a/common/local_embedding.py b/common/local_embedding.py
index 0f6d0cf1..2b2a2561 100644
--- a/common/local_embedding.py
+++ b/common/local_embedding.py
@@ -179,7 +179,8 @@ def _get_local_embedding_model():
return None
except Exception as e:
- logger.error(f"Failed to load local embedding model: {e}", exc_info=True)
+ # 本地 embedding 模型加载失败时只输出简要错误,避免在控制台刷堆栈
+ logger.warning("Failed to load local embedding model: %s", e)
return None
diff --git a/controllers/export_api.py b/controllers/export_api.py
new file mode 100644
index 00000000..9481f516
--- /dev/null
+++ b/controllers/export_api.py
@@ -0,0 +1,122 @@
+import logging
+
+from sanic import Blueprint, Request, response
+from sanic_ext import openapi
+
+from common.exception import MyException
+from common.token_decorator import check_token
+from common.param_parser import parse_params
+from constants.code_enum import SysCodeEnum
+from common.res_decorator import async_json_resp
+from model.schemas import ExportTableRequest, TablePageRequest, get_schema
+from services.export_table_service import run_sql_for_export, run_sql_page, data_to_csv
+from services.user_service import decode_jwt_token
+
+
+logger = logging.getLogger(__name__)
+
+bp = Blueprint("exportApi", url_prefix="/export")
+
+
+@bp.post("/table_csv")
+@openapi.summary("导出表格数据为 CSV")
+@openapi.description("根据 SQL 与数据源执行查询(应用权限过滤),并返回 CSV 文件。")
+@openapi.tag("数据导出")
+@openapi.body(
+ {
+ "application/json": {
+ "schema": get_schema(ExportTableRequest),
+ }
+ },
+ description="导出请求体",
+ required=True,
+)
+@check_token
+@parse_params
+async def export_table_csv(request: Request, body: ExportTableRequest):
+ """
+ 表格数据导出接口(CSV)
+ - 使用与数据问答相同的数据源与权限体系;
+ - SQL 可以为无限制条数查询,导出全部匹配数据;
+ - 返回 text/csv 响应,前端可直接触发下载。
+ """
+ try:
+ token = request.headers.get("Authorization")
+ if token and token.startswith("Bearer "):
+ token = token.split(" ")[1]
+ user_dict = await decode_jwt_token(token)
+ user_id = user_dict.get("id", 1)
+
+ sql = body.sql
+ datasource_id = body.datasource_id
+ filename = (body.filename or "export").strip() or "export"
+
+ data, err = run_sql_for_export(sql, datasource_id=datasource_id, user_id=user_id)
+ if err is not None:
+ raise MyException(SysCodeEnum.c_9999, msg=err)
+
+ csv_body = data_to_csv(data or [])
+ headers = {
+ "Content-Type": "text/csv; charset=utf-8",
+ "Content-Disposition": f'attachment; filename="{filename}.csv"',
+ }
+ return response.text(csv_body, headers=headers)
+ except MyException:
+ raise
+ except Exception as e:
+ logger.error(f"导出表格数据失败: {e}")
+ raise MyException(SysCodeEnum.c_9999)
+
+
+@bp.post("/table_page")
+@openapi.summary("表格数据分页查询")
+@openapi.description("根据 SQL 与数据源执行分页查询(应用权限过滤),返回当前页数据。")
+@openapi.tag("数据导出")
+@openapi.body(
+ {
+ "application/json": {
+ "schema": get_schema(TablePageRequest),
+ }
+ },
+ description="分页查询请求体",
+ required=True,
+)
+@openapi.response(
+ 200,
+ {
+ "application/json": {
+ "schema": {"type": "object"},
+ }
+ },
+ description="分页查询结果",
+)
+@check_token
+@async_json_resp
+@parse_params
+async def table_page(request: Request, body: TablePageRequest):
+ """
+ 表格数据分页查询接口(JSON)
+ - 用于前端 NDataTable 翻页;
+ - 会应用与数据问答相同的权限体系;
+ - 不对原始 SQL 强制添加 LIMIT,分页逻辑在子查询外层完成。
+ """
+ token = request.headers.get("Authorization")
+ if token and token.startswith("Bearer "):
+ token = token.split(" ")[1]
+
+ from services.user_service import decode_jwt_token
+
+ user_dict = await decode_jwt_token(token)
+ user_id = user_dict.get("id", 1)
+
+ data, err = run_sql_page(
+ body.sql,
+ datasource_id=body.datasource_id,
+ user_id=user_id,
+ page=body.page,
+ size=body.size,
+ )
+ if err is not None:
+ raise MyException(SysCodeEnum.c_9999, msg=err)
+ return data
+
diff --git a/controllers/llm_chat_api.py b/controllers/llm_chat_api.py
index 9a4ae8e3..3fdfea71 100644
--- a/controllers/llm_chat_api.py
+++ b/controllers/llm_chat_api.py
@@ -16,6 +16,7 @@
DifyGetSuggestedResponse,
StopChatRequest,
StopChatResponse,
+ ExportTableRequest,
get_schema,
)
diff --git a/model/schemas.py b/model/schemas.py
index f51b552c..bf203858 100644
--- a/model/schemas.py
+++ b/model/schemas.py
@@ -484,6 +484,23 @@ class StopChatResponse(BaseResponse):
data: Dict[str, str] = Field(description="停止结果")
+class ExportTableRequest(BaseModel):
+ """表格数据导出请求(数据问答结果导出为 CSV)"""
+
+ sql: str = Field(description="要执行的 SQL(将应用权限过滤)")
+ datasource_id: int = Field(description="数据源 ID")
+ filename: Optional[str] = Field(None, description="下载文件名(不含扩展名),默认 export.csv")
+
+
+class TablePageRequest(BaseModel):
+ """表格数据分页请求(基于 SQL 的分页查询)"""
+
+ sql: str = Field(description="要执行的 SQL(将应用权限过滤)")
+ datasource_id: int = Field(description="数据源 ID")
+ page: int = Field(1, description="页码(从 1 开始)")
+ size: int = Field(20, description="每页条数")
+
+
# ==================== 文件服务相关模型 ====================
class ReadFileRequest(BaseModel):
"""读取文件请求"""
diff --git a/services/datasource_service.py b/services/datasource_service.py
index d966fc27..da5b0b77 100644
--- a/services/datasource_service.py
+++ b/services/datasource_service.py
@@ -418,12 +418,36 @@ def _compute_and_save_table_embeddings_batch(session: Session, items: List[Dict[
# 获取 embedding 客户端(支持在线/离线模型切换)
embedding_client, model_name = DatasourceService._get_embedding_client()
+ all_embeddings = []
try:
if embedding_client and model_name:
# 使用在线模型批量计算
logger.info(f"批量计算 {len(docs)} 个表的 embedding(在线模型: {model_name})...")
- response = embedding_client.embeddings.create(model=model_name, input=docs)
- data = response.data or []
+ # TODO 改为小批次计算 response = embedding_client.embeddings.create(model=model_name, input=docs) 避免出现InternalError.Algo.InvalidParameter: Value error, batch size is invalid, it should not be larger than 25.: input.contents
+ # 定义最大批次大小,根据报错信息设置为 25,为了安全也可以设为 20
+ MAX_BATCH_SIZE = 25
+ for i in range(0, len(docs), MAX_BATCH_SIZE):
+ batch_docs = docs[i: i + MAX_BATCH_SIZE]
+ try:
+ # 对每个小批次发起请求
+ response = embedding_client.embeddings.create(model=model_name, input=batch_docs)
+
+ # 这里需要根据实际 response 结构提取 embeddings
+ # 假设 response.data 是包含 embedding 对象的列表
+ batch_embeddings = [item.embedding for item in response.data]
+ all_embeddings.extend(batch_embeddings)
+
+ # 可选:保存当前批次的结果到数据库,避免全部失败
+ # self._save_embeddings_to_db(batch_docs, batch_embeddings)
+
+ except Exception as e:
+ # 记录错误日志,决定是跳过还是抛出
+ logger.error(f"批量计算 embedding 失败,批次范围 {i}-{i + len(batch_docs)}: {e}")
+ raise e # 或者选择 continue 跳过当前批次
+
+ # response = embedding_client.embeddings.create(model=model_name, input=docs)
+
+ data = all_embeddings or []
if len(data) != len(tables_for_embedding):
logger.warning(
@@ -461,6 +485,7 @@ def _compute_and_save_table_embeddings_batch(session: Session, items: List[Dict[
except Exception as e:
logger.error(f"批量计算表 embedding 失败: {e}", exc_info=True)
+
@staticmethod
def update_datasource(session: Session, ds_id: int, data: Dict[str, Any]) -> Optional[Datasource]:
"""更新数据源"""
diff --git a/services/embedding_service.py b/services/embedding_service.py
index f74f4e6d..a67d653a 100644
--- a/services/embedding_service.py
+++ b/services/embedding_service.py
@@ -85,8 +85,11 @@ async def generate_embedding(text: str) -> Optional[List[float]]:
return response.data[0].embedding
except Exception as e:
- traceback.print_exc()
- logger.warning(f"Failed to generate embedding with online model: {e}, falling back to local CPU model")
+ # 在线 embedding 失败时降级为本地模型,避免在控制台打印完整异常堆栈
+ logger.warning(
+ "Failed to generate embedding with online model: %s, falling back to local CPU model",
+ e,
+ )
# 在线模型失败时,回退到本地模型
from common.local_embedding import generate_embedding_local
return await generate_embedding_local(text)
diff --git a/services/export_table_service.py b/services/export_table_service.py
new file mode 100644
index 00000000..6210494c
--- /dev/null
+++ b/services/export_table_service.py
@@ -0,0 +1,217 @@
+"""
+表格数据导出服务
+根据 SQL + 数据源执行查询(含权限过滤),返回可导出为 CSV/Excel 的数据。
+"""
+
+import io
+import logging
+from typing import List, Dict, Any, Tuple, Optional
+
+from sqlalchemy import text
+
+from agent.text2sql.database.db_service import DatabaseService
+from agent.text2sql.permission.filter_injector import permission_filter_injector
+from agent.text2sql.analysis.data_render_antv import extract_table_names_sqlglot
+from services.datasource_service import DatasourceService
+from model.db_connection_pool import get_db_pool
+from common.datasource_util import DatasourceConfigUtil, DatasourceConnectionUtil, DB, ConnectType
+
+logger = logging.getLogger(__name__)
+
+
+def run_sql_for_export(
+ sql: str,
+ datasource_id: int,
+ user_id: int,
+) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
+ """
+ 执行 SQL(应用权限过滤)并返回结果数据,供导出使用。
+
+ Args:
+ sql: 原始 SQL(可为无 LIMIT 的查询)
+ datasource_id: 数据源 ID
+ user_id: 当前用户 ID(用于权限过滤)
+
+ Returns:
+ (data_list, error_message)。成功时 data_list 为 list of dict,失败时 data_list 为 None、error_message 为错误信息。
+ """
+ if not sql or not sql.strip() or sql.strip() == "No SQL query generated":
+ return None, "SQL 为空,无法导出"
+
+ sql = sql.strip()
+ db_type = "mysql"
+ try:
+ with get_db_pool().get_session() as session:
+ ds = DatasourceService.get_datasource_by_id(session, datasource_id)
+ if not ds:
+ return None, "数据源不存在"
+ db_type = ds.type or "mysql"
+ except Exception as e:
+ logger.warning(f"获取数据源类型失败: {e}")
+ table_names = extract_table_names_sqlglot(sql, db_type)
+ db_info = {t: {} for t in table_names} if table_names else {}
+
+ state = {
+ "generated_sql": sql,
+ "filtered_sql": None,
+ "datasource_id": datasource_id,
+ "user_id": user_id,
+ "db_info": db_info,
+ "used_tables": table_names,
+ # 导出场景显式关闭预览限制,执行完整 SQL
+ "preview_limit_rows": 0,
+ }
+ try:
+ permission_filter_injector(state)
+ except Exception as e:
+ logger.warning(f"权限过滤失败: {e}", exc_info=True)
+ db_service = DatabaseService(datasource_id)
+ db_service.execute_sql(state)
+ result = state.get("execution_result")
+ if not result:
+ return None, "执行结果为空"
+ if not result.success:
+ return None, result.error or "执行失败"
+ data = result.data
+ if not data:
+ return [], None
+ if isinstance(data, list):
+ return data, None
+ return list(data), None
+
+
+def run_sql_page(
+ sql: str,
+ datasource_id: int,
+ user_id: int,
+ page: int,
+ size: int,
+) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
+ """
+ 带分页的 SQL 查询,用于前端表格翻页。
+ - 会应用权限过滤;
+ - 不对原始 SQL 追加 LIMIT,而是通过子查询包装实现分页;
+ - 返回 {total, page, size, rows}。
+ """
+ if not sql or not sql.strip() or sql.strip() == "No SQL query generated":
+ return None, "SQL 为空,无法查询"
+
+ if page <= 0:
+ page = 1
+ if size <= 0:
+ size = 20
+
+ sql = sql.strip()
+ db_type = "mysql"
+ try:
+ with get_db_pool().get_session() as session:
+ ds = DatasourceService.get_datasource_by_id(session, datasource_id)
+ if not ds:
+ return None, "数据源不存在"
+ db_type = ds.type or "mysql"
+ except Exception as e:
+ logger.warning(f"获取数据源类型失败: {e}")
+
+ table_names = extract_table_names_sqlglot(sql, db_type)
+ db_info = {t: {} for t in table_names} if table_names else {}
+
+ # 通过权限注入节点获取 filtered_sql(关闭预览限制,保证分页在完整 SQL 上进行)
+ state: Dict[str, Any] = {
+ "generated_sql": sql,
+ "filtered_sql": None,
+ "datasource_id": datasource_id,
+ "user_id": user_id,
+ "db_info": db_info,
+ "used_tables": table_names,
+ "preview_limit_rows": 0,
+ }
+ try:
+ permission_filter_injector(state)
+ except Exception as e:
+ logger.warning(f"权限过滤失败: {e}", exc_info=True)
+
+ filtered_sql = state.get("filtered_sql") or state.get("generated_sql", "")
+ if not filtered_sql:
+ return None, "权限过滤后 SQL 为空"
+
+ base_sql = filtered_sql.strip().rstrip(";")
+ offset = (page - 1) * size
+
+ count_sql = f"SELECT COUNT(*) AS cnt FROM ({base_sql}) AS t_count"
+ page_sql = f"SELECT * FROM ({base_sql}) AS t_page LIMIT {size} OFFSET {offset}"
+
+ # 使用 DatabaseService 中的数据源信息来判断驱动类型
+ db_service = DatabaseService(datasource_id)
+ rows: List[Dict[str, Any]] = []
+ total: int = 0
+
+ try:
+ use_native_driver = False
+ if db_service._datasource_type and datasource_id:
+ db_enum = DB.get_db(db_service._datasource_type, default_if_none=True)
+ use_native_driver = db_enum.connect_type == ConnectType.py_driver
+
+ if use_native_driver and db_service._datasource_config:
+ # 原生驱动:使用 DatasourceConnectionUtil 执行
+ config = DatasourceConfigUtil.decrypt_config(db_service._datasource_config)
+ count_result = DatasourceConnectionUtil.execute_query(
+ db_service._datasource_type, config, count_sql
+ )
+ if count_result and isinstance(count_result, list):
+ first_row = count_result[0]
+ total = int(first_row.get("cnt") or 0)
+ page_result = DatasourceConnectionUtil.execute_query(
+ db_service._datasource_type, config, page_sql
+ )
+ rows = page_result or []
+ else:
+ # SQLAlchemy 驱动
+ if not db_service._engine:
+ return None, "数据源未正确初始化(缺少 SQLAlchemy engine)"
+ with db_service._engine.connect() as connection:
+ count_result = connection.execute(text(count_sql)).fetchone()
+ if count_result is not None:
+ # count(*) 可能通过索引 0 或键 'cnt' 访问
+ total = int(getattr(count_result, "cnt", count_result[0]))
+
+ result = connection.execute(text(page_sql))
+ result_rows = result.fetchall()
+ columns = list(result.keys())
+ for row in result_rows:
+ row_dict: Dict[str, Any] = {}
+ for i, col in enumerate(columns):
+ row_dict[col] = row[i]
+ rows.append(row_dict)
+ except Exception as e:
+ logger.error(f"分页查询失败: {e}", exc_info=True)
+ return None, f"分页查询失败: {e}"
+
+ return {
+ "total": total,
+ "page": page,
+ "size": size,
+ "rows": rows,
+ }, None
+
+
+def data_to_csv(data: List[Dict[str, Any]], add_bom: bool = True) -> str:
+ """将 list of dict 转为 CSV 字符串(表头为所有出现过的 key,BOM 便于 Excel 识别 UTF-8)。"""
+ import csv
+ if not data:
+ return ""
+ all_keys = []
+ seen = set()
+ for row in data:
+ for k in row:
+ if k not in seen:
+ seen.add(k)
+ all_keys.append(k)
+ out = io.StringIO()
+ writer = csv.DictWriter(out, fieldnames=all_keys, extrasaction="ignore")
+ writer.writeheader()
+ for row in data:
+ writer.writerow(row)
+ body = out.getvalue()
+ if add_bom:
+ body = "\ufeff" + body
+ return body
diff --git a/web/src/api/index.ts b/web/src/api/index.ts
index 0c3d6ac9..07de1501 100644
--- a/web/src/api/index.ts
+++ b/web/src/api/index.ts
@@ -1,6 +1,7 @@
// import { mockEventStreamText } from '@/data'
// import { currentHost } from '@/utils/location'
// import request from '@/utils/request'
+import { useUserStore } from '@/store/business/userStore'
/**
* Event Stream 调用大模型接口 Ollama3 (Fetch 调用)
@@ -199,6 +200,77 @@ export async function fead_back(chat_id, rating) {
return fetch(req)
}
+/**
+ * 导出表格数据为 CSV
+ * @param sql
+ * @param datasource_id
+ * @param filename 不含扩展名,默认为 export
+ */
+export async function export_table_csv(sql: string, datasource_id: number, filename?: string) {
+ const userStore = useUserStore()
+ const token = userStore.getUserToken()
+ const url = new URL(`${location.origin}/sanic/export/table_csv`)
+
+ const req = new Request(url, {
+ mode: 'cors',
+ method: 'post',
+ headers: {
+ 'Content-Type': 'application/json',
+ 'Authorization': `Bearer ${token}`,
+ },
+ body: JSON.stringify({
+ sql,
+ datasource_id,
+ filename,
+ }),
+ })
+
+ const res = await fetch(req)
+ if (!res.ok) {
+ throw new Error(`导出失败,状态码:${res.status}`)
+ }
+
+ const blob = await res.blob()
+ const objectUrl = URL.createObjectURL(blob)
+ const a = document.createElement('a')
+ a.href = objectUrl
+ a.download = `${(filename || 'export').trim() || 'export'}.csv`
+ document.body.appendChild(a)
+ a.click()
+ document.body.removeChild(a)
+ URL.revokeObjectURL(objectUrl)
+}
+
+/**
+ * 表格分页查询(基于 SQL)
+ * @param sql
+ * @param datasource_id
+ * @param page
+ * @param size
+ */
+export async function table_page(sql: string, datasource_id: number, page = 1, size = 20) {
+ const userStore = useUserStore()
+ const token = userStore.getUserToken()
+ const url = new URL(`${location.origin}/sanic/export/table_page`)
+
+ const req = new Request(url, {
+ mode: 'cors',
+ method: 'post',
+ headers: {
+ 'Content-Type': 'application/json',
+ 'Authorization': `Bearer ${token}`,
+ },
+ body: JSON.stringify({
+ sql,
+ datasource_id,
+ page,
+ size,
+ }),
+ })
+
+ return fetch(req)
+}
+
/**
* 问题建议
* @param chat_id
diff --git a/web/src/components/MarkdownPreview/index.vue b/web/src/components/MarkdownPreview/index.vue
index d9ed7439..0eb6d0c4 100644
--- a/web/src/components/MarkdownPreview/index.vue
+++ b/web/src/components/MarkdownPreview/index.vue
@@ -55,6 +55,7 @@ interface Props {
recommended_questions?: string[]
} | null
recordId?: number // 记录ID,用于查询SQL语句
+ datasourceId?: number // 数据源ID,用于导出
}
// 解构 props
@@ -722,6 +723,7 @@ const currentQaOption = computed(() => {
:chart-data="props.chartData"
:record-id="props.recordId"
:qa-type="props.qaType"
+ :datasource-id="props.datasourceId"
@chart-rendered="() => onChartCompletedReader()"
@table-rendered="() => onTableCompletedReader()"
/>
diff --git a/web/src/components/MarkdownPreview/markdown-antv.vue b/web/src/components/MarkdownPreview/markdown-antv.vue
index 6e63ef96..5668d624 100644
--- a/web/src/components/MarkdownPreview/markdown-antv.vue
+++ b/web/src/components/MarkdownPreview/markdown-antv.vue
@@ -1,5 +1,5 @@