Skip to content

Commit 722d48a

Browse files
authored
Merge pull request #1 from wangle201210/feat/loop
Feat/loop
2 parents 2e799b5 + b8b22bb commit 722d48a

File tree

12 files changed

+460
-74
lines changed

12 files changed

+460
-74
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
.idea
1+
.idea
2+
/cmd/text2sql/go.work.sum

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func main() {
3838
DbLink: "root:password@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local",
3939
Try: 5, // 失败时的重试次数
4040
ShouldRun: true, // 是否执行生成的SQL
41+
Times: 3, // 同时生成3个SQL,选择最合适的一个
4142
}
4243

4344
// 将中文问题转换为SQL并执行
@@ -58,6 +59,10 @@ func main() {
5859
- `ShouldRun`: 是否执行生成的SQL语句
5960
- 设置为`true`时会执行SQL并返回结果
6061
- 设置为`false`时只返回生成的SQL语句
62+
- `Times`: 同时生成SQL的次数(可选,默认1)
63+
- 取值范围:1-10
64+
- 数值越大,生成的SQL候选数越多,选择最合适的SQL的准确率越高
65+
- 注意:数值越大消耗的token也越多
6166

6267
## 命令行工具
6368
[命令行工具](./cmd/text2sql/README.md)

cmd/text2sql/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@ text2sql还提供了命令行工具,可以直接在终端中使用:
77
go install github.com/wangle201210/text2sql/cmd/text2sql@latest
88

99
# 使用示例
10-
text2sql -l "root:password@tcp(127.0.0.1:3306)/database?charset=utf8mb4&parseTime=True&loc=Local" -q "王五的openid" -t 5 -r
10+
text2sql -l "root:password@tcp(127.0.0.1:3306)/database" -q "王五的openid" -t 5 -r -n 3
1111
```
1212

1313
参数说明:
1414
- `-l`: MySQL数据库连接信息(必填)
1515
- `-q`: 查询语句描述(必填)
1616
- `-t`: 失败后重试次数(可选,默认5次)
1717
- `-r`: 是否执行生成的SQL(可选,默认false)
18+
- `-n`: 同时生成SQL的次数(可选,默认1)
19+
- 取值范围:1-10
20+
- 数值越大,生成的SQL候选数越多,选择最合适的SQL的准确率越高
21+
- 注意:数值越大消耗的token也越多

cmd/text2sql/go.work

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
go 1.23.1
2+
3+
use .
4+
5+
replace github.com/wangle201210/text2sql => ../../

cmd/text2sql/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ func main() {
1616
question string
1717
try int
1818
run bool
19+
times int
1920
)
2021

2122
flag.StringVar(&link, "l", "", "mysql链接信息")
2223
flag.StringVar(&question, "q", "", "查询语句描述")
2324
flag.IntVar(&try, "t", 5, "失败后重试次数")
2425
flag.BoolVar(&run, "r", false, "获取到sql后是否执行")
26+
flag.IntVar(&times, "n", 1, "同时生成SQL的次数(1-10之间),数值越大准确率越高,但消耗的token也越多")
2527
flag.Parse()
2628
if link == "" || question == "" {
2729
fmt.Println("请输入正确参数")
@@ -32,6 +34,7 @@ func main() {
3234
DbLink: link,
3335
Try: try,
3436
ShouldRun: run,
37+
Times: times,
3538
}
3639
sql, result, err := ts.Do(question)
3740
if err != nil {

eino/eino.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,41 @@ package eino
22

33
import (
44
"context"
5+
"fmt"
56
)
67

7-
func GetSQL(ddl, question string) (sql string) {
8+
func GetSQL(ddl, question string) (sql string, err error) {
89
ctx := context.Background()
9-
messages := createMessagesFromTemplate(ddl, question)
10-
cm := createOpenAIChatModel(ctx)
11-
result := generate(ctx, cm, messages)
12-
sql = result.Content
13-
return
10+
messages, err := ddl2sqlMessages(ddl, question)
11+
if err != nil {
12+
return "", err
13+
}
14+
15+
cm, err := createOpenAIChatModel(ctx)
16+
if err != nil {
17+
return "", err
18+
}
19+
result, err := generate(ctx, cm, messages)
20+
if err != nil {
21+
return "", fmt.Errorf("生成SQL失败: %w", err)
22+
}
23+
return result.Content, nil
24+
}
25+
26+
func ChoiceSQL(sqls, ddl, question string) (sql string, err error) {
27+
ctx := context.Background()
28+
messages, err := choiceSqlMessages(sqls, ddl, question)
29+
if err != nil {
30+
return "", err
31+
}
32+
33+
cm, err := createOpenAIChatModel(ctx)
34+
if err != nil {
35+
return "", err
36+
}
37+
result, err := generate(ctx, cm, messages)
38+
if err != nil {
39+
return "", fmt.Errorf("选择SQL失败: %w", err)
40+
}
41+
return result.Content, nil
1442
}

eino/generate.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@ package eino
22

33
import (
44
"context"
5-
"log"
5+
"fmt"
66

77
"github.com/cloudwego/eino/components/model"
88
"github.com/cloudwego/eino/schema"
99
)
1010

11-
func generate(ctx context.Context, llm model.ChatModel, in []*schema.Message) *schema.Message {
12-
result, err := llm.Generate(ctx, in)
11+
func generate(ctx context.Context, llm model.ChatModel, in []*schema.Message) (message *schema.Message, err error) {
12+
message, err = llm.Generate(ctx, in)
1313
if err != nil {
14-
log.Fatalf("llm generate failed: %v", err)
14+
err = fmt.Errorf("llm generate failed: %v", err)
15+
return
1516
}
16-
return result
17+
return
1718
}

eino/message.go

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@ package eino
22

33
import (
44
"context"
5-
"log"
5+
"fmt"
66

77
"github.com/cloudwego/eino/components/prompt"
88
"github.com/cloudwego/eino/schema"
99
)
1010

11+
const (
12+
limit = 1000
13+
role = "You are a MySQL expert."
14+
)
15+
16+
// createTemplate 创建并返回一个配置好的聊天模板
1117
func createTemplate() prompt.ChatTemplate {
1218
// 创建模板,使用 FString 格式
1319
return prompt.FromMessages(schema.FString,
@@ -33,18 +39,48 @@ func createTemplate() prompt.ChatTemplate {
3339
)
3440
}
3541

36-
func createMessagesFromTemplate(ddl, question string) []*schema.Message {
42+
// formatMessages 格式化消息并处理错误
43+
func formatMessages(template prompt.ChatTemplate, data map[string]any) ([]*schema.Message, error) {
44+
messages, err := template.Format(context.Background(), data)
45+
if err != nil {
46+
return nil, fmt.Errorf("格式化模板失败: %w", err)
47+
}
48+
return messages, nil
49+
}
50+
51+
// ddl2sqlMessages 将DDL和问题转换为消息列表
52+
func ddl2sqlMessages(ddl, question string) ([]*schema.Message, error) {
3753
template := createTemplate()
38-
// 使用模板生成消息
39-
messages, err := template.Format(context.Background(), map[string]any{
40-
"role": "You are a MySQL expert.",
54+
data := map[string]any{
55+
"role": role,
4156
"question": question,
4257
"ddl": ddl,
43-
"limit": 10,
58+
"limit": limit,
4459
"chat_history": []*schema.Message{},
45-
})
60+
}
61+
messages, err := formatMessages(template, data)
62+
if err != nil {
63+
return nil, err
64+
}
65+
return messages, nil
66+
}
67+
68+
// choiceSqlMessages 生成SQL选择消息列表
69+
func choiceSqlMessages(sqls, ddl, question string) ([]*schema.Message, error) {
70+
template := createTemplate()
71+
data := map[string]any{
72+
"role": role,
73+
"question": "Select the most suitable SQL output from the above SQL statements",
74+
"ddl": ddl,
75+
"limit": limit,
76+
"chat_history": []*schema.Message{
77+
schema.UserMessage(question),
78+
schema.AssistantMessage(sqls, nil),
79+
},
80+
}
81+
messages, err := formatMessages(template, data)
4682
if err != nil {
47-
log.Fatalf("format template failed: %v\n", err)
83+
return nil, err
4884
}
49-
return messages
85+
return messages, nil
5086
}

eino/model.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,42 @@ package eino
22

33
import (
44
"context"
5+
"fmt"
56
"log"
67
"os"
78

89
"github.com/cloudwego/eino-ext/components/model/openai"
910
"github.com/cloudwego/eino/components/model"
1011
)
1112

12-
func createOpenAIChatModel(ctx context.Context) model.ChatModel {
13-
m := os.Getenv("OPENAI_MODEL_NAME")
14-
if m == "" {
15-
log.Fatalf("请在环境变量中设置你的 OPENAI_MODEL_NAME")
13+
// createOpenAIChatModel 创建OpenAI聊天模型实例
14+
func createOpenAIChatModel(ctx context.Context) (model.ChatModel, error) {
15+
// 验证必需的环境变量
16+
modelName := os.Getenv("OPENAI_MODEL_NAME")
17+
if modelName == "" {
18+
return nil, fmt.Errorf("环境变量OPENAI_MODEL_NAME未设置")
1619
}
17-
apikey := os.Getenv("OPENAI_API_KEY")
18-
if apikey == "" {
19-
log.Fatalf("请在环境变量中设置你的 OPENAI_API_KEY")
20+
21+
apiKey := os.Getenv("OPENAI_API_KEY")
22+
if apiKey == "" {
23+
return nil, fmt.Errorf("环境变量OPENAI_API_KEY未设置")
24+
}
25+
26+
// 获取可选的baseURL
27+
baseURL := os.Getenv("OPENAI_BASE_URL")
28+
if baseURL == "" {
29+
log.Println("未设置OPENAI_BASE_URL,将使用默认API地址")
2030
}
31+
32+
// 创建聊天模型
2133
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
22-
Model: os.Getenv("OPENAI_MODEL_NAME"),
23-
APIKey: os.Getenv("OPENAI_API_KEY"),
24-
BaseURL: os.Getenv("OPENAI_BASE_URL"),
34+
Model: modelName,
35+
APIKey: apiKey,
36+
BaseURL: baseURL,
2537
})
2638
if err != nil {
27-
log.Fatalf("create openai chat model failed, err=%v", err)
39+
return nil, fmt.Errorf("创建OpenAI聊天模型失败: %w", err)
2840
}
29-
return chatModel
41+
42+
return chatModel, nil
3043
}

0 commit comments

Comments
 (0)