|
1 | 1 | package mysql |
2 | 2 |
|
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "strings" |
| 7 | + "time" |
| 8 | +) |
| 9 | + |
3 | 10 | import ( |
4 | 11 | "fmt" |
5 | 12 |
|
6 | 13 | "gorm.io/driver/mysql" |
7 | 14 | "gorm.io/gorm" |
8 | 15 | ) |
9 | 16 |
|
| 17 | +// Db 封装数据库操作的结构体 |
10 | 18 | type Db struct { |
11 | | - DataSourceName string |
12 | | - db *gorm.DB |
13 | | - ddl []string |
| 19 | + DataSourceName string // 数据库连接字信息 |
| 20 | + db *gorm.DB // gorm数据库实例 |
| 21 | + ddl []string // 数据库表结构定义列表 |
14 | 22 | } |
15 | 23 |
|
16 | | -func (x *Db) Init() (err error) { |
| 24 | +// Init 初始化数据库连接并获取表结构 |
| 25 | +func (x *Db) Init() error { |
| 26 | + if x.DataSourceName == "" { |
| 27 | + return errors.New("数据库连接字信息不能为空") |
| 28 | + } |
| 29 | + |
| 30 | + // 初始化数据库连接 |
17 | 31 | db, err := gorm.Open(mysql.Open(x.DataSourceName), &gorm.Config{}) |
18 | 32 | if err != nil { |
19 | | - return |
| 33 | + return fmt.Errorf("连接数据库失败: %w", err) |
20 | 34 | } |
21 | 35 | x.db = db |
| 36 | + |
| 37 | + // 获取表结构 |
22 | 38 | ddls, err := x.getDdl() |
23 | 39 | if err != nil { |
24 | | - return |
| 40 | + return fmt.Errorf("获取表结构失败: %w", err) |
25 | 41 | } |
26 | 42 | x.ddl = ddls |
27 | | - return |
| 43 | + return nil |
28 | 44 | } |
29 | 45 |
|
30 | | -func (x *Db) getDdl() (ddls []string, err error) { |
| 46 | +// getDdl 获取所有表的建表语句 |
| 47 | +func (x *Db) getDdl() ([]string, error) { |
| 48 | + if x.db == nil { |
| 49 | + return nil, errors.New("数据库连接未初始化") |
| 50 | + } |
| 51 | + |
| 52 | + // 获取所有表名 |
31 | 53 | list, err := x.db.Migrator().GetTables() |
| 54 | + if err != nil { |
| 55 | + return nil, fmt.Errorf("获取表名列表失败: %w", err) |
| 56 | + } |
| 57 | + |
| 58 | + // 获取每个表的建表语句 |
| 59 | + ddls := make([]string, 0, len(list)) |
32 | 60 | for _, tableName := range list { |
33 | 61 | result := map[string]interface{}{} |
34 | | - err = x.db.Raw(fmt.Sprintf("SHOW CREATE TABLE %s", tableName)).Scan(result).Error |
35 | | - if err != nil { |
36 | | - return |
| 62 | + if err := x.db.Raw(fmt.Sprintf("SHOW CREATE TABLE %s", tableName)).Scan(&result).Error; err != nil { |
| 63 | + return nil, fmt.Errorf("获取表%s的结构失败: %w", tableName, err) |
| 64 | + } |
| 65 | + |
| 66 | + createTable, ok := result["Create Table"].(string) |
| 67 | + if !ok { |
| 68 | + return nil, fmt.Errorf("表%s的结构格式异常", tableName) |
37 | 69 | } |
38 | | - ddls = append(ddls, result["Create Table"].(string)) |
| 70 | + ddls = append(ddls, createTable) |
39 | 71 | } |
40 | | - return |
| 72 | + return ddls, nil |
41 | 73 | } |
42 | 74 |
|
43 | | -func (x *Db) GetDdl() (ddl string) { |
| 75 | +// GetDdl 获取所有表结构的字符串表示 |
| 76 | +func (x *Db) GetDdl() string { |
| 77 | + if len(x.ddl) == 0 { |
| 78 | + return "" |
| 79 | + } |
| 80 | + |
| 81 | + // 使用strings.Builder优化字符串拼接 |
| 82 | + var builder strings.Builder |
44 | 83 | for _, s := range x.ddl { |
45 | | - ddl += s |
46 | | - ddl += ";\n" |
| 84 | + builder.WriteString(s) |
| 85 | + builder.WriteString(";\n") |
47 | 86 | } |
48 | | - return |
| 87 | + return builder.String() |
49 | 88 | } |
50 | 89 |
|
| 90 | +// CheckSQL 检查SQL语句的语法正确性 |
51 | 91 | func (x *Db) CheckSQL(sql string) error { |
| 92 | + if x.db == nil { |
| 93 | + return errors.New("数据库连接未初始化") |
| 94 | + } |
| 95 | + if sql == "" { |
| 96 | + return errors.New("SQL语句不能为空") |
| 97 | + } |
| 98 | + |
| 99 | + // 使用EXPLAIN验证SQL语句 |
52 | 100 | return x.db.Exec(fmt.Sprintf("EXPLAIN %s", sql)).Error |
53 | 101 | } |
54 | 102 |
|
55 | | -func (x *Db) DoSQL(sql string) (res map[string]interface{}) { |
56 | | - tx := x.db.Begin() |
| 103 | +// DoSQL 执行SQL查询并返回结果 |
| 104 | +func (x *Db) DoSQL(sql string) (res map[string]interface{}, err error) { |
| 105 | + // 创建只读事务 |
| 106 | + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
| 107 | + defer cancel() |
| 108 | + |
| 109 | + tx := x.db.WithContext(ctx).Begin() |
| 110 | + defer func() { |
| 111 | + if r := recover(); r != nil { |
| 112 | + tx.Rollback() |
| 113 | + } |
| 114 | + }() |
| 115 | + |
| 116 | + // 设置只读模式 |
57 | 117 | tx.Set("gorm:query_option", "FOR READ ONLY") |
| 118 | + |
| 119 | + // 执行查询 |
58 | 120 | res = map[string]interface{}{} |
59 | | - tx.Raw(sql).Scan(res) |
60 | | - tx.Commit() |
| 121 | + if err = tx.Raw(sql).Scan(&res).Error; err != nil { |
| 122 | + tx.Rollback() |
| 123 | + return |
| 124 | + } |
| 125 | + |
| 126 | + // 提交事务 |
| 127 | + if err = tx.Commit().Error; err != nil { |
| 128 | + return |
| 129 | + } |
| 130 | + |
61 | 131 | return |
62 | 132 | } |
0 commit comments