Skip to content

Commit 12aa0b9

Browse files
a-shpakmkozhukh
authored andcommitted
[update] postgresql and dynamic fields
1 parent c049081 commit 12aa0b9

File tree

2 files changed

+263
-23
lines changed

2 files changed

+263
-23
lines changed

sql.go

Lines changed: 164 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ import (
66
"strings"
77
)
88

9+
const (
10+
DB_MYSQL DatabaseType = iota
11+
DB_POSTGRESQL
12+
)
13+
14+
type DatabaseType int
15+
16+
var DB DatabaseType = DB_MYSQL
17+
918
type Filter struct {
1019
Glue string `json:"glue"`
1120
Field string `json:"field"`
@@ -17,22 +26,30 @@ type Filter struct {
1726
type CustomOperation func(string, string, []interface{}) (string, []interface{}, error)
1827

1928
type SQLConfig struct {
20-
Whitelist map[string]bool
21-
Operations map[string]CustomOperation
29+
Whitelist map[string]bool
30+
Operations map[string]CustomOperation
31+
DynamicFields []DynamicField
32+
DynamicConfigName string
33+
}
34+
35+
type DynamicField struct {
36+
Key string `json:"key"`
37+
Type string `json:"type"`
2238
}
2339

2440
func FromJSON(text []byte) (Filter, error) {
2541
f := Filter{}
2642
err := json.Unmarshal(text, &f)
43+
2744
return f, err
2845
}
2946

3047
var NoValues = make([]interface{}, 0)
3148

32-
func inSQL(field string, data []interface{}) (string, []interface{}, error) {
49+
func inSQL(field string, data []interface{}, placeholder string) (string, []interface{}, error) {
3350
marks := make([]string, len(data))
3451
for i := range marks {
35-
marks[i] = "?"
52+
marks[i] = placeholder
3653
}
3754

3855
sql := fmt.Sprintf("%s IN(%s)", field, strings.Join(marks, ","))
@@ -45,65 +62,151 @@ func GetSQL(data Filter, config *SQLConfig) (string, []interface{}, error) {
4562
return "", nil, fmt.Errorf("field name is not in whitelist: %s", data.Field)
4663
}
4764

65+
ph, err := getPlaceholder()
66+
if err != nil {
67+
return "", nil, err
68+
}
69+
70+
var isDynamicField bool
71+
if DB == DB_POSTGRESQL {
72+
f := getDynamicField(config.DynamicFields, data.Field)
73+
if f != nil {
74+
if config.DynamicConfigName == "" {
75+
return "", nil, fmt.Errorf("dynamic config name is empty")
76+
}
77+
parts := strings.Split(data.Field, ".")
78+
tp := GetJSONBType(f.Type)
79+
var s, e string
80+
if tp == "date" {
81+
s = "CAST("
82+
e = " AS DATE)"
83+
tp = "text"
84+
}
85+
if len(parts) == 1 {
86+
data.Field = fmt.Sprintf("%s(%s->'%s')::%s%s", s, config.DynamicConfigName, parts[0], tp, e)
87+
} else if len(parts) == 2 {
88+
data.Field = fmt.Sprintf("%s(\"%s\".%s->'%s')::%s%s", s, parts[0], config.DynamicConfigName, parts[1], tp, e)
89+
}
90+
isDynamicField = true
91+
}
92+
}
93+
4894
if len(data.Includes) > 0 {
49-
return inSQL(data.Field, data.Includes)
95+
return inSQL(data.Field, data.Includes, ph)
5096
}
5197

5298
values := data.Condition.getValues()
5399
switch data.Condition.Rule {
54100
case "":
55101
return "", NoValues, nil
56102
case "equal":
57-
return fmt.Sprintf("%s = ?", data.Field), values, nil
103+
return fmt.Sprintf("%s = %s", data.Field, ph), values, nil
58104
case "notEqual":
59-
return fmt.Sprintf("%s <> ?", data.Field), values, nil
105+
return fmt.Sprintf("%s <> %s", data.Field, ph), values, nil
60106
case "contains":
61-
return fmt.Sprintf("INSTR(%s, ?) > 0", data.Field), values, nil
107+
switch DB {
108+
case DB_MYSQL:
109+
return fmt.Sprintf("INSTR(%s, ?) > 0", data.Field), values, nil
110+
case DB_POSTGRESQL:
111+
if isDynamicField {
112+
// Quotes (" ... ") are needed for correct work. Fields of type text in JSONB are wrapped by default
113+
return fmt.Sprintf("%s LIKE '\"%%' || $ || '%%\"'", data.Field), values, nil
114+
}
115+
return fmt.Sprintf("%s LIKE '%%' || $ || '%%'", data.Field), values, nil
116+
}
62117
case "notContains":
63-
return fmt.Sprintf("INSTR(%s, ?) = 0", data.Field), values, nil
118+
switch DB {
119+
case DB_MYSQL:
120+
return fmt.Sprintf("INSTR(%s, ?) = 0", data.Field), values, nil
121+
case DB_POSTGRESQL:
122+
if isDynamicField {
123+
return fmt.Sprintf("%s NOT LIKE '\"%%' || $ || '%%\"'", data.Field), values, nil
124+
}
125+
return fmt.Sprintf("%s NOT LIKE '%%' || $ || '%%'", data.Field), values, nil
126+
}
64127
case "lessOrEqual":
65-
return fmt.Sprintf("%s <= ?", data.Field), values, nil
128+
return fmt.Sprintf("%s <= %s", data.Field, ph), values, nil
66129
case "greaterOrEqual":
67-
return fmt.Sprintf("%s >= ?", data.Field), values, nil
130+
return fmt.Sprintf("%s >= %s", data.Field, ph), values, nil
68131
case "less":
69-
return fmt.Sprintf("%s < ?", data.Field), values, nil
132+
return fmt.Sprintf("%s < %s", data.Field, ph), values, nil
70133
case "notBetween":
71134
if len(values) != 2 {
72135
return "", nil, fmt.Errorf("wrong number of parameters for notBetween operation: %d", len(values))
73136
}
74137

75138
if values[0] == nil {
76-
return fmt.Sprintf("%s > ?", data.Field), values[1:], nil
139+
return fmt.Sprintf("%s > %s", data.Field, ph), values[1:], nil
77140
} else if values[1] == nil {
78-
return fmt.Sprintf("%s < ?", data.Field), values[:1], nil
141+
return fmt.Sprintf("%s < %s", data.Field, ph), values[:1], nil
79142
} else {
80-
return fmt.Sprintf("( %s < ? OR %s > ? )", data.Field, data.Field), values, nil
143+
return fmt.Sprintf("( %s < %s OR %s > %s )", data.Field, ph, data.Field, ph), values, nil
81144
}
82145
case "between":
83146
if len(values) != 2 {
84147
return "", nil, fmt.Errorf("wrong number of parameters for notBetween operation: %d", len(values))
85148
}
86149

87150
if values[0] == nil {
88-
return fmt.Sprintf("%s < ?", data.Field), values[1:], nil
151+
return fmt.Sprintf("%s < %s", data.Field, ph), values[1:], nil
89152
} else if values[1] == nil {
90-
return fmt.Sprintf("%s > ?", data.Field), values[:1], nil
153+
return fmt.Sprintf("%s > %s", data.Field, ph), values[:1], nil
91154
} else {
92-
return fmt.Sprintf("( %s > ? AND %s < ? )", data.Field, data.Field), values, nil
155+
return fmt.Sprintf("( %s > %s AND %s < %s )", data.Field, ph, data.Field, ph), values, nil
93156
}
94157
case "greater":
95-
return fmt.Sprintf("%s > ?", data.Field), values, nil
158+
return fmt.Sprintf("%s > %s", data.Field, ph), values, nil
96159
case "beginsWith":
97-
search := "CONCAT(?, '%')"
160+
var search string
161+
switch DB {
162+
case DB_MYSQL:
163+
search = "CONCAT(?, '%')"
164+
case DB_POSTGRESQL:
165+
if isDynamicField {
166+
search = "'\"' || $ || '%'"
167+
} else {
168+
search = " $ || '%'"
169+
}
170+
}
98171
return fmt.Sprintf("%s LIKE %s", data.Field, search), values, nil
99172
case "notBeginsWith":
100-
search := "CONCAT(?, '%')"
173+
var search string
174+
switch DB {
175+
case DB_MYSQL:
176+
search = "CONCAT(?, '%')"
177+
case DB_POSTGRESQL:
178+
if isDynamicField {
179+
search = "'\"' || $ || '%'"
180+
} else {
181+
search = " $ || '%'"
182+
}
183+
}
101184
return fmt.Sprintf("%s NOT LIKE %s", data.Field, search), values, nil
102185
case "endsWith":
103-
search := "CONCAT('%', ?)"
186+
var search string
187+
switch DB {
188+
case DB_MYSQL:
189+
search = "CONCAT('%', ?)"
190+
case DB_POSTGRESQL:
191+
if isDynamicField {
192+
search = "'%' || $ || '\"'"
193+
} else {
194+
search = "'%' || $ "
195+
}
196+
}
104197
return fmt.Sprintf("%s LIKE %s", data.Field, search), values, nil
105198
case "notEndsWith":
106-
search := "CONCAT('%', ?)"
199+
var search string
200+
switch DB {
201+
case DB_MYSQL:
202+
search = "CONCAT('%', ?)"
203+
case DB_POSTGRESQL:
204+
if isDynamicField {
205+
search = "'%' || $ || '\"'"
206+
} else {
207+
search = "'%' || $ "
208+
}
209+
}
107210
return fmt.Sprintf("%s NOT LIKE %s", data.Field, search), values, nil
108211
}
109212

@@ -141,5 +244,43 @@ func GetSQL(data Filter, config *SQLConfig) (string, []interface{}, error) {
141244
outStr = "( " + outStr + " )"
142245
}
143246

247+
// number all placeholders
248+
if DB == DB_POSTGRESQL {
249+
n := 1
250+
for strings.Contains(outStr, " $ ") {
251+
outStr = strings.Replace(outStr, " $ ", fmt.Sprintf("$%d", n), 1)
252+
n = n + 1
253+
}
254+
}
255+
144256
return outStr, values, nil
145257
}
258+
259+
func getPlaceholder() (string, error) {
260+
switch DB {
261+
case DB_MYSQL:
262+
return "?", nil
263+
case DB_POSTGRESQL:
264+
return " $ ", nil
265+
default:
266+
return "", fmt.Errorf("unknown database")
267+
}
268+
}
269+
270+
func getDynamicField(array []DynamicField, value string) *DynamicField {
271+
for _, v := range array {
272+
if v.Key == value {
273+
return &v
274+
}
275+
}
276+
return nil
277+
}
278+
279+
func GetJSONBType(t string) string {
280+
switch t {
281+
case "number":
282+
return "numeric"
283+
default:
284+
return t
285+
}
286+
}

sql_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,59 @@ var cases = [][]string{
130130
},
131131
}
132132

133+
var psqlCases = [][]string {
134+
[]string{
135+
`{ "glue":"and", "rules":[{ "field": "a", "condition":{ "type":"equal", "filter":1 }}]}`,
136+
"(cfg->'a')::text = $1",
137+
"1",
138+
},
139+
[]string{
140+
`{ "glue":"and", "rules":[{ "field": "b", "condition":{ "type":"notEqual", "filter":1 }}]}`,
141+
"(cfg->'b')::numeric <> $1",
142+
"1",
143+
},
144+
[]string{
145+
`{ "glue":"and", "rules":[{ "field": "b", "condition":{ "type":"less", "filter":1 }}]}`,
146+
"(cfg->'b')::numeric < $1",
147+
"1",
148+
},
149+
[]string{
150+
`{ "glue":"and", "rules":[{ "field": "b", "condition":{ "type":"lessOrEqual", "filter":1 }}]}`,
151+
"(cfg->'b')::numeric <= $1",
152+
"1",
153+
},
154+
[]string{
155+
`{ "glue":"and", "rules":[{ "field": "b", "condition":{ "type":"greater", "filter":1 }}]}`,
156+
"(cfg->'b')::numeric > $1",
157+
"1",
158+
},
159+
[]string{
160+
`{ "glue":"and", "rules":[{ "field": "b", "condition":{ "type":"greaterOrEqual", "filter":1 }}]}`,
161+
"(cfg->'b')::numeric >= $1",
162+
"1",
163+
},
164+
[]string{
165+
`{ "glue":"and", "rules":[{ "field": "a", "condition":{ "type":"contains", "filter":1 }}]}`,
166+
"(cfg->'a')::text LIKE '\"%' || $1 || '%\"'",
167+
"1",
168+
},
169+
[]string{
170+
`{ "glue":"and", "rules":[{ "field": "a", "condition":{ "type":"notContains", "filter":1 }}]}`,
171+
"(cfg->'a')::text NOT LIKE '\"%' || $1 || '%\"'",
172+
"1",
173+
},
174+
[]string{
175+
`{ "glue":"and", "rules":[{ "field": "c", "condition":{ "type":"equal", "filter":"2006/01/02" }}]}`,
176+
"CAST((cfg->'c')::text AS DATE) = $1",
177+
`2006/01/02`,
178+
},
179+
[]string{
180+
`{ "glue":"and", "rules":[{ "field": "c", "condition":{ "type":"notBetween", "filter":{ "start":"2006/01/02", "end":"2006/01/9" } }}]}`,
181+
"( CAST((cfg->'c')::text AS DATE) < $1 OR CAST((cfg->'c')::text AS DATE) > $2 )",
182+
`2006/01/02,2006/01/9`,
183+
},
184+
}
185+
133186
func anyToStringArray(some []interface{}) (string, error) {
134187
out := make([]string, 0, len(some))
135188
for _, x := range some {
@@ -182,6 +235,52 @@ func TestSQL(t *testing.T) {
182235
}
183236
}
184237

238+
func TestPSQL(t *testing.T) {
239+
DB = DB_POSTGRESQL
240+
queryConfig := SQLConfig{
241+
Whitelist: map[string]bool{
242+
"a": true,
243+
"b": true,
244+
"c": true,
245+
},
246+
DynamicFields: []DynamicField{
247+
{"a", "text"},
248+
{"b", "number"},
249+
{"c", "date"},
250+
},
251+
DynamicConfigName: "cfg",
252+
}
253+
for _, line := range psqlCases {
254+
format, err := FromJSON([]byte(line[0]))
255+
if err != nil {
256+
t.Errorf("can't parse json\nj: %s\n%f", line[0], err)
257+
continue
258+
}
259+
260+
sql, vals, err := GetSQL(format, &queryConfig)
261+
if err != nil {
262+
t.Errorf("can't generate sql\nj: %s\n%f", line[0], err)
263+
continue
264+
}
265+
if sql != line[1] {
266+
t.Errorf("wrong sql generated\nj: %s\ns: %s\nr: %s", line[0], line[1], sql)
267+
continue
268+
}
269+
270+
valsStr, err := anyToStringArray(vals)
271+
if err != nil {
272+
t.Errorf("can't convert parameters\nj: %s\n%f", line[0], err)
273+
continue
274+
}
275+
276+
if valsStr != line[2] {
277+
t.Errorf("wrong sql generated\nj: %s\ns: %s\nr: %s", line[0], line[2], valsStr)
278+
continue
279+
}
280+
}
281+
DB = DB_MYSQL
282+
}
283+
185284
func TestWhitelist(t *testing.T) {
186285
format, err := FromJSON([]byte(aAndB))
187286
if err != nil {

0 commit comments

Comments
 (0)