Skip to content

Commit 32e6af6

Browse files
committed
[update] predicates
1 parent 19b5917 commit 32e6af6

File tree

2 files changed

+109
-12
lines changed

2 files changed

+109
-12
lines changed

sql.go

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,24 @@ type Filter struct {
2929
Rules []Filter `json:"rules"`
3030
}
3131

32+
func (f *Filter) getValues() []interface{} {
33+
valueMap, ok := f.Value.(map[string]interface{})
34+
if !ok {
35+
return []interface{}{f.Value}
36+
}
37+
38+
return []interface{}{valueMap["start"], valueMap["end"]}
39+
}
40+
3241
type CustomOperation func(string, string, []interface{}) (string, []interface{}, error)
42+
type CustomPredicate func(string, string, []interface{}) (string, []interface{}, error)
3343

3444
type CheckFunction = func(string) bool
3545
type SQLConfig struct {
3646
WhitelistFunc CheckFunction
3747
Whitelist map[string]bool
3848
Operations map[string]CustomOperation
49+
Predicates map[string]CustomPredicate
3950
}
4051

4152
func FromJSON(text []byte) (Filter, error) {
@@ -57,15 +68,6 @@ func inSQL(field string, data []interface{}, db DBDriver) (string, []interface{}
5768
return sql, data, nil
5869
}
5970

60-
func (f *Filter) getValues() []interface{} {
61-
valueMap, ok := f.Value.(map[string]interface{})
62-
if !ok {
63-
return []interface{}{f.Value}
64-
}
65-
66-
return []interface{}{valueMap["start"], valueMap["end"]}
67-
}
68-
6971
func GetSQL(data Filter, config *SQLConfig, dbArr ...DBDriver) (string, []interface{}, error) {
7072
var db DBDriver
7173
if len(dbArr) > 0 {
@@ -90,6 +92,17 @@ func GetSQL(data Filter, config *SQLConfig, dbArr ...DBDriver) (string, []interf
9092
}
9193

9294
values := data.getValues()
95+
96+
var err error
97+
if config != nil && config.Predicates != nil {
98+
if pr, prOk := config.Predicates[data.Predicate]; prOk {
99+
name, values, err = pr(name, data.Predicate, values)
100+
if err != nil {
101+
return "", NoValues, err
102+
}
103+
}
104+
}
105+
93106
switch data.Filter {
94107
case "":
95108
return "", NoValues, nil
@@ -144,8 +157,7 @@ func GetSQL(data Filter, config *SQLConfig, dbArr ...DBDriver) (string, []interf
144157
}
145158

146159
if config != nil && config.Operations != nil {
147-
op, opOk := config.Operations[data.Filter]
148-
if opOk {
160+
if op, opOk := config.Operations[data.Filter]; opOk {
149161
return op(name, data.Filter, values)
150162
}
151163
}

sql_test.go

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ var aAndB = `{ "glue":"and", "rules":[{ "field": "a", "filter":"less", "value":1
1111
var aOrB = `{ "glue":"or", "rules":[{ "field": "a", "filter":"less", "value":1}, { "field": "b", "filter":"greater", "value":"abc" }]}`
1212
var cOrC = `{ "glue":"or", "rules":[{ "field": "a", "filter":"is null" }, { "field": "b", "filter":"range100", "value":500 }]}`
1313
var JSONaAndB = `{ "glue":"and", "rules":[{ "field": "json:cfg.a", "filter":"less", "value":1}, { "field": "json:cfg.b", "filter":"greater", "value":"abc" }]}`
14+
var aPred = `{ "glue":"and", "rules":[{ "field": "a", "filter":"greater", "type": "number", "predicate": "month","value": 10 }, { "field": "a", "filter":"less", "type": "number", "predicate": "year","value": 2024 }]}`
1415

1516
var cases = [][]string{
1617
{`{}`, "", "", ""},
@@ -447,7 +448,7 @@ func TestWhitelistPG(t *testing.T) {
447448
func TestCustomOperation(t *testing.T) {
448449
format, err := FromJSON([]byte(cOrC))
449450
if err != nil {
450-
t.Errorf("can't parse json\nj: %s\n%f", aAndB, err)
451+
t.Errorf("can't parse json\nj: %s\n%f", cOrC, err)
451452
return
452453
}
453454

@@ -486,3 +487,87 @@ func TestCustomOperation(t *testing.T) {
486487
return
487488
}
488489
}
490+
491+
func TestCustomPredicate(t *testing.T) {
492+
format, err := FromJSON([]byte(aPred))
493+
if err != nil {
494+
t.Errorf("can't parse json\nj: %s\n%f", aPred, err)
495+
return
496+
}
497+
498+
sql, vals, err := GetSQL(format, &SQLConfig{
499+
Predicates: map[string]CustomPredicate{
500+
"month": func(n string, p string, values []interface{}) (string, []interface{}, error) {
501+
return fmt.Sprintf("month(%s)", n), values, nil
502+
},
503+
"year": func(n string, p string, values []interface{}) (string, []interface{}, error) {
504+
return fmt.Sprintf("year(%s)", n), values, nil
505+
},
506+
},
507+
})
508+
509+
if err != nil {
510+
t.Errorf("can't generate sql: %s\n%f", aPred, err)
511+
return
512+
}
513+
514+
check := "( month(a) > ? AND year(a) < ? )"
515+
if sql != check {
516+
t.Errorf("wrong sql generated\nj: %s\ns: %s\nr: %s", aPred, check, sql)
517+
return
518+
}
519+
520+
valsStr, err := anyToStringArray(vals)
521+
if err != nil {
522+
t.Errorf("can't convert parameters\nj: %s\n%f", aPred, err)
523+
return
524+
}
525+
526+
check = "10,2024"
527+
if valsStr != check {
528+
t.Errorf("wrong sql generated\nj: %s\ns: %s\nr: %s", aPred, check, valsStr)
529+
return
530+
}
531+
}
532+
533+
func TestCustomPredicatePG(t *testing.T) {
534+
format, err := FromJSON([]byte(aPred))
535+
if err != nil {
536+
t.Errorf("can't parse json\nj: %s\n%f", aPred, err)
537+
return
538+
}
539+
540+
sql, vals, err := GetSQL(format, &SQLConfig{
541+
Predicates: map[string]CustomPredicate{
542+
"month": func(n string, p string, values []interface{}) (string, []interface{}, error) {
543+
return fmt.Sprintf("date_part('month', %s)", n), values, nil
544+
},
545+
"year": func(n string, p string, values []interface{}) (string, []interface{}, error) {
546+
return fmt.Sprintf("date_part('year', %s)", n), values, nil
547+
},
548+
},
549+
}, &PostgreSQL{})
550+
551+
if err != nil {
552+
t.Errorf("can't generate sql: %s\n%f", aPred, err)
553+
return
554+
}
555+
556+
check := "( date_part('month', a) > $1 AND date_part('year', a) < $2 )"
557+
if sql != check {
558+
t.Errorf("wrong sql generated\nj: %s\ns: %s\nr: %s", aPred, check, sql)
559+
return
560+
}
561+
562+
valsStr, err := anyToStringArray(vals)
563+
if err != nil {
564+
t.Errorf("can't convert parameters\nj: %s\n%f", aPred, err)
565+
return
566+
}
567+
568+
check = "10,2024"
569+
if valsStr != check {
570+
t.Errorf("wrong sql generated\nj: %s\ns: %s\nr: %s", aPred, check, valsStr)
571+
return
572+
}
573+
}

0 commit comments

Comments
 (0)