@@ -20,7 +20,8 @@ type QueryValue struct {
20
20
21
21
// Column is kept so late in the generation process around to differentiate
22
22
// between mysql slices and pg arrays
23
- Column * plugin.Column
23
+ Column * plugin.Column
24
+ QueryText string
24
25
}
25
26
26
27
func (v QueryValue ) EmitStruct () bool {
@@ -85,6 +86,9 @@ func (v QueryValue) SlicePair() string {
85
86
86
87
func (v QueryValue ) Type () string {
87
88
if v .Typ != "" {
89
+ if v .isUsedWithArrayComparison () {
90
+ return strings .Trim (v .Typ , "[]" ) // Return single type if used in array comparison.
91
+ }
88
92
return v .Typ
89
93
}
90
94
if v .Struct != nil {
@@ -113,6 +117,9 @@ func (v QueryValue) UniqueFields() []Field {
113
117
fields := make ([]Field , 0 , len (v .Struct .Fields ))
114
118
115
119
for _ , field := range v .Struct .Fields {
120
+ if v .isUsedWithArrayComparison () {
121
+ field .Type = strings .Trim (field .Type , "[]" )
122
+ }
116
123
if _ , found := seen [field .Name ]; found {
117
124
continue
118
125
}
@@ -129,14 +136,14 @@ func (v QueryValue) Params() string {
129
136
}
130
137
var out []string
131
138
if v .Struct == nil {
132
- if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () {
139
+ if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
133
140
out = append (out , "pq.Array(" + escape (v .Name )+ ")" )
134
141
} else {
135
142
out = append (out , escape (v .Name ))
136
143
}
137
144
} else {
138
145
for _ , f := range v .Struct .Fields {
139
- if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () {
146
+ if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
140
147
out = append (out , "pq.Array(" + escape (v .VariableForField (f ))+ ")" )
141
148
} else {
142
149
out = append (out , escape (v .VariableForField (f )))
@@ -254,6 +261,22 @@ func (v QueryValue) VariableForField(f Field) string {
254
261
return v .Name + "." + f .Name
255
262
}
256
263
264
+ // isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
265
+ func (v QueryValue ) isUsedWithArrayComparison () bool {
266
+ if v .Struct != nil {
267
+ for _ , f := range v .Struct .Fields {
268
+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , f .DBName )) {
269
+ return true
270
+ }
271
+ }
272
+ } else {
273
+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , v .DBName )) {
274
+ return true
275
+ }
276
+ }
277
+ return false
278
+ }
279
+
257
280
// A struct used to generate methods and fields on the Queries struct
258
281
type Query struct {
259
282
Cmd string
0 commit comments