Skip to content

Commit fbff4ed

Browse files
committed
多次请求后选出最正确的sql
1 parent 2e799b5 commit fbff4ed

File tree

9 files changed

+209
-17
lines changed

9 files changed

+209
-17
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
.idea
1+
.idea
2+
/cmd/text2sql/go.work.sum

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func main() {
3838
DbLink: "root:password@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local",
3939
Try: 5, // 失败时的重试次数
4040
ShouldRun: true, // 是否执行生成的SQL
41+
Times: 3, // 同时生成3个SQL,选择最合适的一个
4142
}
4243

4344
// 将中文问题转换为SQL并执行
@@ -58,6 +59,10 @@ func main() {
5859
- `ShouldRun`: 是否执行生成的SQL语句
5960
- 设置为`true`时会执行SQL并返回结果
6061
- 设置为`false`时只返回生成的SQL语句
62+
- `Times`: 同时生成SQL的次数(可选,默认1)
63+
- 取值范围:1-10
64+
- 数值越大,生成的SQL候选数越多,选择最合适的SQL的准确率越高
65+
- 注意:数值越大消耗的token也越多
6166

6267
## 命令行工具
6368
[命令行工具](./cmd/text2sql/README.md)

cmd/text2sql/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@ text2sql还提供了命令行工具,可以直接在终端中使用:
77
go install github.com/wangle201210/text2sql/cmd/text2sql@latest
88

99
# 使用示例
10-
text2sql -l "root:password@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local" -q "王五的openid" -t 5 -r
10+
text2sql -l "root:password@tcp(127.0.0.1:3306)/database" -q "王五的openid" -t 5 -r -n 3
1111
```
1212

1313
参数说明:
1414
- `-l`: MySQL数据库连接信息(必填)
1515
- `-q`: 查询语句描述(必填)
1616
- `-t`: 失败后重试次数(可选,默认5次)
1717
- `-r`: 是否执行生成的SQL(可选,默认false)
18+
- `-n`: 同时生成SQL的次数(可选,默认1)
19+
- 取值范围:1-10
20+
- 数值越大,生成的SQL候选数越多,选择最合适的SQL的准确率越高
21+
- 注意:数值越大消耗的token也越多

cmd/text2sql/go.work

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
go 1.23.1
2+
3+
use .
4+
5+
replace github.com/wangle201210/text2sql => ../../

cmd/text2sql/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ func main() {
1616
question string
1717
try int
1818
run bool
19+
times int
1920
)
2021

2122
flag.StringVar(&link, "l", "", "mysql链接信息")
2223
flag.StringVar(&question, "q", "", "查询语句描述")
2324
flag.IntVar(&try, "t", 5, "失败后重试次数")
2425
flag.BoolVar(&run, "r", false, "获取到sql后是否执行")
26+
flag.IntVar(&times, "n", 1, "同时生成SQL的次数(1-10之间),数值越大准确率越高,但消耗的token也越多")
2527
flag.Parse()
2628
if link == "" || question == "" {
2729
fmt.Println("请输入正确参数")
@@ -32,6 +34,7 @@ func main() {
3234
DbLink: link,
3335
Try: try,
3436
ShouldRun: run,
37+
Times: times,
3538
}
3639
sql, result, err := ts.Do(question)
3740
if err != nil {

eino/eino.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@ import (
66

77
func GetSQL(ddl, question string) (sql string) {
88
ctx := context.Background()
9-
messages := createMessagesFromTemplate(ddl, question)
9+
messages := ddl2sqlMessages(ddl, question)
10+
cm := createOpenAIChatModel(ctx)
11+
result := generate(ctx, cm, messages)
12+
sql = result.Content
13+
return
14+
}
15+
16+
func ChoiceSQL(sqls, ddl, question string) (sql string) {
17+
ctx := context.Background()
18+
messages := choiceSqlMessages(sqls, ddl, question)
1019
cm := createOpenAIChatModel(ctx)
1120
result := generate(ctx, cm, messages)
1221
sql = result.Content

eino/message.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ import (
88
"github.com/cloudwego/eino/schema"
99
)
1010

11+
const (
12+
limit = 1000
13+
role = "You are a MySQL expert."
14+
)
15+
1116
func createTemplate() prompt.ChatTemplate {
1217
// 创建模板,使用 FString 格式
1318
return prompt.FromMessages(schema.FString,
@@ -33,18 +38,37 @@ func createTemplate() prompt.ChatTemplate {
3338
)
3439
}
3540

36-
func createMessagesFromTemplate(ddl, question string) []*schema.Message {
41+
func ddl2sqlMessages(ddl, question string) []*schema.Message {
3742
template := createTemplate()
3843
// 使用模板生成消息
3944
messages, err := template.Format(context.Background(), map[string]any{
40-
"role": "You are a MySQL expert.",
45+
"role": role,
4146
"question": question,
4247
"ddl": ddl,
43-
"limit": 10,
48+
"limit": limit,
4449
"chat_history": []*schema.Message{},
4550
})
4651
if err != nil {
4752
log.Fatalf("format template failed: %v\n", err)
4853
}
4954
return messages
5055
}
56+
57+
func choiceSqlMessages(sqls, ddl, question string) []*schema.Message {
58+
template := createTemplate()
59+
// 使用模板生成消息
60+
messages, err := template.Format(context.Background(), map[string]any{
61+
"role": role,
62+
"question": "Select the most suitable SQL output from the above SQL statements",
63+
"ddl": ddl,
64+
"limit": limit,
65+
"chat_history": []*schema.Message{
66+
schema.UserMessage(question),
67+
schema.AssistantMessage(sqls, nil),
68+
},
69+
})
70+
if err != nil {
71+
log.Fatalf("format template failed: %v\n", err)
72+
}
73+
return messages
74+
}

text2sql.go

Lines changed: 151 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,179 @@ package text2sql
22

33
import (
44
"log"
5+
"strings"
6+
"sync"
7+
"unicode"
58

69
"github.com/wangle201210/text2sql/eino"
710
"github.com/wangle201210/text2sql/mysql"
811
)
912

1013
type Text2sql struct {
11-
DbLink string
12-
Try int // 如果失败,尝试多少次
13-
ShouldRun bool // 生成后是否需要run
14+
DbLink string
15+
Try int // 如果失败,尝试多少次
16+
ShouldRun bool // 生成后是否需要run
17+
Times int // 选填,1-10之间,同时生成*个sql后选出最合适的个,注意这个数越大消耗的token越多
18+
db *mysql.Db
19+
ddl string
20+
question string
21+
sqls []string
22+
removeSqls []string
1423
}
1524

1625
func (x *Text2sql) Do(question string) (sql string, runResult map[string]interface{}, err error) {
26+
if x.Times == 0 {
27+
x.Times = 1 // 至少要运行一次
28+
}
29+
if x.Times > 10 {
30+
x.Times = 10 // 最多10次
31+
}
1732
db := &mysql.Db{
1833
DataSourceName: x.DbLink,
1934
}
2035
if err = db.Init(); err != nil {
2136
return
2237
}
23-
ddl := db.GetDdl()
24-
sql = eino.GetSQL(ddl, question)
25-
err = db.CheckSQL(sql)
38+
x.db = db
39+
x.ddl = db.GetDdl()
40+
x.question = question
41+
if err = x.getAllSql(); err != nil {
42+
return
43+
}
44+
sql, err = x.choice()
45+
if err != nil {
46+
return
47+
}
48+
if x.ShouldRun {
49+
runResult = db.DoSQL(sql)
50+
}
51+
return
52+
}
53+
54+
func (x *Text2sql) removeWhitespace() {
55+
x.removeSqls = []string{}
56+
for _, sql := range x.sqls {
57+
x.removeSqls = append(x.removeSqls, removeWhitespace(sql))
58+
}
59+
}
60+
61+
func (x *Text2sql) uniqSql() {
62+
list := map[string]int{}
63+
for i, sql := range x.removeSqls {
64+
list[sql] = i // 存下最后一次出现的角标
65+
}
66+
var sqls []string
67+
for _, i := range list {
68+
sqls = append(sqls, x.sqls[i])
69+
}
70+
x.sqls = sqls
71+
}
72+
73+
func (x *Text2sql) findMostCommonSql() string {
74+
// 创建一个 map 来记录每个去空白字符后的字符串出现次数
75+
countMap := make(map[string]int)
76+
for _, s := range x.removeSqls {
77+
countMap[s]++
78+
}
79+
80+
// 找到出现次数最多的字符串
81+
var mostCommonString string
82+
maxCount := 0
83+
for cleanedString, count := range countMap {
84+
if count > maxCount {
85+
mostCommonString = cleanedString
86+
maxCount = count
87+
}
88+
}
89+
// 数量 > 1/2就直接使用
90+
if float64(maxCount) <= float64(len(x.sqls))/2 {
91+
return ""
92+
}
93+
// 返回该字符串在原数组中的第一个索引
94+
for i, s := range x.removeSqls {
95+
if s == mostCommonString {
96+
return x.sqls[i]
97+
}
98+
}
99+
return ""
100+
}
101+
102+
// 从多个sql中选择一个最适合的
103+
func (x *Text2sql) choice() (sql string, err error) {
104+
// 移除空白后进行对比
105+
x.removeWhitespace()
106+
// 如果有一个sql数量超过一半就直接使用该sql
107+
if commonSql := x.findMostCommonSql(); len(commonSql) > 0 {
108+
return commonSql, nil
109+
}
110+
// 去重
111+
x.uniqSql()
112+
// 如果只有一个了就直接返回
113+
if len(x.sqls) == 1 {
114+
sql = x.sqls[0]
115+
return
116+
}
117+
// 从候选项里面选一个
118+
var sqls string
119+
for _, s := range x.sqls {
120+
sqls += s
121+
sqls += "\n"
122+
}
123+
sql = eino.ChoiceSQL(sqls, x.ddl, x.question)
124+
err = x.db.CheckSQL(sql)
26125
try := x.Try
27126
for err != nil && try > 0 {
28-
log.Printf("try: %d, err: %v\n", x.Try-try, err)
29127
try--
30-
sql = eino.GetSQL(ddl, question)
31-
err = db.CheckSQL(sql)
128+
log.Printf("try: %d, err: %v\n", x.Try-try, err)
129+
sql = eino.ChoiceSQL(sqls, x.ddl, x.question)
130+
err = x.db.CheckSQL(sql)
32131
}
33132
if err != nil {
34133
return
35134
}
36-
if x.ShouldRun {
37-
runResult = db.DoSQL(sql)
135+
return
136+
}
137+
138+
func (x *Text2sql) getAllSql() (err error) {
139+
wg := &sync.WaitGroup{}
140+
141+
for i := 0; i < x.Times; i++ {
142+
wg.Add(1)
143+
// 循环生成多次,取最相关的一次
144+
go func() {
145+
defer wg.Done()
146+
if onceSql, err := x.once(); err == nil {
147+
x.sqls = append(x.sqls, onceSql)
148+
}
149+
}()
150+
}
151+
wg.Wait()
152+
return
153+
}
154+
155+
func (x *Text2sql) once() (sql string, err error) {
156+
sql = eino.GetSQL(x.ddl, x.question)
157+
err = x.db.CheckSQL(sql)
158+
try := x.Try
159+
for err != nil && try > 0 {
160+
try--
161+
log.Printf("try: %d, err: %v\n", x.Try-try, err)
162+
sql = eino.GetSQL(x.ddl, x.question)
163+
err = x.db.CheckSQL(sql)
164+
}
165+
if err != nil {
166+
return
38167
}
39168
return
40169
}
170+
171+
func removeWhitespace(input string) string {
172+
input = strings.ReplaceAll(input, "`", "")
173+
var result []rune
174+
for _, r := range input {
175+
if !unicode.IsSpace(r) { // 如果字符不是空白字符
176+
result = append(result, r)
177+
}
178+
}
179+
return string(result)
180+
}

text2sql_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ func TestText2sql(t *testing.T) {
77
DbLink: "root:@tcp(127.0.0.1:3306)/note?charset=utf8mb4&parseTime=True&loc=Local",
88
Try: 5,
99
ShouldRun: true,
10+
Times: 2,
1011
}
1112
sql, res, err := data.Do("王五在2025年1月上旬的餐饮食品类别总额")
1213
if err != nil {

0 commit comments

Comments
 (0)