Skip to content

Commit 5140cc5

Browse files
cmkqwertykyleconroy
authored andcommitted
fix(codegen/golang) type inference in array comparisons
1 parent 0b952b4 commit 5140cc5

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

internal/codegen/golang/query.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ type QueryValue struct {
2020

2121
// Column is kept so late in the generation process around to differentiate
2222
// between mysql slices and pg arrays
23-
Column *plugin.Column
23+
Column *plugin.Column
24+
QueryText string
2425
}
2526

2627
func (v QueryValue) EmitStruct() bool {
@@ -85,6 +86,9 @@ func (v QueryValue) SlicePair() string {
8586

8687
func (v QueryValue) Type() string {
8788
if v.Typ != "" {
89+
if v.isUsedWithArrayComparison() {
90+
return strings.Trim(v.Typ, "[]") // Return single type if used in array comparison.
91+
}
8892
return v.Typ
8993
}
9094
if v.Struct != nil {
@@ -113,6 +117,9 @@ func (v QueryValue) UniqueFields() []Field {
113117
fields := make([]Field, 0, len(v.Struct.Fields))
114118

115119
for _, field := range v.Struct.Fields {
120+
if v.isUsedWithArrayComparison() {
121+
field.Type = strings.Trim(field.Type, "[]")
122+
}
116123
if _, found := seen[field.Name]; found {
117124
continue
118125
}
@@ -129,14 +136,14 @@ func (v QueryValue) Params() string {
129136
}
130137
var out []string
131138
if v.Struct == nil {
132-
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
139+
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
133140
out = append(out, "pq.Array("+escape(v.Name)+")")
134141
} else {
135142
out = append(out, escape(v.Name))
136143
}
137144
} else {
138145
for _, f := range v.Struct.Fields {
139-
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
146+
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
140147
out = append(out, "pq.Array("+escape(v.VariableForField(f))+")")
141148
} else {
142149
out = append(out, escape(v.VariableForField(f)))
@@ -254,6 +261,22 @@ func (v QueryValue) VariableForField(f Field) string {
254261
return v.Name + "." + f.Name
255262
}
256263

264+
// isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
265+
func (v QueryValue) isUsedWithArrayComparison() bool {
266+
if v.Struct != nil {
267+
for _, f := range v.Struct.Fields {
268+
if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", f.DBName)) {
269+
return true
270+
}
271+
}
272+
} else {
273+
if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", v.DBName)) {
274+
return true
275+
}
276+
}
277+
return false
278+
}
279+
257280
// A struct used to generate methods and fields on the Queries struct
258281
type Query struct {
259282
Cmd string

internal/codegen/golang/result.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
238238
Typ: goType(req, options, p.Column),
239239
SQLDriver: sqlpkg,
240240
Column: p.Column,
241+
QueryText: gq.SQL,
241242
}
242243
} else if len(query.Params) >= 1 {
243244
var cols []goColumn
@@ -257,6 +258,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
257258
Struct: s,
258259
SQLDriver: sqlpkg,
259260
EmitPointer: options.EmitParamsStructPointers,
261+
QueryText: gq.SQL,
260262
}
261263

262264
// if query params is 2, and query params limit is 4 AND this is a copyfrom, we still want to emit the query's model

0 commit comments

Comments
 (0)