Skip to content

Commit b8b22bb

Browse files
committed
优化代码结构
1 parent fbff4ed commit b8b22bb

File tree

6 files changed

+317
-123
lines changed

6 files changed

+317
-123
lines changed

eino/eino.go

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,41 @@ package eino
22

33
import (
44
"context"
5+
"fmt"
56
)
67

7-
func GetSQL(ddl, question string) (sql string) {
8+
func GetSQL(ddl, question string) (sql string, err error) {
89
ctx := context.Background()
9-
messages := ddl2sqlMessages(ddl, question)
10-
cm := createOpenAIChatModel(ctx)
11-
result := generate(ctx, cm, messages)
12-
sql = result.Content
13-
return
10+
messages, err := ddl2sqlMessages(ddl, question)
11+
if err != nil {
12+
return "", err
13+
}
14+
15+
cm, err := createOpenAIChatModel(ctx)
16+
if err != nil {
17+
return "", err
18+
}
19+
result, err := generate(ctx, cm, messages)
20+
if err != nil {
21+
return "", fmt.Errorf("生成SQL失败: %w", err)
22+
}
23+
return result.Content, nil
1424
}
1525

16-
func ChoiceSQL(sqls, ddl, question string) (sql string) {
26+
func ChoiceSQL(sqls, ddl, question string) (sql string, err error) {
1727
ctx := context.Background()
18-
messages := choiceSqlMessages(sqls, ddl, question)
19-
cm := createOpenAIChatModel(ctx)
20-
result := generate(ctx, cm, messages)
21-
sql = result.Content
22-
return
28+
messages, err := choiceSqlMessages(sqls, ddl, question)
29+
if err != nil {
30+
return "", err
31+
}
32+
33+
cm, err := createOpenAIChatModel(ctx)
34+
if err != nil {
35+
return "", err
36+
}
37+
result, err := generate(ctx, cm, messages)
38+
if err != nil {
39+
return "", fmt.Errorf("选择SQL失败: %w", err)
40+
}
41+
return result.Content, nil
2342
}

eino/generate.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@ package eino
22

33
import (
44
"context"
5-
"log"
5+
"fmt"
66

77
"github.com/cloudwego/eino/components/model"
88
"github.com/cloudwego/eino/schema"
99
)
1010

11-
func generate(ctx context.Context, llm model.ChatModel, in []*schema.Message) *schema.Message {
12-
result, err := llm.Generate(ctx, in)
11+
func generate(ctx context.Context, llm model.ChatModel, in []*schema.Message) (message *schema.Message, err error) {
12+
message, err = llm.Generate(ctx, in)
1313
if err != nil {
14-
log.Fatalf("llm generate failed: %v", err)
14+
err = fmt.Errorf("llm generate failed: %v", err)
15+
return
1516
}
16-
return result
17+
return
1718
}

eino/message.go

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package eino
22

33
import (
44
"context"
5-
"log"
5+
"fmt"
66

77
"github.com/cloudwego/eino/components/prompt"
88
"github.com/cloudwego/eino/schema"
@@ -13,6 +13,7 @@ const (
1313
role = "You are a MySQL expert."
1414
)
1515

16+
// createTemplate 创建并返回一个配置好的聊天模板
1617
func createTemplate() prompt.ChatTemplate {
1718
// 创建模板,使用 FString 格式
1819
return prompt.FromMessages(schema.FString,
@@ -38,26 +39,36 @@ func createTemplate() prompt.ChatTemplate {
3839
)
3940
}
4041

41-
func ddl2sqlMessages(ddl, question string) []*schema.Message {
42+
// formatMessages 格式化消息并处理错误
43+
func formatMessages(template prompt.ChatTemplate, data map[string]any) ([]*schema.Message, error) {
44+
messages, err := template.Format(context.Background(), data)
45+
if err != nil {
46+
return nil, fmt.Errorf("格式化模板失败: %w", err)
47+
}
48+
return messages, nil
49+
}
50+
51+
// ddl2sqlMessages 将DDL和问题转换为消息列表
52+
func ddl2sqlMessages(ddl, question string) ([]*schema.Message, error) {
4253
template := createTemplate()
43-
// 使用模板生成消息
44-
messages, err := template.Format(context.Background(), map[string]any{
54+
data := map[string]any{
4555
"role": role,
4656
"question": question,
4757
"ddl": ddl,
4858
"limit": limit,
4959
"chat_history": []*schema.Message{},
50-
})
60+
}
61+
messages, err := formatMessages(template, data)
5162
if err != nil {
52-
log.Fatalf("format template failed: %v\n", err)
63+
return nil, err
5364
}
54-
return messages
65+
return messages, nil
5566
}
5667

57-
func choiceSqlMessages(sqls, ddl, question string) []*schema.Message {
68+
// choiceSqlMessages 生成SQL选择消息列表
69+
func choiceSqlMessages(sqls, ddl, question string) ([]*schema.Message, error) {
5870
template := createTemplate()
59-
// 使用模板生成消息
60-
messages, err := template.Format(context.Background(), map[string]any{
71+
data := map[string]any{
6172
"role": role,
6273
"question": "Select the most suitable SQL output from the above SQL statements",
6374
"ddl": ddl,
@@ -66,9 +77,10 @@ func choiceSqlMessages(sqls, ddl, question string) []*schema.Message {
6677
schema.UserMessage(question),
6778
schema.AssistantMessage(sqls, nil),
6879
},
69-
})
80+
}
81+
messages, err := formatMessages(template, data)
7082
if err != nil {
71-
log.Fatalf("format template failed: %v\n", err)
83+
return nil, err
7284
}
73-
return messages
85+
return messages, nil
7486
}

eino/model.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,42 @@ package eino
22

33
import (
44
"context"
5+
"fmt"
56
"log"
67
"os"
78

89
"github.com/cloudwego/eino-ext/components/model/openai"
910
"github.com/cloudwego/eino/components/model"
1011
)
1112

12-
func createOpenAIChatModel(ctx context.Context) model.ChatModel {
13-
m := os.Getenv("OPENAI_MODEL_NAME")
14-
if m == "" {
15-
log.Fatalf("请在环境变量中设置你的 OPENAI_MODEL_NAME")
13+
// createOpenAIChatModel 创建OpenAI聊天模型实例
14+
func createOpenAIChatModel(ctx context.Context) (model.ChatModel, error) {
15+
// 验证必需的环境变量
16+
modelName := os.Getenv("OPENAI_MODEL_NAME")
17+
if modelName == "" {
18+
return nil, fmt.Errorf("环境变量OPENAI_MODEL_NAME未设置")
1619
}
17-
apikey := os.Getenv("OPENAI_API_KEY")
18-
if apikey == "" {
19-
log.Fatalf("请在环境变量中设置你的 OPENAI_API_KEY")
20+
21+
apiKey := os.Getenv("OPENAI_API_KEY")
22+
if apiKey == "" {
23+
return nil, fmt.Errorf("环境变量OPENAI_API_KEY未设置")
24+
}
25+
26+
// 获取可选的baseURL
27+
baseURL := os.Getenv("OPENAI_BASE_URL")
28+
if baseURL == "" {
29+
log.Println("未设置OPENAI_BASE_URL,将使用默认API地址")
2030
}
31+
32+
// 创建聊天模型
2133
chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{
22-
Model: os.Getenv("OPENAI_MODEL_NAME"),
23-
APIKey: os.Getenv("OPENAI_API_KEY"),
24-
BaseURL: os.Getenv("OPENAI_BASE_URL"),
34+
Model: modelName,
35+
APIKey: apiKey,
36+
BaseURL: baseURL,
2537
})
2638
if err != nil {
27-
log.Fatalf("create openai chat model failed, err=%v", err)
39+
return nil, fmt.Errorf("创建OpenAI聊天模型失败: %w", err)
2840
}
29-
return chatModel
41+
42+
return chatModel, nil
3043
}

mysql/mysql.go

Lines changed: 91 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,132 @@
11
package mysql
22

3+
import (
4+
"context"
5+
"errors"
6+
"strings"
7+
"time"
8+
)
9+
310
import (
411
"fmt"
512

613
"gorm.io/driver/mysql"
714
"gorm.io/gorm"
815
)
916

17+
// Db 封装数据库操作的结构体
1018
type Db struct {
11-
DataSourceName string
12-
db *gorm.DB
13-
ddl []string
19+
DataSourceName string // 数据库连接字信息
20+
db *gorm.DB // gorm数据库实例
21+
ddl []string // 数据库表结构定义列表
1422
}
1523

16-
func (x *Db) Init() (err error) {
24+
// Init 初始化数据库连接并获取表结构
25+
func (x *Db) Init() error {
26+
if x.DataSourceName == "" {
27+
return errors.New("数据库连接字信息不能为空")
28+
}
29+
30+
// 初始化数据库连接
1731
db, err := gorm.Open(mysql.Open(x.DataSourceName), &gorm.Config{})
1832
if err != nil {
19-
return
33+
return fmt.Errorf("连接数据库失败: %w", err)
2034
}
2135
x.db = db
36+
37+
// 获取表结构
2238
ddls, err := x.getDdl()
2339
if err != nil {
24-
return
40+
return fmt.Errorf("获取表结构失败: %w", err)
2541
}
2642
x.ddl = ddls
27-
return
43+
return nil
2844
}
2945

30-
func (x *Db) getDdl() (ddls []string, err error) {
46+
// getDdl 获取所有表的建表语句
47+
func (x *Db) getDdl() ([]string, error) {
48+
if x.db == nil {
49+
return nil, errors.New("数据库连接未初始化")
50+
}
51+
52+
// 获取所有表名
3153
list, err := x.db.Migrator().GetTables()
54+
if err != nil {
55+
return nil, fmt.Errorf("获取表名列表失败: %w", err)
56+
}
57+
58+
// 获取每个表的建表语句
59+
ddls := make([]string, 0, len(list))
3260
for _, tableName := range list {
3361
result := map[string]interface{}{}
34-
err = x.db.Raw(fmt.Sprintf("SHOW CREATE TABLE %s", tableName)).Scan(result).Error
35-
if err != nil {
36-
return
62+
if err := x.db.Raw(fmt.Sprintf("SHOW CREATE TABLE %s", tableName)).Scan(&result).Error; err != nil {
63+
return nil, fmt.Errorf("获取表%s的结构失败: %w", tableName, err)
64+
}
65+
66+
createTable, ok := result["Create Table"].(string)
67+
if !ok {
68+
return nil, fmt.Errorf("表%s的结构格式异常", tableName)
3769
}
38-
ddls = append(ddls, result["Create Table"].(string))
70+
ddls = append(ddls, createTable)
3971
}
40-
return
72+
return ddls, nil
4173
}
4274

43-
func (x *Db) GetDdl() (ddl string) {
75+
// GetDdl 获取所有表结构的字符串表示
76+
func (x *Db) GetDdl() string {
77+
if len(x.ddl) == 0 {
78+
return ""
79+
}
80+
81+
// 使用strings.Builder优化字符串拼接
82+
var builder strings.Builder
4483
for _, s := range x.ddl {
45-
ddl += s
46-
ddl += ";\n"
84+
builder.WriteString(s)
85+
builder.WriteString(";\n")
4786
}
48-
return
87+
return builder.String()
4988
}
5089

90+
// CheckSQL 检查SQL语句的语法正确性
5191
func (x *Db) CheckSQL(sql string) error {
92+
if x.db == nil {
93+
return errors.New("数据库连接未初始化")
94+
}
95+
if sql == "" {
96+
return errors.New("SQL语句不能为空")
97+
}
98+
99+
// 使用EXPLAIN验证SQL语句
52100
return x.db.Exec(fmt.Sprintf("EXPLAIN %s", sql)).Error
53101
}
54102

55-
func (x *Db) DoSQL(sql string) (res map[string]interface{}) {
56-
tx := x.db.Begin()
103+
// DoSQL 执行SQL查询并返回结果
104+
func (x *Db) DoSQL(sql string) (res map[string]interface{}, err error) {
105+
// 创建只读事务
106+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
107+
defer cancel()
108+
109+
tx := x.db.WithContext(ctx).Begin()
110+
defer func() {
111+
if r := recover(); r != nil {
112+
tx.Rollback()
113+
}
114+
}()
115+
116+
// 设置只读模式
57117
tx.Set("gorm:query_option", "FOR READ ONLY")
118+
119+
// 执行查询
58120
res = map[string]interface{}{}
59-
tx.Raw(sql).Scan(res)
60-
tx.Commit()
121+
if err = tx.Raw(sql).Scan(&res).Error; err != nil {
122+
tx.Rollback()
123+
return
124+
}
125+
126+
// 提交事务
127+
if err = tx.Commit().Error; err != nil {
128+
return
129+
}
130+
61131
return
62132
}

0 commit comments

Comments
 (0)