@@ -2,39 +2,179 @@ package text2sql
22
33import (
44 "log"
5+ "strings"
6+ "sync"
7+ "unicode"
58
69 "github.com/wangle201210/text2sql/eino"
710 "github.com/wangle201210/text2sql/mysql"
811)
912
1013type Text2sql struct {
11- DbLink string
12- Try int // 如果失败,尝试多少次
13- ShouldRun bool // 生成后是否需要run
14+ DbLink string
15+ Try int // 如果失败,尝试多少次
16+ ShouldRun bool // 生成后是否需要run
17+ Times int // 选填,1-10之间,同时生成*个sql后选出最合适的个,注意这个数越大消耗的token越多
18+ db * mysql.Db
19+ ddl string
20+ question string
21+ sqls []string
22+ removeSqls []string
1423}
1524
1625func (x * Text2sql ) Do (question string ) (sql string , runResult map [string ]interface {}, err error ) {
26+ if x .Times == 0 {
27+ x .Times = 1 // 至少要运行一次
28+ }
29+ if x .Times > 10 {
30+ x .Times = 10 // 最多10次
31+ }
1732 db := & mysql.Db {
1833 DataSourceName : x .DbLink ,
1934 }
2035 if err = db .Init (); err != nil {
2136 return
2237 }
23- ddl := db .GetDdl ()
24- sql = eino .GetSQL (ddl , question )
25- err = db .CheckSQL (sql )
38+ x .db = db
39+ x .ddl = db .GetDdl ()
40+ x .question = question
41+ if err = x .getAllSql (); err != nil {
42+ return
43+ }
44+ sql , err = x .choice ()
45+ if err != nil {
46+ return
47+ }
48+ if x .ShouldRun {
49+ runResult = db .DoSQL (sql )
50+ }
51+ return
52+ }
53+
54+ func (x * Text2sql ) removeWhitespace () {
55+ x .removeSqls = []string {}
56+ for _ , sql := range x .sqls {
57+ x .removeSqls = append (x .removeSqls , removeWhitespace (sql ))
58+ }
59+ }
60+
61+ func (x * Text2sql ) uniqSql () {
62+ list := map [string ]int {}
63+ for i , sql := range x .removeSqls {
64+ list [sql ] = i // 存下最后一次出现的角标
65+ }
66+ var sqls []string
67+ for _ , i := range list {
68+ sqls = append (sqls , x .sqls [i ])
69+ }
70+ x .sqls = sqls
71+ }
72+
73+ func (x * Text2sql ) findMostCommonSql () string {
74+ // 创建一个 map 来记录每个去空白字符后的字符串出现次数
75+ countMap := make (map [string ]int )
76+ for _ , s := range x .removeSqls {
77+ countMap [s ]++
78+ }
79+
80+ // 找到出现次数最多的字符串
81+ var mostCommonString string
82+ maxCount := 0
83+ for cleanedString , count := range countMap {
84+ if count > maxCount {
85+ mostCommonString = cleanedString
86+ maxCount = count
87+ }
88+ }
89+ // 数量 > 1/2就直接使用
90+ if float64 (maxCount ) <= float64 (len (x .sqls ))/ 2 {
91+ return ""
92+ }
93+ // 返回该字符串在原数组中的第一个索引
94+ for i , s := range x .removeSqls {
95+ if s == mostCommonString {
96+ return x .sqls [i ]
97+ }
98+ }
99+ return ""
100+ }
101+
102+ // 从多个sql中选择一个最适合的
103+ func (x * Text2sql ) choice () (sql string , err error ) {
104+ // 移除空白后进行对比
105+ x .removeWhitespace ()
106+ // 如果有一个sql数量超过一半就直接使用该sql
107+ if commonSql := x .findMostCommonSql (); len (commonSql ) > 0 {
108+ return commonSql , nil
109+ }
110+ // 去重
111+ x .uniqSql ()
112+ // 如果只有一个了就直接返回
113+ if len (x .sqls ) == 1 {
114+ sql = x .sqls [0 ]
115+ return
116+ }
117+ // 从候选项里面选一个
118+ var sqls string
119+ for _ , s := range x .sqls {
120+ sqls += s
121+ sqls += "\n "
122+ }
123+ sql = eino .ChoiceSQL (sqls , x .ddl , x .question )
124+ err = x .db .CheckSQL (sql )
26125 try := x .Try
27126 for err != nil && try > 0 {
28- log .Printf ("try: %d, err: %v\n " , x .Try - try , err )
29127 try --
30- sql = eino .GetSQL (ddl , question )
31- err = db .CheckSQL (sql )
128+ log .Printf ("try: %d, err: %v\n " , x .Try - try , err )
129+ sql = eino .ChoiceSQL (sqls , x .ddl , x .question )
130+ err = x .db .CheckSQL (sql )
32131 }
33132 if err != nil {
34133 return
35134 }
36- if x .ShouldRun {
37- runResult = db .DoSQL (sql )
135+ return
136+ }
137+
138+ func (x * Text2sql ) getAllSql () (err error ) {
139+ wg := & sync.WaitGroup {}
140+
141+ for i := 0 ; i < x .Times ; i ++ {
142+ wg .Add (1 )
143+ // 循环生成多次,取最相关的一次
144+ go func () {
145+ defer wg .Done ()
146+ if onceSql , err := x .once (); err == nil {
147+ x .sqls = append (x .sqls , onceSql )
148+ }
149+ }()
150+ }
151+ wg .Wait ()
152+ return
153+ }
154+
155+ func (x * Text2sql ) once () (sql string , err error ) {
156+ sql = eino .GetSQL (x .ddl , x .question )
157+ err = x .db .CheckSQL (sql )
158+ try := x .Try
159+ for err != nil && try > 0 {
160+ try --
161+ log .Printf ("try: %d, err: %v\n " , x .Try - try , err )
162+ sql = eino .GetSQL (x .ddl , x .question )
163+ err = x .db .CheckSQL (sql )
164+ }
165+ if err != nil {
166+ return
38167 }
39168 return
40169}
170+
171+ func removeWhitespace (input string ) string {
172+ input = strings .ReplaceAll (input , "`" , "" )
173+ var result []rune
174+ for _ , r := range input {
175+ if ! unicode .IsSpace (r ) { // 如果字符不是空白字符
176+ result = append (result , r )
177+ }
178+ }
179+ return string (result )
180+ }
0 commit comments