diff --git a/.env.dev b/.env.dev index 4a86b05b..bf52f1fc 100644 --- a/.env.dev +++ b/.env.dev @@ -12,7 +12,7 @@ MINIO_ENDPOINT=127.0.0.1:9000 SQLALCHEMY_DATABASE_URI=postgresql+psycopg2://aix_db:1@127.0.0.1:15432/aix_db # LangFuse 配置 默认关闭 (可选) -LANGFUSE_TRACING_ENABLED="true" +LANGFUSE_TRACING_ENABLED="false" LANGFUSE_SECRET_KEY = "sk-lf-4bf2a844-4a9c-4626-af69-0cae99bf2bfb" LANGFUSE_PUBLIC_KEY = "pk-lf-8aff3c29-3239-4a52-8028-bacc185f6c22" LANGFUSE_BASE_URL = "http://localhost:3000" diff --git a/agent/deepagent/output/medical_projects_monthly_report.html b/agent/deepagent/output/medical_projects_monthly_report.html new file mode 100644 index 00000000..ab66c210 --- /dev/null +++ b/agent/deepagent/output/medical_projects_monthly_report.html @@ -0,0 +1,329 @@ + + + + + + 医疗项目月度使用数量分析报告 + + + + +
+
+

🏥 医疗项目月度使用数量分析报告

+

数据统计时间范围:2025 年 1 月 - 2026 年 1 月 | 数据来源于医院 HIS 系统结算明细表

+
+ +
+
+
10
+
统计月份数
+
+
+
500+
+
医疗项目种类
+
+
+
137,767
+
最高单月用量(输液泵)
+
+
+
Top 10
+
热门项目排名
+
+
+ +
+

📊 Top 10 医疗项目 - 最近 3 个月使用量趋势

+ +
+ +
+

📈 各类别项目月度总使用量对比

+ +
+ +
+

🔍 护理类 vs 检验类 vs 药品类 - 使用量分布

+ +
+ +
+

📋 医疗项目月度使用量详细数据 (Top 50)

+ + + + + + + + + + + + + + + +
排名项目名称2026-012025-122025-112025-102025-09总使用量
+
+
+ + + + diff --git a/agent/text2sql/analysis/data_render_antv.py b/agent/text2sql/analysis/data_render_antv.py index 663171ea..76009eb7 100644 --- a/agent/text2sql/analysis/data_render_antv.py +++ b/agent/text2sql/analysis/data_render_antv.py @@ -4,6 +4,7 @@ """ import json import logging +import os import traceback from decimal import Decimal from datetime import datetime, date @@ -37,6 +38,9 @@ "excel": "postgres", # Excel 使用 PostgreSQL 规则 } +# 前端表格预览最大行数(避免一次性返回全部数据导致页面卡顿) +TABLE_PREVIEW_MAX_ROWS = int(os.getenv("TABLE_PREVIEW_MAX_ROWS", "100")) + def convert_value(v): """转换数据类型""" @@ -513,8 +517,11 @@ async def data_render_ant(state: AgentState): except Exception as e: logger.warning(f"获取数据源类型失败: {e},使用默认值 mysql") + # 视图侧仅预览前 TABLE_PREVIEW_MAX_ROWS 行,避免一次性渲染全部数据导致前端卡顿 + preview_data = data[:TABLE_PREVIEW_MAX_ROWS] + # 获取实际的列名(从第一条数据中提取) - actual_columns = list(data[0].keys()) if data else [] + actual_columns = list(preview_data[0].keys()) if preview_data else [] if not actual_columns: logger.warning("无法从数据中提取列名,跳过数据渲染") @@ -536,9 +543,9 @@ async def data_render_ant(state: AgentState): else: logger.warning(f"列名映射失败或返回空,使用原始列名。actual_columns={actual_columns[:3]}") - # 转换数据格式: 将英文列名映射为中文列名 + # 转换数据格式: 将英文列名映射为中文列名(仅对预览数据执行) formatted_data = [] - for row in data: + for row in preview_data: formatted_row = {} for col_name, value in row.items(): chinese_col_name = column_mapping.get(col_name, col_name) diff --git a/agent/text2sql/analysis/llm_summarizer.py b/agent/text2sql/analysis/llm_summarizer.py index 492e2ef8..994623ba 100644 --- a/agent/text2sql/analysis/llm_summarizer.py +++ b/agent/text2sql/analysis/llm_summarizer.py @@ -1,4 +1,5 @@ import logging +import os from datetime import datetime, date import json import re @@ -11,6 +12,11 @@ from agent.text2sql.template.prompt_builder import PromptBuilder logger = logging.getLogger(__name__) + +# 总结时最多带入的数据行数,避免 prompt 过大拖慢 LLM(可配置) +SUMMARIZE_MAX_ROWS = int(os.getenv("SUMMARIZE_MAX_ROWS", "25")) +# 总结时数据 JSON 最大字符数,超出则截断并注明(可配置) +SUMMARIZE_MAX_CHARS = int(os.getenv("SUMMARIZE_MAX_CHARS", "12000")) """ 大模型数据总结节点 """ @@ -65,14 +71,16 @@ def summarize(state: AgentState): prompt_builder = PromptBuilder() try: - # 获取数据结果 + # 获取数据结果,限制行数与长度以加快 LLM 总结 data_result = state["execution_result"].data - - # 如果数据是字典或列表,转换为JSON字符串 + if isinstance(data_result, list) and len(data_result) > SUMMARIZE_MAX_ROWS: + data_result = data_result[:SUMMARIZE_MAX_ROWS] if isinstance(data_result, (dict, list)): data_result_str = json.dumps(data_result, ensure_ascii=False, indent=2, cls=DecimalEncoder) else: data_result_str = str(data_result) + if len(data_result_str) > SUMMARIZE_MAX_CHARS: + data_result_str = data_result_str[:SUMMARIZE_MAX_CHARS] + "\n\n...(数据已截断,仅展示前一部分供总结)" # 获取当前时间 current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/agent/text2sql/analysis/parallel_collector.py b/agent/text2sql/analysis/parallel_collector.py index b469bd55..1ea07f00 100644 --- a/agent/text2sql/analysis/parallel_collector.py +++ b/agent/text2sql/analysis/parallel_collector.py @@ -80,7 +80,7 @@ def parallel_collect(state: AgentState, tasks: list[str] = None) -> AgentState: for task, future in futures.items(): try: - result_state = future.result(timeout=180) # 最多等待60秒 + result_state = future.result(timeout=60) # 单任务最多等待 60 秒,避免整体过慢 results[task] = result_state logger.info(f"✅ 任务完成: {task}") except Exception as e: diff --git a/agent/text2sql/analysis/unified_collector.py b/agent/text2sql/analysis/unified_collector.py index 9599b64b..07e790ce 100644 --- a/agent/text2sql/analysis/unified_collector.py +++ b/agent/text2sql/analysis/unified_collector.py @@ -44,11 +44,37 @@ async def unified_collect(state: AgentState) -> AgentState: # chart_config 为空,检查是否是因为查询结果为空 execution_result = state.get("execution_result") if execution_result and execution_result.success and not execution_result.data: - # SQL执行成功但无数据,生成空结果卡片,让前端显示空结果提示并可查看SQL - logger.info("📊 SQL执行成功但无数据,生成空结果卡片") + # SQL执行成功但无数据,仍然返回表格模板(temp01),这样前端可以显示表格结构和分页控件 + logger.info("📊 SQL执行成功但无数据,生成空表格") + # 尝试从 SQL 中提取列名 + columns = [] + generated_sql = state.get("generated_sql", "") or state.get("filtered_sql", "") + if generated_sql: + try: + # 简单的列名提取:从 SELECT 和 FROM 之间提取 + import re + select_match = re.search(r'SELECT\s+(.*?)\s+FROM', generated_sql, re.IGNORECASE | re.DOTALL) + if select_match: + select_clause = select_match.group(1) + # 分割列名(简单处理,不考虑复杂的子查询) + col_parts = [c.strip() for c in select_clause.split(',')] + for col in col_parts: + # 提取别名或列名 + if ' AS ' in col.upper(): + alias = col.split(' AS ')[-1].strip().strip('`').strip('"').strip("'") + columns.append(alias) + else: + # 提取最后一个点号后面的部分(表名.列名 -> 列名) + col_name = col.strip().strip('`').strip('"').strip("'") + if '.' in col_name: + col_name = col_name.split('.')[-1] + columns.append(col_name) + except Exception as e: + logger.warning(f"从 SQL 提取列名失败: {e}") + state["render_data"] = { - "template_code": "temp05", - "columns": [], + "template_code": "temp01", # 改为 temp01,显示表格 + "columns": columns if columns else [], "data": [], } else: diff --git a/agent/text2sql/database/db_service.py b/agent/text2sql/database/db_service.py index f4f37617..0bd84222 100644 --- a/agent/text2sql/database/db_service.py +++ b/agent/text2sql/database/db_service.py @@ -12,9 +12,9 @@ import os import re import time -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Tuple, Optional, Set from threading import Lock -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed import faiss import jieba @@ -33,7 +33,7 @@ from model.db_models import TAiModel, TDsPermission, TDsRules from model.datasource_models import DatasourceTable, DatasourceField from agent.text2sql.permission.permission_retriever import get_user_permission_filters -from sqlalchemy import select +from sqlalchemy import select, update # 日志配置 logger = logging.getLogger(__name__) @@ -44,6 +44,13 @@ # 返回表数量配置(可配置,默认 6 个) TABLE_RETURN_COUNT = int(os.getenv("TABLE_RETURN_COUNT", "6")) +# 候选表数量少于此值时跳过 Rerank,直接使用 BM25+向量顺序以节省耗时(默认 3) +RERANK_SKIP_WHEN_CANDIDATES_LE = int(os.getenv("RERANK_SKIP_WHEN_CANDIDATES_LE", "3")) +# 向量检索 top_k(默认 12,过大增加融合计算量) +VECTOR_RETRIEVE_TOP_K = int(os.getenv("VECTOR_RETRIEVE_TOP_K", "12")) + +# 表结构并行加载的线程数(0 表示不并行,使用串行) +TABLE_FETCH_MAX_WORKERS = int(os.getenv("TABLE_FETCH_MAX_WORKERS", "8")) # 缓存配置 _table_info_cache: Dict[Tuple[int, Optional[int]], Tuple[Dict[str, Dict], float]] = {} @@ -51,6 +58,43 @@ CACHE_TTL = int(os.getenv("TABLE_INFO_CACHE_TTL", "300")) # 缓存有效期(秒),默认5分钟 +def _fetch_one_table_via_inspector(engine, table_name: str, column_permissions: Dict[str, Set[str]]) -> Tuple[str, Optional[Dict]]: + """ + 在独立连接上拉取单表 schema(供并行调用)。 + 返回 (table_name, table_info_dict) 或 (table_name, None) 表示跳过。 + """ + try: + with engine.connect() as conn: + insp = inspect(conn) + columns = {} + for col in insp.get_columns(table_name): + if table_name in column_permissions and col["name"] not in column_permissions[table_name]: + continue + columns[col["name"]] = { + "type": str(col["type"]), + "comment": str(col["comment"] or ""), + } + if not columns: + return table_name, None + foreign_keys = [ + f"{fk['constrained_columns'][0]} -> {fk['referred_table']}.{fk['referred_columns'][0]}" + for fk in insp.get_foreign_keys(table_name) + ] + try: + info = insp.get_table_comment(table_name) + comment = (info.get("text") or info.get("comment") or "") if isinstance(info, dict) else (info or "") + except Exception: + comment = "" + return table_name, { + "columns": columns, + "foreign_keys": foreign_keys, + "table_comment": str(comment).strip() if comment else "", + } + except Exception as e: + logger.debug(f"读取表 {table_name} 结构失败: {e}") + return table_name, None + + # 嵌入模型配置 def get_embedding_model_config(): """ @@ -467,39 +511,26 @@ def _fetch_all_table_info(self, user_id: Optional[int] = None, use_cache: bool = logger.warning(f"⚠️ 获取列权限失败: {e}", exc_info=True) table_info = {} - for table_name in table_names: - try: - columns = {} - for col in inspector.get_columns(table_name): - # 权限过滤:如果配置了列权限,只返回有权限的字段 - if table_name in column_permissions: - if col["name"] not in column_permissions[table_name]: - continue - - columns[col["name"]] = { - "type": str(col["type"]), - "comment": str(col["comment"] or ""), - } - - # 如果过滤后没有字段,跳过该表 - if not columns: - logger.debug(f"⚠️ 表 {table_name} 无可用字段(权限过滤后),跳过") - continue - - foreign_keys = [ - f"{fk['constrained_columns'][0]} -> {fk['referred_table']}.{fk['referred_columns'][0]}" - for fk in inspector.get_foreign_keys(table_name) - ] - - table_comment = self._get_table_comment(table_name) - - table_info[table_name] = { - "columns": columns, - "foreign_keys": foreign_keys, - "table_comment": table_comment, + workers = TABLE_FETCH_MAX_WORKERS + if workers > 0 and len(table_names) > 1: + # 并行加载多表 schema,减少总耗时 + workers = min(workers, len(table_names), (os.cpu_count() or 4) * 2) + logger.info(f"🔀 使用 {workers} 个线程并行加载表结构") + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit(_fetch_one_table_via_inspector, self._engine, tn, column_permissions): tn + for tn in table_names } - except Exception as e: - logger.error(f"❌ 读取表 {table_name} 结构失败: {e}") + for future in as_completed(futures): + _name, info = future.result() + if info is not None: + table_info[_name] = info + else: + # 串行加载(兼容 TABLE_FETCH_MAX_WORKERS=0 或单表) + for table_name in table_names: + _name, info = _fetch_one_table_via_inspector(self._engine, table_name, column_permissions) + if info is not None: + table_info[_name] = info elapsed = time.time() - start_time logger.info(f"✅ 成功加载 {len(table_info)} 张表,耗时 {elapsed:.2f}s") @@ -588,6 +619,31 @@ def _fetch_table_info_from_metadata(self, user_id: Optional[int], use_cache: boo return table_info + def _get_table_names_in_metadata(self) -> Set[str]: + """ + 返回当前数据源在 t_datasource_table 中存在的表名集合(大写,用于匹配)。 + 同时包含「短名」(即 schema.table 的 table 部分),以便与 inspector 返回的 schema.table 形式匹配。 + """ + if not self._datasource_id: + return set() + try: + with db_pool.get_session() as session: + tables = ( + session.query(DatasourceTable.table_name) + .filter(DatasourceTable.ds_id == self._datasource_id) + .all() + ) + out = set() + for row in tables: + key = str(row[0]).upper() + out.add(key) + if "." in key: + out.add(key.split(".")[-1]) + return out + except Exception as e: + logger.warning(f"⚠️ 获取元数据表名失败: {e}") + return set() + def _get_precomputed_embeddings(self, table_info: Dict[str, Dict]) -> Tuple[Optional[np.ndarray], List[str], List[str]]: """ 尝试从数据库获取预计算的 embedding。 @@ -608,8 +664,13 @@ def _get_precomputed_embeddings(self, table_info: Dict[str, Dict]) -> Tuple[Opti .all() ) - # 构建表名到表的映射(不区分大小写,兼容 Oracle 等会返回大写表名的数据库) - table_map = {str(table.table_name).upper(): table for table in tables} + # 构建表名到表的映射(不区分大小写;同时支持「短名」以便 schema.table 与 table 互匹配) + table_map = {} + for table in tables: + key = str(table.table_name).upper() + table_map[key] = table + if "." in key: + table_map[key.split(".")[-1]] = table # 收集有预计算 embedding 的表 precomputed_embeddings = [] @@ -617,8 +678,8 @@ def _get_precomputed_embeddings(self, table_info: Dict[str, Dict]) -> Tuple[Opti missing_table_names = [] for table_name, info in table_info.items(): - # 统一按大写匹配,避免 T_ALARM_INFO / t_alarm_info 不一致导致无法命中 - table = table_map.get(str(table_name).upper()) + key_upper = str(table_name).upper() + table = table_map.get(key_upper) or (table_map.get(key_upper.split(".")[-1]) if "." in key_upper else None) # 检查是否有 embedding 字段(通过 hasattr 检查,避免字段不存在时报错) if table and hasattr(table, 'embedding') and table.embedding: try: @@ -689,71 +750,171 @@ def _create_embeddings_with_dashscope(self, texts: List[str]) -> np.ndarray: logger.info(f"✅ 离线模型嵌入生成完成,耗时 {time.time() - start_time:.2f}s,维度: {embedding_dim}") return embeddings - # 使用在线模型 - logger.info(f"🌐 调用在线嵌入模型 {self.embedding_model_name}...") + # 使用在线模型(批量请求以降低延迟,单次最多 25 条避免 API 限流) + logger.info(f"🌐 调用在线嵌入模型 {self.embedding_model_name}(批量)...") start_time = time.time() + EMBEDDING_BATCH_SIZE = 25 embeddings = [] - for doc in texts: + for i in range(0, len(texts), EMBEDDING_BATCH_SIZE): + batch = texts[i : i + EMBEDDING_BATCH_SIZE] try: - response = self.embedding_client.embeddings.create(model=self.embedding_model_name, input=doc) - embeddings.append(response.data[0].embedding) + response = self.embedding_client.embeddings.create( + model=self.embedding_model_name, input=batch + ) + for item in response.data: + embeddings.append(item.embedding) except Exception as e: - logger.error(f"❌ 在线模型嵌入生成失败 ({doc[:30]}...): {e}") - embeddings.append(np.zeros(1024)) # 占位符 - + logger.warning(f"⚠️ 在线模型批量嵌入失败 (batch {i}-{i+len(batch)}): {e},逐条回退") + for doc in batch: + try: + r = self.embedding_client.embeddings.create( + model=self.embedding_model_name, input=doc + ) + embeddings.append(r.data[0].embedding) + except Exception as e2: + logger.error(f"❌ 单条嵌入失败: {e2}") + embeddings.append(np.zeros(1024)) + if len(embeddings) != len(texts): + logger.warning(f"⚠️ 嵌入数量与输入不一致: {len(embeddings)} vs {len(texts)}") embeddings = np.array(embeddings).astype("float32") faiss.normalize_L2(embeddings) logger.info(f"✅ 在线模型嵌入生成完成,耗时 {time.time() - start_time:.2f}s") return embeddings + def _save_table_embeddings_to_db(self, table_names: List[str], embeddings: np.ndarray) -> None: + """ + 将表结构 embedding 写入数据库 t_datasource_table.embedding。 + table_names 与 embeddings 逐行对应。 + """ + if not self._datasource_id or not table_names or embeddings.size == 0: + return + try: + with db_pool.get_session() as session: + tables = ( + session.query(DatasourceTable) + .filter(DatasourceTable.ds_id == self._datasource_id) + .all() + ) + name_to_id = {str(t.table_name).upper(): t.id for t in tables} + saved = 0 + for i, table_name in enumerate(table_names): + if i >= len(embeddings): + break + key_upper = str(table_name).strip().upper() + table_id = name_to_id.get(key_upper) + if not table_id and "." in key_upper: + key_short = key_upper.split(".")[-1] + table_id = name_to_id.get(key_short) + if not table_id: + continue + emb_list = embeddings[i].tolist() if hasattr(embeddings[i], "tolist") else list(embeddings[i]) + stmt = ( + update(DatasourceTable) + .where(DatasourceTable.id == table_id) + .values(embedding=json.dumps(emb_list)) + ) + session.execute(stmt) + saved += 1 + session.commit() + logger.info(f"✅ 已保存 {saved}/{len(table_names)} 张表的 embedding 到数据库") + except Exception as e: + logger.warning(f"⚠️ 保存表 embedding 到数据库失败: {e}", exc_info=True) + def _initialize_vector_index(self, table_info: Dict[str, Dict]): """ - 初始化 FAISS 向量索引:从数据库读取预计算的 embedding 并构建内存索引。 - 仅使用预计算的 embedding,不在检索时做实时计算。 + 初始化 FAISS 向量索引:优先从数据库读取预计算的 embedding; + 若没有任何预计算或部分表缺失,则实时计算表结构 embedding 并写入数据库,再构建索引。 """ if self._index_initialized: return - # 构建新索引 logger.info("🏗️ 开始构建向量索引(从数据库读取 embedding)...") start_time = time.time() - # 记录所有表名和语料(用于 BM25 等) - self._table_names = list(table_info.keys()) - self._corpus = [self._build_document(name, info) for name, info in table_info.items()] + # 仅对「已在 t_datasource_table 中存在」的表建索引并持久化 embedding,避免对未同步到元数据的表重复计算且无法保存 + tables_in_metadata = self._get_table_names_in_metadata() + def _in_metadata(name: str) -> bool: + u = str(name).strip().upper() + return u in tables_in_metadata or (u.split(".")[-1] if "." in u else u) in tables_in_metadata + table_info_indexed = {k: v for k, v in table_info.items() if _in_metadata(k)} + skipped_count = len(table_info) - len(table_info_indexed) + if skipped_count > 0: + logger.info(f"📋 向量索引仅针对元数据中已存在的表:共 {len(table_info_indexed)} 张,跳过 {skipped_count} 张(仅存在于实时库、未在元数据中)") + if not table_info_indexed: + logger.warning("⚠️ 元数据中无表记录,无法构建向量索引,仅使用 BM25") + self._faiss_index = None + self._index_initialized = True + return + + self._table_names = list(table_info_indexed.keys()) + self._corpus = [self._build_document(name, info) for name, info in table_info_indexed.items()] - # 从数据库获取预计算的 embedding(不会做任何实时计算) precomputed_embeddings, precomputed_table_names, missing_table_names = self._get_precomputed_embeddings( - table_info + table_info_indexed ) - # 如果没有任何预计算 embedding,则禁用向量索引(仅使用 BM25) + # 情况1:没有任何预计算 → 实时计算全部表 embedding,写入 DB,再建索引 if precomputed_embeddings is None or len(precomputed_table_names) == 0: - logger.warning("⚠️ 未找到任何预计算的表结构 embedding,向量检索将被禁用,仅使用 BM25") - self._faiss_index = None + logger.warning("⚠️ 未找到任何预计算的表结构 embedding,将实时计算并写入数据库后构建向量索引") + try: + embeddings = self._create_embeddings_with_dashscope(self._corpus) + if embeddings is not None and embeddings.size > 0: + self._save_table_embeddings_to_db(self._table_names, embeddings) + dimension = embeddings.shape[1] + self._faiss_index = faiss.IndexFlatIP(dimension) + self._faiss_index.add(embeddings) + elapsed = time.time() - start_time + logger.info(f"🎉 向量索引构建完成(已计算并保存 {len(self._table_names)} 张表),耗时 {elapsed:.2f}s") + else: + logger.warning("⚠️ 实时计算 embedding 失败,向量检索将被禁用,仅使用 BM25") + self._faiss_index = None + except Exception as e: + logger.warning(f"⚠️ 实时计算表 embedding 失败: {e},向量检索将被禁用", exc_info=True) + self._faiss_index = None self._index_initialized = True return # 如果存在缺失的 embedding,为避免索引和表顺序不一致,这里直接禁用向量检索 if len(missing_table_names) > 0: logger.warning( - f"⚠️ 共有 {len(missing_table_names)} 张表缺少预计算 embedding," - "为保证索引与表顺序一致,本次禁用向量检索,仅使用 BM25" + f"⚠️ 共有 {len(missing_table_names)} 张表缺少预计算 embedding,将实时计算缺失部分并合并后构建索引" ) - self._faiss_index = None + try: + missing_corpus = [self._build_document(name, table_info_indexed[name]) for name in missing_table_names] + missing_embeddings = self._create_embeddings_with_dashscope(missing_corpus) + if missing_embeddings is None or missing_embeddings.size == 0: + raise ValueError("缺失表的 embedding 计算失败") + name_to_precomputed_idx = {name: i for i, name in enumerate(precomputed_table_names)} + name_to_missing_idx = {name: i for i, name in enumerate(missing_table_names)} + merged_list = [] + for name in self._table_names: + if name in name_to_precomputed_idx: + merged_list.append(precomputed_embeddings[name_to_precomputed_idx[name]]) + else: + merged_list.append(missing_embeddings[name_to_missing_idx[name]]) + embeddings = np.array(merged_list).astype("float32") + faiss.normalize_L2(embeddings) + self._save_table_embeddings_to_db(missing_table_names, missing_embeddings) + dimension = embeddings.shape[1] + self._faiss_index = faiss.IndexFlatIP(dimension) + self._faiss_index.add(embeddings) + elapsed = time.time() - start_time + logger.info(f"🎉 向量索引构建完成(已补全 {len(missing_table_names)} 张表),共 {len(self._table_names)} 张表,耗时 {elapsed:.2f}s") + except Exception as e: + logger.warning(f"⚠️ 补全缺失 embedding 失败: {e},本次禁用向量检索", exc_info=True) + self._faiss_index = None self._index_initialized = True return - # 此时说明所有表都存在预计算 embedding,顺序与 self._table_names 一致 + # 情况3:全部有预计算,直接建索引 embeddings = precomputed_embeddings - if embeddings.size == 0: logger.error("❌ 无法生成嵌入,索引构建失败") + self._index_initialized = True return - # 初始化 FAISS 索引(仅在内存中) dimension = embeddings.shape[1] - self._faiss_index = faiss.IndexFlatIP(dimension) # 内积 = 余弦相似度 + self._faiss_index = faiss.IndexFlatIP(dimension) self._faiss_index.add(embeddings) elapsed = time.time() - start_time @@ -766,7 +927,7 @@ def _retrieve_by_vector(self, query: str, top_k: int = 10) -> List[int]: 优先使用在线模型,如果没有配置则使用离线模型。 """ if not self._faiss_index: - logger.error("❌ 向量索引未初始化") + logger.warning("⚠️ 向量索引未初始化(可能无预计算 embedding 且实时计算未执行或失败),跳过向量检索,仅使用 BM25") return [] try: @@ -845,9 +1006,65 @@ def _rrf_fusion(bm25_indices: List[int], vector_indices: List[int], k: int = 60) sorted_indices = sorted(scores.items(), key=lambda x: -x[1]) return [idx for idx, _ in sorted_indices] + def _get_dashscope_rerank_url_and_payload( + self, query: str, documents: List[str] + ) -> tuple: + """ + 根据模型类型返回 DashScope Rerank 的正确 URL 与请求体。 + - qwen3-rerank: 使用 /compatible-api/v1/reranks,请求体为扁平结构。 + - gte-rerank-v2 / qwen3-vl-rerank: 使用 /api/v1/services/rerank/text-rerank/text-rerank,请求体为 input/parameters。 + """ + base_url = (self.rerank_base_url or "").strip().rstrip("/") + model = (self.rerank_model_name or "").strip().lower() + is_dashscope_domain = "dashscope.aliyuncs.com" in base_url or "aliyuncs" in base_url + has_rerank_path = "/rerank" in base_url or "/reranks" in base_url + + # 若配置的 base_url 仅为域名或缺少 rerank 路径,则按官方文档补全 + if is_dashscope_domain and not has_rerank_path: + from urllib.parse import urlparse + parsed = urlparse(base_url if base_url.startswith("http") else f"https://{base_url}") + domain = f"{parsed.scheme or 'https'}://{parsed.netloc}" + if "qwen3-rerank" in model: + effective_url = f"{domain}/compatible-api/v1/reranks" + payload = { + "model": self.rerank_model_name, + "query": query, + "documents": documents, + "top_n": len(documents), + "instruct": "Given a web search query, retrieve relevant passages that answer the query.", + } + else: + effective_url = f"{domain}/api/v1/services/rerank/text-rerank/text-rerank" + payload = { + "model": self.rerank_model_name, + "input": {"query": query, "documents": documents}, + "parameters": {"top_n": len(documents), "return_documents": False}, + } + return effective_url, payload + + # 使用配置的完整 URL,按模型选择请求体格式 + if "qwen3-rerank" in model: + payload = { + "model": self.rerank_model_name, + "query": query, + "documents": documents, + "top_n": len(documents), + "instruct": "Given a web search query, retrieve relevant passages that answer the query.", + } + elif "aliyuncs" in base_url or "qwen" in model or "gte-rerank" in model: + payload = { + "model": self.rerank_model_name, + "input": {"query": query, "documents": documents}, + "parameters": {"top_n": len(documents), "return_documents": False}, + } + else: + payload = {"query": query, "documents": documents} + return self.rerank_base_url, payload + def _rerank_with_dashscope(self, query: str, candidate_tables: Dict[str, Dict]) -> List[Tuple[str, float]]: """ 使用 DashScope 重排 API 对候选表进行重排序。 + 兼容 qwen3-rerank(/compatible-api/v1/reranks)与 gte-rerank-v2 等(/api/v1/services/rerank/...)。 """ if not self.USE_RERANKER: logger.debug("⏭️ Reranker 已禁用或配置不完整,跳过重排序") @@ -866,75 +1083,60 @@ def _rerank_with_dashscope(self, query: str, candidate_tables: Dict[str, Dict]) logger.info(f"🔁 调用重排模型 {self.rerank_model_name} 进行重排序...") - # 根据API类型选择不同的请求结构 - if "aliyuncs" in self.rerank_base_url or "Qwen" in self.rerank_model_name: - # 阿里云 DashScope 格式 - payload = { - "model": self.rerank_model_name, - "input": {"query": query, "documents": documents}, - "parameters": {"top_n": len(documents), "return_documents": False}, - } + # 获取正确的 URL 与请求体(避免 404:DashScope 不同模型路径与 body 不同) + if "aliyuncs" in (self.rerank_base_url or "") or (self.rerank_model_name or "").lower().startswith(("qwen", "gte")): + effective_url, payload = self._get_dashscope_rerank_url_and_payload(query, documents) else: - # 其他格式(如本地模型或通用rerank API) + effective_url = self.rerank_base_url payload = {"query": query, "documents": documents} - # 设置请求头 headers = {"Authorization": f"Bearer {self.rerank_api_key}", "Content-Type": "application/json"} - - # 调用重排 API - response = requests.post(self.rerank_base_url, headers=headers, json=payload, timeout=30) + response = requests.post(effective_url, headers=headers, json=payload, timeout=15) # 检查响应状态 if response.status_code != 200: logger.warning(f"⚠️ Rerank API 调用失败: {response.status_code} - {response.text}") return [(name, 1.0) for name in candidate_tables.keys()] - # 解析响应 + # 解析响应(DashScope 两种端点均返回 output.results) result_data = response.json() - # 根据API类型解析响应 - if "aliyuncs" in self.rerank_base_url or "Qwen" in self.rerank_model_name: - # 阿里云格式响应 - if "output" in result_data and "results" in result_data["output"]: - results = [] - for item in result_data["output"]["results"]: + if "output" in result_data and "results" in result_data["output"]: + results = [] + for item in result_data["output"]["results"]: + idx = item["index"] + score = item["relevance_score"] + table_name = next(name for name, text in name_to_text.items() if text == documents[idx]) + results.append((table_name, score)) + results.sort(key=lambda x: x[1], reverse=True) + logger.info("✅ Rerank 完成") + return results + # 通用格式(非 DashScope) + if "results" in result_data: + results = [] + for item in result_data["results"]: + if "index" in item and "relevance_score" in item: idx = item["index"] score = item["relevance_score"] + if "document" in item and "text" in item["document"]: + doc_text = item["document"]["text"] + table_name = next(name for name, text in name_to_text.items() if text == doc_text) + else: + table_name = next(name for name, text in name_to_text.items() if text == documents[idx]) + results.append((table_name, score)) + results.sort(key=lambda x: x[1], reverse=True) + logger.info("✅ Rerank 完成") + return results + if isinstance(result_data, list): + results = [] + for i, item in enumerate(result_data): + if isinstance(item, dict) and "index" in item: + idx = item["index"] + score = item.get("score", 1.0 - i * 0.01) table_name = next(name for name, text in name_to_text.items() if text == documents[idx]) results.append((table_name, score)) - - results.sort(key=lambda x: x[1], reverse=True) - logger.info("✅ Rerank 完成") - return results - else: - # 通用格式响应 - 假设直接返回排序结果 - if "results" in result_data: - results = [] - for item in result_data["results"]: - if "index" in item and "relevance_score" in item: # 使用relevance_score - idx = item["index"] - score = item["relevance_score"] # 使用relevance_score字段 - # 从document对象中提取文本 - if "document" in item and "text" in item["document"]: - doc_text = item["document"]["text"] - table_name = next(name for name, text in name_to_text.items() if text == doc_text) - else: - table_name = next(name for name, text in name_to_text.items() if text == documents[idx]) - results.append((table_name, score)) - results.sort(key=lambda x: x[1], reverse=True) - logger.info("✅ Rerank 完成") - return results - elif isinstance(result_data, list): - # 假设直接返回了排序后的索引列表 - results = [] - for i, item in enumerate(result_data): - if isinstance(item, dict) and "index" in item: - idx = item["index"] - score = item.get("score", 1.0 - i * 0.01) # 默认分数递减 - table_name = next(name for name, text in name_to_text.items() if text == documents[idx]) - results.append((table_name, score)) - logger.info("✅ Rerank 完成") - return results + logger.info("✅ Rerank 完成") + return results logger.warning("⚠️ Rerank API 返回格式异常") return [(name, 1.0) for name in candidate_tables.keys()] @@ -1177,42 +1379,46 @@ def get_table_schema(self, state: AgentState) -> AgentState: # 确保 user_query 也在返回的 state 中(虽然它应该已经在初始 state 中了) state["user_query"] = user_query - # 初始化向量索引 + # 初始化向量索引(内部会过滤为仅元数据中存在的表,并写入 self._table_names) self._initialize_vector_index(all_table_info) - # 混合检索 - 并行执行 BM25 和向量检索以提高性能 + # 混合检索仅针对与向量索引一致的表集合,避免 BM25 索引与 self._table_names 不一致导致越界 + table_info_for_retrieval = {k: all_table_info[k] for k in self._table_names if k in all_table_info} + n_tables = len(self._table_names) + if n_tables == 0: + state["db_info"] = {} + return state + logger.info("🔍 开始混合检索:BM25 + 向量检索(并行执行)") - # 使用线程池并行执行 BM25 和向量检索 with ThreadPoolExecutor(max_workers=2) as executor: - bm25_future = executor.submit(self._retrieve_by_bm25, all_table_info, user_query) - vector_future = executor.submit(self._retrieve_by_vector, user_query, 20) + bm25_future = executor.submit(self._retrieve_by_bm25, table_info_for_retrieval, user_query) + vector_future = executor.submit(self._retrieve_by_vector, user_query, VECTOR_RETRIEVE_TOP_K) - # 等待两个任务完成 bm25_top_indices = bm25_future.result() vector_top_indices = vector_future.result() logger.info(f"📊 BM25检索返回 {len(bm25_top_indices)} 个结果") logger.info(f"🔗 向量检索返回 {len(vector_top_indices)} 个结果") - # 过滤:仅保留同时在 BM25 前 50 和向量结果中的表 valid_bm25_set = set(bm25_top_indices[:50]) candidate_indices = [idx for idx in vector_top_indices if idx in valid_bm25_set] logger.info(f"🎯 初步筛选后保留 {len(candidate_indices)} 个候选表") if not candidate_indices: - candidate_indices = bm25_top_indices[:TABLE_RETURN_COUNT] # 降级 + candidate_indices = bm25_top_indices[:TABLE_RETURN_COUNT] logger.info(f"⚠️ 候选表为空,降级使用BM25前{TABLE_RETURN_COUNT}个结果") fused_indices = self._rrf_fusion(bm25_top_indices, candidate_indices, k=60) logger.info(f"🔄 RRF融合后得到 {len(fused_indices)} 个结果") - # 评分筛选 selected_indices = [] for idx in fused_indices: - bm25_rank = bm25_top_indices.index(idx) + 1 if idx in bm25_top_indices else len(all_table_info) + 1 + if idx < 0 or idx >= n_tables: + continue + bm25_rank = bm25_top_indices.index(idx) + 1 if idx in bm25_top_indices else n_tables + 1 vector_rank = ( - vector_top_indices.index(idx) + 1 if idx in vector_top_indices else len(all_table_info) + 1 + vector_top_indices.index(idx) + 1 if idx in vector_top_indices else n_tables + 1 ) score = 1 / (60 + bm25_rank) + 1 / (60 + vector_rank) if score >= 0.01 and len(selected_indices) < 10: @@ -1221,9 +1427,12 @@ def get_table_schema(self, state: AgentState) -> AgentState: candidate_table_names = [self._table_names[i] for i in selected_indices] candidate_table_info = {name: all_table_info[name] for name in candidate_table_names} - # 重排序 - reranked_results = self._rerank_with_dashscope(user_query, candidate_table_info) - final_table_names = [name for name, _ in reranked_results][:TABLE_RETURN_COUNT] # 取 top N(可配置) + # 候选表较少时跳过 Rerank 以节省约 1–3s + if len(candidate_table_info) <= RERANK_SKIP_WHEN_CANDIDATES_LE: + reranked_results = [(name, 1.0) for name in candidate_table_names] + else: + reranked_results = self._rerank_with_dashscope(user_query, candidate_table_info) + final_table_names = [name for name, _ in reranked_results][:TABLE_RETURN_COUNT] # 去重 final_table_names = list(dict.fromkeys(final_table_names)) @@ -1274,12 +1483,36 @@ def execute_sql(self, state: AgentState) -> AgentState: sql_to_execute = state.get("filtered_sql") or state.get("generated_sql", "") sql_to_execute = sql_to_execute.strip() if sql_to_execute else "" - if not sql_to_execute: - error_msg = "SQL 为空,无法执行" + # 处理未成功生成 SQL 的占位符,避免执行无效语句 + if not sql_to_execute or sql_to_execute == "No SQL query generated": + error_msg = ( + "SQL 为空,无法执行" + if not sql_to_execute + else "SQL 未成功生成,已跳过执行" + ) logger.warning(error_msg) - state["execution_result"] = ExecutionResult(success=False, error=error_msg) + state["execution_result"] = ExecutionResult( + success=False, + error=error_msg, + ) return state + # 预览模式行数限制(仅用于页面展示,导出等场景可通过 state['preview_limit_rows']=0 显式关闭) + try: + from os import getenv as _getenv + default_preview_limit = int(_getenv("SQL_EXEC_PREVIEW_LIMIT", "100")) + except Exception: + default_preview_limit = 100 + preview_limit = state.get("preview_limit_rows", default_preview_limit) + + sql_for_execution = sql_to_execute + if preview_limit and preview_limit > 0 and " limit " not in sql_to_execute.lower(): + base_sql = sql_to_execute.rstrip(";") + sql_for_execution = f"SELECT * FROM ({base_sql}) AS t_preview LIMIT {preview_limit}" + logger.info( + f"🔎 预览模式:对 SQL 应用 LIMIT {preview_limit} 行(不影响导出等显式关闭预览限制的场景)" + ) + logger.info("▶️ 执行 SQL 语句") # 记录使用的SQL类型(用于调试) if state.get("filtered_sql"): @@ -1299,14 +1532,14 @@ def execute_sql(self, state: AgentState) -> AgentState: logger.info(f"使用原生驱动执行 SQL(数据源类型: {self._datasource_type})") config = DatasourceConfigUtil.decrypt_config(self._datasource_config) result_data = DatasourceConnectionUtil.execute_query( - self._datasource_type, config, sql_to_execute + self._datasource_type, config, sql_for_execution ) state["execution_result"] = ExecutionResult(success=True, data=result_data) logger.info(f"✅ SQL 执行成功(原生驱动),返回 {len(result_data)} 条记录") else: # 对于 SQLAlchemy 驱动的数据库,使用 engine 执行 with self._engine.connect() as connection: - result = connection.execute(text(sql_to_execute)) + result = connection.execute(text(sql_for_execution)) result_data = result.fetchall() columns = result.keys() frame = pd.DataFrame(result_data, columns=columns) diff --git a/agent/text2sql/sql/generator.py b/agent/text2sql/sql/generator.py index 53799d41..ab228edf 100644 --- a/agent/text2sql/sql/generator.py +++ b/agent/text2sql/sql/generator.py @@ -22,6 +22,32 @@ logger = logging.getLogger(__name__) +# 用户明确要求「全部数据/导出/不限制条数」时的关键词,命中则生成 SQL 时不强制 LIMIT +_FULL_DATA_OR_EXPORT_KEYWORDS = ( + "不要限制", + "不限制条数", + "不要使用限制", + "无需限制", + "全部数据", + "所有数据", + "所有明细", + "全部明细", + "可以导出", + "需要导出", + "导出", + "无限制", + "不限条数", + "去掉限制", +) + + +def _user_wants_full_data_or_export(user_query: str) -> bool: + """判断用户是否明确要求查看全部数据或导出(不限制条数)。""" + if not (user_query and user_query.strip()): + return False + q = user_query.strip() + return any(kw in q for kw in _FULL_DATA_OR_EXPORT_KEYWORDS) + def sql_generate(state: AgentState) -> AgentState: """ @@ -69,43 +95,29 @@ def sql_generate(state: AgentState) -> AgentState: except Exception as e: logger.warning(f"获取数据源信息失败: {e},使用默认值") - # 表关系补充:在 SQL 生成阶段补充缺失的关联表,并生成外键关系信息 - # 这样可以在 SQL 生成时根据实际需要补充关联表,而不是在检索阶段就补充 - try: - from agent.text2sql.database.db_service import DatabaseService - user_id = state.get("user_id") - - # 创建 DatabaseService 实例用于表关系补充 - db_service = DatabaseService(datasource_id) - - # 获取所有表信息(用于补充关联表,使用缓存避免重复查询) - all_table_info = db_service._fetch_all_table_info(user_id=user_id, use_cache=True) - - # 获取当前选中的表名 - selected_table_names = list(db_info.keys()) - - # 补充关联表 - supplemented_table_names = db_service.supplement_related_tables( - selected_table_names, - all_table_info - ) - - # 无论是否补充新表,都用 all_table_info 中的数据(包含 foreign_keys)重建 db_info - # 这样可以确保已有表的外键关系也能传递到提示词中 - new_db_info = {} - for table_name in supplemented_table_names: - if table_name in all_table_info: - new_db_info[table_name] = all_table_info[table_name] - else: - # 兜底:保持原来的表信息 - if table_name in db_info: + # 表关系补充:多表时才拉取全量表信息并补充关联表;单表时直接用 state 中的 db_info 以节省耗时 + if len(db_info) > 1: + try: + from agent.text2sql.database.db_service import DatabaseService + user_id = state.get("user_id") + db_service = DatabaseService(datasource_id) + all_table_info = db_service._fetch_all_table_info(user_id=user_id, use_cache=True) + selected_table_names = list(db_info.keys()) + supplemented_table_names = db_service.supplement_related_tables( + selected_table_names, + all_table_info + ) + new_db_info = {} + for table_name in supplemented_table_names: + if table_name in all_table_info: + new_db_info[table_name] = all_table_info[table_name] + elif table_name in db_info: new_db_info[table_name] = db_info[table_name] - - db_info = new_db_info - state["db_info"] = db_info - logger.debug(f"表关系补充完成,当前 db_info 包含 {len(db_info)} 张表") - except Exception as e: - logger.warning(f"表关系补充失败: {e},继续使用原始表列表", exc_info=True) + db_info = new_db_info + state["db_info"] = db_info + logger.debug(f"表关系补充完成,当前 db_info 包含 {len(db_info)} 张表") + except Exception as e: + logger.warning(f"表关系补充失败: {e},继续使用原始表列表", exc_info=True) # 格式化 schema 为 M-Schema 格式(包含补充的关联表) schema_str = format_schema_to_m_schema( @@ -122,36 +134,45 @@ def sql_generate(state: AgentState) -> AgentState: # 使用 PromptBuilder 构建提示词 prompt_builder = PromptBuilder() - # RAG 增强检索:检索术语和训练示例 + # RAG 增强检索:检索术语和训练示例(带超时,避免拖慢整体) + terminologies = "" + data_training = "" try: - from agent.text2sql.rag.terminology_retriever import retrieve_terminologies - import asyncio - - # 检索术语(同步调用) - terminologies = retrieve_terminologies( - question=state["user_query"], - datasource_id=datasource_id, - oid=1, # 默认组织ID,后续可以从用户信息获取 - top_k=10, - ) - - # 检索训练示例 - from agent.text2sql.rag.training_retriever import retrieve_training_examples - data_training = retrieve_training_examples( - question=state["user_query"], - datasource_id=datasource_id, - oid=1, # 默认组织ID,后续可以从用户信息获取 - top_k=5, - ) - except Exception as e: - logger.warning(f"RAG 检索失败: {e},使用空字符串") - terminologies = "" - data_training = "" + from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError + RAG_TIMEOUT = int(__import__("os").getenv("SQL_RAG_TIMEOUT", "4")) + + def _do_rag(): + from agent.text2sql.rag.terminology_retriever import retrieve_terminologies + from agent.text2sql.rag.training_retriever import retrieve_training_examples + t = retrieve_terminologies( + question=state["user_query"], + datasource_id=datasource_id, + oid=1, + top_k=5, + ) + d = retrieve_training_examples( + question=state["user_query"], + datasource_id=datasource_id, + oid=1, + top_k=3, + ) + return t, d + + with ThreadPoolExecutor(max_workers=1) as ex: + f = ex.submit(_do_rag) + terminologies, data_training = f.result(timeout=RAG_TIMEOUT) + except (FuturesTimeoutError, Exception) as e: + if "Timeout" in type(e).__name__ or "timeout" in str(e).lower(): + logger.debug("RAG 检索超时,跳过术语与训练示例") + else: + logger.warning(f"RAG 检索失败: {e},使用空字符串") custom_prompt = "" # 自定义提示词(暂时为空) error_msg = "" # 错误消息(暂时为空) # 获取系统提示词和用户提示词 + # 用户明确要求全部数据/导出时不强制 LIMIT,否则默认限制 40 条 + enable_query_limit = not _user_wants_full_data_or_export(state.get("user_query", "")) system_prompt, user_prompt = prompt_builder.build_sql_prompt( db_type=db_type, schema=schema_str, @@ -161,7 +182,7 @@ def sql_generate(state: AgentState) -> AgentState: terminologies=terminologies, data_training=data_training, custom_prompt=custom_prompt, - enable_query_limit=True, # 启用查询限制 + enable_query_limit=enable_query_limit, error_msg=error_msg, current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), change_title=False, # 暂时不生成对话标题 diff --git a/common/llm_util.py b/common/llm_util.py index 39b8677b..18196b5d 100644 --- a/common/llm_util.py +++ b/common/llm_util.py @@ -73,13 +73,28 @@ def _get_openai(): f"[ERROR] Failed to import ChatOpenAI, please check langchain-openai/langsmith/opentelemetry installation: {e}" ) raise - + # 构造额外参数,强制关闭 thinking 模式以适配非流式请求 + # 某些后端通过 default_headers 传递,某些通过 extra_body + # 这里优先使用 default_headers 覆盖,防止默认开启 thinking 导致报错 + extra_headers = { + "enable_thinking": "false", + # 如果后端需要布尔值而不是字符串,可能需要特定处理,但 HTTP 头通常是字符串 + # 对于某些兼容层,可能需要放在 extra_body 中 + } + + # 针对部分后端(如 volcengine, some vLLM setups),参数可能在 body 里 + extra_body = { + "enable_thinking": False + } return ChatOpenAI( model=model_name, temperature=temperature, base_url=model_base_url, api_key=model_api_key or "empty", # Ensure not None timeout=timeout, # 设置超时时间(秒) + # 关键修改:传入额外参数以禁用 thinking + default_headers=extra_headers, + extra_body=extra_body, ) def _get_ollama(): diff --git a/common/local_embedding.py b/common/local_embedding.py index 0f6d0cf1..2b2a2561 100644 --- a/common/local_embedding.py +++ b/common/local_embedding.py @@ -179,7 +179,8 @@ def _get_local_embedding_model(): return None except Exception as e: - logger.error(f"Failed to load local embedding model: {e}", exc_info=True) + # 本地 embedding 模型加载失败时只输出简要错误,避免在控制台刷堆栈 + logger.warning("Failed to load local embedding model: %s", e) return None diff --git a/controllers/export_api.py b/controllers/export_api.py new file mode 100644 index 00000000..9481f516 --- /dev/null +++ b/controllers/export_api.py @@ -0,0 +1,122 @@ +import logging + +from sanic import Blueprint, Request, response +from sanic_ext import openapi + +from common.exception import MyException +from common.token_decorator import check_token +from common.param_parser import parse_params +from constants.code_enum import SysCodeEnum +from common.res_decorator import async_json_resp +from model.schemas import ExportTableRequest, TablePageRequest, get_schema +from services.export_table_service import run_sql_for_export, run_sql_page, data_to_csv +from services.user_service import decode_jwt_token + + +logger = logging.getLogger(__name__) + +bp = Blueprint("exportApi", url_prefix="/export") + + +@bp.post("/table_csv") +@openapi.summary("导出表格数据为 CSV") +@openapi.description("根据 SQL 与数据源执行查询(应用权限过滤),并返回 CSV 文件。") +@openapi.tag("数据导出") +@openapi.body( + { + "application/json": { + "schema": get_schema(ExportTableRequest), + } + }, + description="导出请求体", + required=True, +) +@check_token +@parse_params +async def export_table_csv(request: Request, body: ExportTableRequest): + """ + 表格数据导出接口(CSV) + - 使用与数据问答相同的数据源与权限体系; + - SQL 可以为无限制条数查询,导出全部匹配数据; + - 返回 text/csv 响应,前端可直接触发下载。 + """ + try: + token = request.headers.get("Authorization") + if token and token.startswith("Bearer "): + token = token.split(" ")[1] + user_dict = await decode_jwt_token(token) + user_id = user_dict.get("id", 1) + + sql = body.sql + datasource_id = body.datasource_id + filename = (body.filename or "export").strip() or "export" + + data, err = run_sql_for_export(sql, datasource_id=datasource_id, user_id=user_id) + if err is not None: + raise MyException(SysCodeEnum.c_9999, msg=err) + + csv_body = data_to_csv(data or []) + headers = { + "Content-Type": "text/csv; charset=utf-8", + "Content-Disposition": f'attachment; filename="{filename}.csv"', + } + return response.text(csv_body, headers=headers) + except MyException: + raise + except Exception as e: + logger.error(f"导出表格数据失败: {e}") + raise MyException(SysCodeEnum.c_9999) + + +@bp.post("/table_page") +@openapi.summary("表格数据分页查询") +@openapi.description("根据 SQL 与数据源执行分页查询(应用权限过滤),返回当前页数据。") +@openapi.tag("数据导出") +@openapi.body( + { + "application/json": { + "schema": get_schema(TablePageRequest), + } + }, + description="分页查询请求体", + required=True, +) +@openapi.response( + 200, + { + "application/json": { + "schema": {"type": "object"}, + } + }, + description="分页查询结果", +) +@check_token +@async_json_resp +@parse_params +async def table_page(request: Request, body: TablePageRequest): + """ + 表格数据分页查询接口(JSON) + - 用于前端 NDataTable 翻页; + - 会应用与数据问答相同的权限体系; + - 不对原始 SQL 强制添加 LIMIT,分页逻辑在子查询外层完成。 + """ + token = request.headers.get("Authorization") + if token and token.startswith("Bearer "): + token = token.split(" ")[1] + + from services.user_service import decode_jwt_token + + user_dict = await decode_jwt_token(token) + user_id = user_dict.get("id", 1) + + data, err = run_sql_page( + body.sql, + datasource_id=body.datasource_id, + user_id=user_id, + page=body.page, + size=body.size, + ) + if err is not None: + raise MyException(SysCodeEnum.c_9999, msg=err) + return data + diff --git a/controllers/llm_chat_api.py b/controllers/llm_chat_api.py index 9a4ae8e3..3fdfea71 100644 --- a/controllers/llm_chat_api.py +++ b/controllers/llm_chat_api.py @@ -16,6 +16,7 @@ DifyGetSuggestedResponse, StopChatRequest, StopChatResponse, + ExportTableRequest, get_schema, ) diff --git a/model/schemas.py b/model/schemas.py index f51b552c..bf203858 100644 --- a/model/schemas.py +++ b/model/schemas.py @@ -484,6 +484,23 @@ class StopChatResponse(BaseResponse): data: Dict[str, str] = Field(description="停止结果") +class ExportTableRequest(BaseModel): + """表格数据导出请求(数据问答结果导出为 CSV)""" + + sql: str = Field(description="要执行的 SQL(将应用权限过滤)") + datasource_id: int = Field(description="数据源 ID") + filename: Optional[str] = Field(None, description="下载文件名(不含扩展名),默认 export.csv") + + +class TablePageRequest(BaseModel): + """表格数据分页请求(基于 SQL 的分页查询)""" + + sql: str = Field(description="要执行的 SQL(将应用权限过滤)") + datasource_id: int = Field(description="数据源 ID") + page: int = Field(1, description="页码(从 1 开始)") + size: int = Field(20, description="每页条数") + + # ==================== 文件服务相关模型 ==================== class ReadFileRequest(BaseModel): """读取文件请求""" diff --git a/services/datasource_service.py b/services/datasource_service.py index d966fc27..da5b0b77 100644 --- a/services/datasource_service.py +++ b/services/datasource_service.py @@ -418,12 +418,36 @@ def _compute_and_save_table_embeddings_batch(session: Session, items: List[Dict[ # 获取 embedding 客户端(支持在线/离线模型切换) embedding_client, model_name = DatasourceService._get_embedding_client() + all_embeddings = [] try: if embedding_client and model_name: # 使用在线模型批量计算 logger.info(f"批量计算 {len(docs)} 个表的 embedding(在线模型: {model_name})...") - response = embedding_client.embeddings.create(model=model_name, input=docs) - data = response.data or [] + # TODO 改为小批次计算 response = embedding_client.embeddings.create(model=model_name, input=docs) 避免出现InternalError.Algo.InvalidParameter: Value error, batch size is invalid, it should not be larger than 25.: input.contents + # 定义最大批次大小,根据报错信息设置为 25,为了安全也可以设为 20 + MAX_BATCH_SIZE = 25 + for i in range(0, len(docs), MAX_BATCH_SIZE): + batch_docs = docs[i: i + MAX_BATCH_SIZE] + try: + # 对每个小批次发起请求 + response = embedding_client.embeddings.create(model=model_name, input=batch_docs) + + # 这里需要根据实际 response 结构提取 embeddings + # 假设 response.data 是包含 embedding 对象的列表 + batch_embeddings = [item.embedding for item in response.data] + all_embeddings.extend(batch_embeddings) + + # 可选:保存当前批次的结果到数据库,避免全部失败 + # self._save_embeddings_to_db(batch_docs, batch_embeddings) + + except Exception as e: + # 记录错误日志,决定是跳过还是抛出 + logger.error(f"批量计算 embedding 失败,批次范围 {i}-{i + len(batch_docs)}: {e}") + raise e # 或者选择 continue 跳过当前批次 + + # response = embedding_client.embeddings.create(model=model_name, input=docs) + + data = all_embeddings or [] if len(data) != len(tables_for_embedding): logger.warning( @@ -461,6 +485,7 @@ def _compute_and_save_table_embeddings_batch(session: Session, items: List[Dict[ except Exception as e: logger.error(f"批量计算表 embedding 失败: {e}", exc_info=True) + @staticmethod def update_datasource(session: Session, ds_id: int, data: Dict[str, Any]) -> Optional[Datasource]: """更新数据源""" diff --git a/services/embedding_service.py b/services/embedding_service.py index f74f4e6d..a67d653a 100644 --- a/services/embedding_service.py +++ b/services/embedding_service.py @@ -85,8 +85,11 @@ async def generate_embedding(text: str) -> Optional[List[float]]: return response.data[0].embedding except Exception as e: - traceback.print_exc() - logger.warning(f"Failed to generate embedding with online model: {e}, falling back to local CPU model") + # 在线 embedding 失败时降级为本地模型,避免在控制台打印完整异常堆栈 + logger.warning( + "Failed to generate embedding with online model: %s, falling back to local CPU model", + e, + ) # 在线模型失败时,回退到本地模型 from common.local_embedding import generate_embedding_local return await generate_embedding_local(text) diff --git a/services/export_table_service.py b/services/export_table_service.py new file mode 100644 index 00000000..6210494c --- /dev/null +++ b/services/export_table_service.py @@ -0,0 +1,217 @@ +""" +表格数据导出服务 +根据 SQL + 数据源执行查询(含权限过滤),返回可导出为 CSV/Excel 的数据。 +""" + +import io +import logging +from typing import List, Dict, Any, Tuple, Optional + +from sqlalchemy import text + +from agent.text2sql.database.db_service import DatabaseService +from agent.text2sql.permission.filter_injector import permission_filter_injector +from agent.text2sql.analysis.data_render_antv import extract_table_names_sqlglot +from services.datasource_service import DatasourceService +from model.db_connection_pool import get_db_pool +from common.datasource_util import DatasourceConfigUtil, DatasourceConnectionUtil, DB, ConnectType + +logger = logging.getLogger(__name__) + + +def run_sql_for_export( + sql: str, + datasource_id: int, + user_id: int, +) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]: + """ + 执行 SQL(应用权限过滤)并返回结果数据,供导出使用。 + + Args: + sql: 原始 SQL(可为无 LIMIT 的查询) + datasource_id: 数据源 ID + user_id: 当前用户 ID(用于权限过滤) + + Returns: + (data_list, error_message)。成功时 data_list 为 list of dict,失败时 data_list 为 None、error_message 为错误信息。 + """ + if not sql or not sql.strip() or sql.strip() == "No SQL query generated": + return None, "SQL 为空,无法导出" + + sql = sql.strip() + db_type = "mysql" + try: + with get_db_pool().get_session() as session: + ds = DatasourceService.get_datasource_by_id(session, datasource_id) + if not ds: + return None, "数据源不存在" + db_type = ds.type or "mysql" + except Exception as e: + logger.warning(f"获取数据源类型失败: {e}") + table_names = extract_table_names_sqlglot(sql, db_type) + db_info = {t: {} for t in table_names} if table_names else {} + + state = { + "generated_sql": sql, + "filtered_sql": None, + "datasource_id": datasource_id, + "user_id": user_id, + "db_info": db_info, + "used_tables": table_names, + # 导出场景显式关闭预览限制,执行完整 SQL + "preview_limit_rows": 0, + } + try: + permission_filter_injector(state) + except Exception as e: + logger.warning(f"权限过滤失败: {e}", exc_info=True) + db_service = DatabaseService(datasource_id) + db_service.execute_sql(state) + result = state.get("execution_result") + if not result: + return None, "执行结果为空" + if not result.success: + return None, result.error or "执行失败" + data = result.data + if not data: + return [], None + if isinstance(data, list): + return data, None + return list(data), None + + +def run_sql_page( + sql: str, + datasource_id: int, + user_id: int, + page: int, + size: int, +) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + 带分页的 SQL 查询,用于前端表格翻页。 + - 会应用权限过滤; + - 不对原始 SQL 追加 LIMIT,而是通过子查询包装实现分页; + - 返回 {total, page, size, rows}。 + """ + if not sql or not sql.strip() or sql.strip() == "No SQL query generated": + return None, "SQL 为空,无法查询" + + if page <= 0: + page = 1 + if size <= 0: + size = 20 + + sql = sql.strip() + db_type = "mysql" + try: + with get_db_pool().get_session() as session: + ds = DatasourceService.get_datasource_by_id(session, datasource_id) + if not ds: + return None, "数据源不存在" + db_type = ds.type or "mysql" + except Exception as e: + logger.warning(f"获取数据源类型失败: {e}") + + table_names = extract_table_names_sqlglot(sql, db_type) + db_info = {t: {} for t in table_names} if table_names else {} + + # 通过权限注入节点获取 filtered_sql(关闭预览限制,保证分页在完整 SQL 上进行) + state: Dict[str, Any] = { + "generated_sql": sql, + "filtered_sql": None, + "datasource_id": datasource_id, + "user_id": user_id, + "db_info": db_info, + "used_tables": table_names, + "preview_limit_rows": 0, + } + try: + permission_filter_injector(state) + except Exception as e: + logger.warning(f"权限过滤失败: {e}", exc_info=True) + + filtered_sql = state.get("filtered_sql") or state.get("generated_sql", "") + if not filtered_sql: + return None, "权限过滤后 SQL 为空" + + base_sql = filtered_sql.strip().rstrip(";") + offset = (page - 1) * size + + count_sql = f"SELECT COUNT(*) AS cnt FROM ({base_sql}) AS t_count" + page_sql = f"SELECT * FROM ({base_sql}) AS t_page LIMIT {size} OFFSET {offset}" + + # 使用 DatabaseService 中的数据源信息来判断驱动类型 + db_service = DatabaseService(datasource_id) + rows: List[Dict[str, Any]] = [] + total: int = 0 + + try: + use_native_driver = False + if db_service._datasource_type and datasource_id: + db_enum = DB.get_db(db_service._datasource_type, default_if_none=True) + use_native_driver = db_enum.connect_type == ConnectType.py_driver + + if use_native_driver and db_service._datasource_config: + # 原生驱动:使用 DatasourceConnectionUtil 执行 + config = DatasourceConfigUtil.decrypt_config(db_service._datasource_config) + count_result = DatasourceConnectionUtil.execute_query( + db_service._datasource_type, config, count_sql + ) + if count_result and isinstance(count_result, list): + first_row = count_result[0] + total = int(first_row.get("cnt") or 0) + page_result = DatasourceConnectionUtil.execute_query( + db_service._datasource_type, config, page_sql + ) + rows = page_result or [] + else: + # SQLAlchemy 驱动 + if not db_service._engine: + return None, "数据源未正确初始化(缺少 SQLAlchemy engine)" + with db_service._engine.connect() as connection: + count_result = connection.execute(text(count_sql)).fetchone() + if count_result is not None: + # count(*) 可能通过索引 0 或键 'cnt' 访问 + total = int(getattr(count_result, "cnt", count_result[0])) + + result = connection.execute(text(page_sql)) + result_rows = result.fetchall() + columns = list(result.keys()) + for row in result_rows: + row_dict: Dict[str, Any] = {} + for i, col in enumerate(columns): + row_dict[col] = row[i] + rows.append(row_dict) + except Exception as e: + logger.error(f"分页查询失败: {e}", exc_info=True) + return None, f"分页查询失败: {e}" + + return { + "total": total, + "page": page, + "size": size, + "rows": rows, + }, None + + +def data_to_csv(data: List[Dict[str, Any]], add_bom: bool = True) -> str: + """将 list of dict 转为 CSV 字符串(表头为所有出现过的 key,BOM 便于 Excel 识别 UTF-8)。""" + import csv + if not data: + return "" + all_keys = [] + seen = set() + for row in data: + for k in row: + if k not in seen: + seen.add(k) + all_keys.append(k) + out = io.StringIO() + writer = csv.DictWriter(out, fieldnames=all_keys, extrasaction="ignore") + writer.writeheader() + for row in data: + writer.writerow(row) + body = out.getvalue() + if add_bom: + body = "\ufeff" + body + return body diff --git a/web/src/api/index.ts b/web/src/api/index.ts index 0c3d6ac9..07de1501 100644 --- a/web/src/api/index.ts +++ b/web/src/api/index.ts @@ -1,6 +1,7 @@ // import { mockEventStreamText } from '@/data' // import { currentHost } from '@/utils/location' // import request from '@/utils/request' +import { useUserStore } from '@/store/business/userStore' /** * Event Stream 调用大模型接口 Ollama3 (Fetch 调用) @@ -199,6 +200,77 @@ export async function fead_back(chat_id, rating) { return fetch(req) } +/** + * 导出表格数据为 CSV + * @param sql + * @param datasource_id + * @param filename 不含扩展名,默认为 export + */ +export async function export_table_csv(sql: string, datasource_id: number, filename?: string) { + const userStore = useUserStore() + const token = userStore.getUserToken() + const url = new URL(`${location.origin}/sanic/export/table_csv`) + + const req = new Request(url, { + mode: 'cors', + method: 'post', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}`, + }, + body: JSON.stringify({ + sql, + datasource_id, + filename, + }), + }) + + const res = await fetch(req) + if (!res.ok) { + throw new Error(`导出失败,状态码:${res.status}`) + } + + const blob = await res.blob() + const objectUrl = URL.createObjectURL(blob) + const a = document.createElement('a') + a.href = objectUrl + a.download = `${(filename || 'export').trim() || 'export'}.csv` + document.body.appendChild(a) + a.click() + document.body.removeChild(a) + URL.revokeObjectURL(objectUrl) +} + +/** + * 表格分页查询(基于 SQL) + * @param sql + * @param datasource_id + * @param page + * @param size + */ +export async function table_page(sql: string, datasource_id: number, page = 1, size = 20) { + const userStore = useUserStore() + const token = userStore.getUserToken() + const url = new URL(`${location.origin}/sanic/export/table_page`) + + const req = new Request(url, { + mode: 'cors', + method: 'post', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}`, + }, + body: JSON.stringify({ + sql, + datasource_id, + page, + size, + }), + }) + + return fetch(req) +} + /** * 问题建议 * @param chat_id diff --git a/web/src/components/MarkdownPreview/index.vue b/web/src/components/MarkdownPreview/index.vue index d9ed7439..0eb6d0c4 100644 --- a/web/src/components/MarkdownPreview/index.vue +++ b/web/src/components/MarkdownPreview/index.vue @@ -55,6 +55,7 @@ interface Props { recommended_questions?: string[] } | null recordId?: number // 记录ID,用于查询SQL语句 + datasourceId?: number // 数据源ID,用于导出 } // 解构 props @@ -722,6 +723,7 @@ const currentQaOption = computed(() => { :chart-data="props.chartData" :record-id="props.recordId" :qa-type="props.qaType" + :datasource-id="props.datasourceId" @chart-rendered="() => onChartCompletedReader()" @table-rendered="() => onTableCompletedReader()" /> diff --git a/web/src/components/MarkdownPreview/markdown-antv.vue b/web/src/components/MarkdownPreview/markdown-antv.vue index 6e63ef96..5668d624 100644 --- a/web/src/components/MarkdownPreview/markdown-antv.vue +++ b/web/src/components/MarkdownPreview/markdown-antv.vue @@ -1,5 +1,5 @@