Skip to content

Commit 2f14580

Browse files
committed
fix(codegen/golang) type inference in array comparisons
1 parent a4fa8b9 commit 2f14580

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

internal/codegen/golang/query.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ type QueryValue struct {
1919

2020
// Column is kept so late in the generation process around to differentiate
2121
// between mysql slices and pg arrays
22-
Column *plugin.Column
22+
Column *plugin.Column
23+
QueryText string
2324
}
2425

2526
func (v QueryValue) EmitStruct() bool {
@@ -84,6 +85,9 @@ func (v QueryValue) SlicePair() string {
8485

8586
func (v QueryValue) Type() string {
8687
if v.Typ != "" {
88+
if v.isUsedWithArrayComparison() {
89+
return strings.Trim(v.Typ, "[]") // Return single type if used in array comparison.
90+
}
8791
return v.Typ
8892
}
8993
if v.Struct != nil {
@@ -112,6 +116,9 @@ func (v QueryValue) UniqueFields() []Field {
112116
fields := make([]Field, 0, len(v.Struct.Fields))
113117

114118
for _, field := range v.Struct.Fields {
119+
if v.isUsedWithArrayComparison() {
120+
field.Type = strings.Trim(field.Type, "[]")
121+
}
115122
if _, found := seen[field.Name]; found {
116123
continue
117124
}
@@ -128,14 +135,14 @@ func (v QueryValue) Params() string {
128135
}
129136
var out []string
130137
if v.Struct == nil {
131-
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
138+
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
132139
out = append(out, "pq.Array("+escape(v.Name)+")")
133140
} else {
134141
out = append(out, escape(v.Name))
135142
}
136143
} else {
137144
for _, f := range v.Struct.Fields {
138-
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
145+
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
139146
out = append(out, "pq.Array("+escape(v.VariableForField(f))+")")
140147
} else {
141148
out = append(out, escape(v.VariableForField(f)))
@@ -253,6 +260,22 @@ func (v QueryValue) VariableForField(f Field) string {
253260
return v.Name + "." + f.Name
254261
}
255262

263+
// isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
264+
func (v QueryValue) isUsedWithArrayComparison() bool {
265+
if v.Struct != nil {
266+
for _, f := range v.Struct.Fields {
267+
if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", f.DBName)) {
268+
return true
269+
}
270+
}
271+
} else {
272+
if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", v.DBName)) {
273+
return true
274+
}
275+
}
276+
return false
277+
}
278+
256279
// A struct used to generate methods and fields on the Queries struct
257280
type Query struct {
258281
Cmd string

internal/codegen/golang/result.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
221221
Typ: goType(req, options, p.Column),
222222
SQLDriver: sqlpkg,
223223
Column: p.Column,
224+
QueryText: gq.SQL,
224225
}
225226
} else if len(query.Params) >= 1 {
226227
var cols []goColumn
@@ -240,6 +241,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
240241
Struct: s,
241242
SQLDriver: sqlpkg,
242243
EmitPointer: options.EmitParamsStructPointers,
244+
QueryText: gq.SQL,
243245
}
244246

245247
if len(query.Params) <= qpl {

0 commit comments

Comments
 (0)