@@ -20,7 +20,8 @@ type QueryValue struct {
2020
2121 // Column is kept so late in the generation process around to differentiate
2222 // between mysql slices and pg arrays
23- Column * plugin.Column
23+ Column * plugin.Column
24+ QueryText string
2425}
2526
2627func (v QueryValue ) EmitStruct () bool {
@@ -85,6 +86,9 @@ func (v QueryValue) SlicePair() string {
8586
8687func (v QueryValue ) Type () string {
8788 if v .Typ != "" {
89+ if v .isUsedWithArrayComparison () {
90+ return strings .Trim (v .Typ , "[]" ) // Return single type if used in array comparison.
91+ }
8892 return v .Typ
8993 }
9094 if v .Struct != nil {
@@ -113,6 +117,9 @@ func (v QueryValue) UniqueFields() []Field {
113117 fields := make ([]Field , 0 , len (v .Struct .Fields ))
114118
115119 for _ , field := range v .Struct .Fields {
120+ if v .isUsedWithArrayComparison () {
121+ field .Type = strings .Trim (field .Type , "[]" )
122+ }
116123 if _ , found := seen [field .Name ]; found {
117124 continue
118125 }
@@ -129,14 +136,14 @@ func (v QueryValue) Params() string {
129136 }
130137 var out []string
131138 if v .Struct == nil {
132- if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () {
139+ if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
133140 out = append (out , "pq.Array(" + escape (v .Name )+ ")" )
134141 } else {
135142 out = append (out , escape (v .Name ))
136143 }
137144 } else {
138145 for _ , f := range v .Struct .Fields {
139- if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () {
146+ if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
140147 out = append (out , "pq.Array(" + escape (v .VariableForField (f ))+ ")" )
141148 } else {
142149 out = append (out , escape (v .VariableForField (f )))
@@ -254,6 +261,22 @@ func (v QueryValue) VariableForField(f Field) string {
254261 return v .Name + "." + f .Name
255262}
256263
264+ // isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
265+ func (v QueryValue ) isUsedWithArrayComparison () bool {
266+ if v .Struct != nil {
267+ for _ , f := range v .Struct .Fields {
268+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , f .DBName )) {
269+ return true
270+ }
271+ }
272+ } else {
273+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , v .DBName )) {
274+ return true
275+ }
276+ }
277+ return false
278+ }
279+
257280// A struct used to generate methods and fields on the Queries struct
258281type Query struct {
259282 Cmd string
0 commit comments