Skip to content

Commit b60ba44

Browse files
committed
1. 优化回答格式
2. 使用view提高准确度 3. 优化提示词
1 parent 5d5ad20 commit b60ba44

File tree

5 files changed

+121
-6
lines changed

5 files changed

+121
-6
lines changed

eino/eino.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package eino
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
)
78

@@ -40,3 +41,25 @@ func ChoiceSQL(sqls, ddl, question string) (sql string, err error) {
4041
}
4142
return result.Content, nil
4243
}
44+
45+
func PrettyRes(sql, question string, runResult []map[string]interface{}) (res string, err error) {
46+
ctx := context.Background()
47+
marshal, err := json.Marshal(runResult)
48+
if err != nil {
49+
return "", err
50+
}
51+
messages, err := prettyMessages(question, sql, string(marshal))
52+
if err != nil {
53+
return "", err
54+
}
55+
56+
cm, err := createOpenAIChatModel(ctx)
57+
if err != nil {
58+
return "", err
59+
}
60+
result, err := generate(ctx, cm, messages)
61+
if err != nil {
62+
return "", fmt.Errorf("优化回答失败: %w", err)
63+
}
64+
return result.Content, nil
65+
}

eino/message.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ func createTemplate() prompt.ChatTemplate {
3131
"Pay attention to use CURDATE() function to get the current date, if the question involves \"today\"."+
3232
"Can only perform queries and does not accept any modification or deletion functions."+
3333
"Use the following table schema info to create your SQL query:\n{ddl}\n"+
34+
"I have simplified the syntax of view to be similar to table syntax."+
35+
"If the view can meet the needs, use the view first, otherwise use the table"+
36+
"Think about it step by step:\n1. Determine the tables to be joined...\n2. Identify the filter conditions...\n3. Choose the aggregation method..."+
3437
"The returned content can only contain SQL statements, without explanations or other information, and should not be labeled with SQL tags.",
3538
),
3639
schema.MessagesPlaceholder("chat_history", true),
@@ -84,3 +87,36 @@ func choiceSqlMessages(sqls, ddl, question string) ([]*schema.Message, error) {
8487
}
8588
return messages, nil
8689
}
90+
91+
// createAnswerTemplate 创建并返回一个配置好的聊天模板
92+
func createAnswerTemplate() prompt.ChatTemplate {
93+
// 创建模板,使用 FString 格式
94+
return prompt.FromMessages(schema.FString,
95+
// 系统消息模板
96+
schema.SystemMessage("{role}"+
97+
"Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question."+
98+
"The SQL statement obtained is: {sql}。 "+
99+
"The result obtained after executing SQL is: {result}。"+
100+
"Answer user questions based on the executed SQL and the obtained answers.",
101+
),
102+
schema.MessagesPlaceholder("chat_history", true),
103+
// 用户消息模板
104+
schema.UserMessage("Question: {question}"),
105+
)
106+
}
107+
108+
// prettyMessages 优化回答
109+
func prettyMessages(question, sql, answer string) ([]*schema.Message, error) {
110+
template := createAnswerTemplate()
111+
data := map[string]any{
112+
"role": role,
113+
"question": question,
114+
"sql": sql,
115+
"result": answer,
116+
}
117+
messages, err := formatMessages(template, data)
118+
if err != nil {
119+
return nil, err
120+
}
121+
return messages, nil
122+
}

mysql/mysql.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ type Db struct {
1919
DataSourceName string // 数据库连接字信息
2020
db *gorm.DB // gorm数据库实例
2121
ddl []string // 数据库表结构定义列表
22+
23+
OnlyView bool // 是否只使用 视图
24+
OnlyTable bool // 是否只使用 表
2225
}
2326

2427
// Init 初始化数据库连接并获取表结构
@@ -62,12 +65,37 @@ func (x *Db) getDdl() ([]string, error) {
6265
if err := x.db.Raw(fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)).Scan(&result).Error; err != nil {
6366
return nil, fmt.Errorf("获取表%s的结构失败: %w", tableName, err)
6467
}
65-
66-
createTable, ok := result["Create Table"].(string)
67-
if !ok {
68-
return nil, fmt.Errorf("表%s的结构格式异常", tableName)
68+
if d, e := result["Create Table"]; e {
69+
if x.OnlyView {
70+
continue
71+
}
72+
createTable := d.(string)
73+
ddls = append(ddls, createTable)
74+
continue
75+
}
76+
if v, e := result["View"]; e {
77+
if x.OnlyTable {
78+
continue
79+
}
80+
var columns []ColumnMeta
81+
// 其实table也可以这样使用
82+
if err := x.db.Raw(fmt.Sprintf("SHOW FULL COLUMNS FROM %s", v)).Scan(&columns).Error; err != nil {
83+
return nil, fmt.Errorf("获取表%s的结构失败: %w", tableName, err)
84+
}
85+
createTable := fmt.Sprintf("CREATE VIEW `%s` (", v)
86+
for _, column := range columns {
87+
createTable += fmt.Sprintf("\n`%s` %s", column.Field, column.Type)
88+
if column.Comment != "" {
89+
createTable += fmt.Sprintf(" COMMENT '%s'", column.Comment)
90+
}
91+
createTable += ","
92+
}
93+
createTable = createTable[:len(createTable)-1]
94+
createTable += ")"
95+
ddls = append(ddls, createTable)
96+
continue
6997
}
70-
ddls = append(ddls, createTable)
98+
return nil, fmt.Errorf("获取表%s的结构失败: %w", tableName, err)
7199
}
72100
return ddls, nil
73101
}

mysql/view.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package mysql
2+
3+
// ColumnMeta 定义列的元数据
4+
type ColumnMeta struct {
5+
Field string // 列名
6+
Type string // 数据类型
7+
Key string
8+
Comment string // 备注,很重要
9+
}

text2sql.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,25 @@ type Text2sql struct {
2424
question string // 用户输入的问题
2525
sqls []string // 生成的SQL列表
2626
removeSqls []string // 去除空白后的SQL列表
27+
28+
OnlyView bool // 是否只使用 视图
29+
OnlyTable bool // 是否只使用 表
30+
}
31+
32+
func (x *Text2sql) Pretty(question string) (sql string, runResult string, err error) {
33+
var res []map[string]interface{}
34+
sql, res, err = x.Do(question)
35+
if len(res) == 0 {
36+
runResult = ""
37+
return
38+
}
39+
// 优化回答
40+
runResult, err = eino.PrettyRes(sql, question, res)
41+
if err != nil {
42+
err = fmt.Errorf("优化回答失败: %w", err)
43+
return
44+
}
45+
return
2746
}
2847

2948
// Do 执行文本到SQL的转换过程
@@ -48,7 +67,7 @@ func (x *Text2sql) Do(question string) (sql string, runResult []map[string]inter
4867
}
4968

5069
// 初始化数据库连接
51-
db := &mysql.Db{DataSourceName: x.DbLink}
70+
db := &mysql.Db{DataSourceName: x.DbLink, OnlyView: x.OnlyView, OnlyTable: x.OnlyTable}
5271
if err = db.Init(); err != nil {
5372
err = fmt.Errorf("初始化数据库失败: %w", err)
5473
return

0 commit comments

Comments
 (0)