11import logging
2+ import re
23import traceback
3-
4+ from typing import List
5+ import jieba
46import pandas as pd
57from sqlalchemy .inspection import inspect
68from sqlalchemy .sql .expression import text
7-
9+ from rank_bm25 import BM25Okapi
810from agent .text2sql .state .agent_state import AgentState , ExecutionResult
911from model .db_connection_pool import get_db_pool
1012
@@ -19,7 +21,37 @@ class DatabaseService:
1921
2022 def __init__ (self ):
2123 pass
22- # self.engine = db_pool.get_engine()
24+
25+ @staticmethod
26+ def _build_document (table_name : str , table_info : dict ) -> str :
27+ """
28+ 将表结构拼接成一段文本,用于匹配
29+ """
30+ parts = []
31+
32+ # 添加表名和注释
33+ table_comment = table_info .get ("table_comment" , "" )
34+ parts .append (f"{ table_name } { table_comment } " )
35+
36+ # 添加列信息
37+ for col_name , col_info in table_info .get ("columns" , {}).items ():
38+ col_comment = col_info .get ("comment" , "" )
39+ col_cn_name = col_info .get ("cn_name" , "" )
40+ parts .append (f"{ col_name } { col_cn_name } { col_comment } " )
41+
42+ return " " .join (parts )
43+
44+ @staticmethod
45+ def _tokenize_text (text : str ) -> List [str ]:
46+ """
47+ :param text
48+ 对文本进行分词
49+ """
50+ # 过滤掉标点符号和特殊字符,只保留中文、英文和数字
51+ filtered_text = re .sub (r"[^\u4e00-\u9fa5a-zA-Z0-9]" , " " , text )
52+ tokens = list (jieba .cut (filtered_text ))
53+ # 过滤空字符串
54+ return [token .strip () for token in tokens if token .strip ()]
2355
2456 @staticmethod
2557 def get_table_schema (state : AgentState ):
@@ -28,6 +60,7 @@ def get_table_schema(state: AgentState):
2860 :param state:
2961 :return:
3062 获取数据中所有表schema信息
63+ 使用BM25算法过滤出相关表信息
3164 :return: 表schema信息
3265 """
3366 try :
@@ -46,7 +79,60 @@ def get_table_schema(state: AgentState):
4679 ]
4780
4881 table_info [table_name ] = {"columns" : columns , "foreign_keys" : foreign_keys }
49- state ["db_info" ] = table_info
82+
83+ # 如果有用户查询,则根据查询过滤表信息
84+ user_query = state .get ("user_query" , "" )
85+ if user_query and table_info :
86+ # 构建表文档
87+ corpus = []
88+ table_names = []
89+ table_comments = []
90+ for table_name , info in table_info .items ():
91+ doc = DatabaseService ._build_document (table_name , info )
92+ corpus .append (doc )
93+ table_names .append (table_name )
94+ table_comments .append (info .get ("table_comment" , "" ))
95+
96+ # 对文档进行分词
97+ tokenized_corpus = [DatabaseService ._tokenize_text (doc ) for doc in corpus ]
98+
99+ # 使用BM25算法训练模型
100+ bm25 = BM25Okapi (tokenized_corpus )
101+
102+ # 对查询进行分词
103+ query_tokens = DatabaseService ._tokenize_text (user_query )
104+
105+ # 计算文档得分
106+ doc_scores = bm25 .get_scores (query_tokens )
107+
108+ # 优化算法:提高表注释匹配的优先级
109+ # 如果查询内容直接匹配表注释,则给予更高的权重
110+ for i , (table_comment , score ) in enumerate (zip (table_comments , doc_scores )):
111+ if score > 0 and table_comment :
112+ # 检查查询是否直接包含表注释中的关键词
113+ comment_tokens = DatabaseService ._tokenize_text (table_comment )
114+ query_text = "" .join (query_tokens )
115+ comment_text = "" .join (comment_tokens )
116+
117+ # 如果查询中包含表注释的关键内容,增加权重
118+ if comment_text and (comment_text in query_text or query_text in comment_text ):
119+ doc_scores [i ] *= 2 # 给予两倍权重
120+
121+ # 按得分排序,取前3个最相关的表
122+ top_indices = sorted (range (len (doc_scores )), key = lambda i : doc_scores [i ], reverse = True )[:3 ]
123+
124+ # 只保留最相关的表
125+ filtered_table_info = {
126+ table_names [idx ]: table_info [table_names [idx ]]
127+ for idx in top_indices
128+ if doc_scores [idx ] > 0 # 只保留得分大于0的表
129+ }
130+
131+ state ["db_info" ] = filtered_table_info
132+ else :
133+ state ["db_info" ] = table_info
134+
135+ logger .info (f"获取数据库表信息成功: { state .get ('db_info' )} " )
50136 except Exception as e :
51137 logger .error (f"获取数据库表信息失败: { e } " )
52138 state ["db_info" ] = {}
0 commit comments