@@ -19,7 +19,8 @@ type QueryValue struct {
19
19
20
20
// Column is kept so late in the generation process around to differentiate
21
21
// between mysql slices and pg arrays
22
- Column * plugin.Column
22
+ Column * plugin.Column
23
+ QueryText string
23
24
}
24
25
25
26
func (v QueryValue ) EmitStruct () bool {
@@ -84,6 +85,9 @@ func (v QueryValue) SlicePair() string {
84
85
85
86
func (v QueryValue ) Type () string {
86
87
if v .Typ != "" {
88
+ if v .isUsedWithArrayComparison () {
89
+ return strings .Trim (v .Typ , "[]" ) // Return single type if used in array comparison.
90
+ }
87
91
return v .Typ
88
92
}
89
93
if v .Struct != nil {
@@ -112,6 +116,9 @@ func (v QueryValue) UniqueFields() []Field {
112
116
fields := make ([]Field , 0 , len (v .Struct .Fields ))
113
117
114
118
for _ , field := range v .Struct .Fields {
119
+ if v .isUsedWithArrayComparison () {
120
+ field .Type = strings .Trim (field .Type , "[]" )
121
+ }
115
122
if _ , found := seen [field .Name ]; found {
116
123
continue
117
124
}
@@ -128,14 +135,14 @@ func (v QueryValue) Params() string {
128
135
}
129
136
var out []string
130
137
if v .Struct == nil {
131
- if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () {
138
+ if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
132
139
out = append (out , "pq.Array(" + escape (v .Name )+ ")" )
133
140
} else {
134
141
out = append (out , escape (v .Name ))
135
142
}
136
143
} else {
137
144
for _ , f := range v .Struct .Fields {
138
- if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () {
145
+ if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
139
146
out = append (out , "pq.Array(" + escape (v .VariableForField (f ))+ ")" )
140
147
} else {
141
148
out = append (out , escape (v .VariableForField (f )))
@@ -253,6 +260,22 @@ func (v QueryValue) VariableForField(f Field) string {
253
260
return v .Name + "." + f .Name
254
261
}
255
262
263
+ // isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
264
+ func (v QueryValue ) isUsedWithArrayComparison () bool {
265
+ if v .Struct != nil {
266
+ for _ , f := range v .Struct .Fields {
267
+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , f .DBName )) {
268
+ return true
269
+ }
270
+ }
271
+ } else {
272
+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , v .DBName )) {
273
+ return true
274
+ }
275
+ }
276
+ return false
277
+ }
278
+
256
279
// A struct used to generate methods and fields on the Queries struct
257
280
type Query struct {
258
281
Cmd string
0 commit comments