diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 3b4fb2fa1a..93c9fa3a9e 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -20,7 +20,8 @@ type QueryValue struct { // Column is kept so late in the generation process around to differentiate // between mysql slices and pg arrays - Column *plugin.Column + Column *plugin.Column + QueryText string } func (v QueryValue) EmitStruct() bool { @@ -85,6 +86,9 @@ func (v QueryValue) SlicePair() string { func (v QueryValue) Type() string { if v.Typ != "" { + if v.isUsedWithArrayComparison() { + return strings.Trim(v.Typ, "[]") // Return single type if used in array comparison. + } return v.Typ } if v.Struct != nil { @@ -113,6 +117,9 @@ func (v QueryValue) UniqueFields() []Field { fields := make([]Field, 0, len(v.Struct.Fields)) for _, field := range v.Struct.Fields { + if v.isUsedWithArrayComparison() { + field.Type = strings.Trim(field.Type, "[]") + } if _, found := seen[field.Name]; found { continue } @@ -129,14 +136,14 @@ func (v QueryValue) Params() string { } var out []string if v.Struct == nil { - if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { + if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() { out = append(out, "pq.Array("+escape(v.Name)+")") } else { out = append(out, escape(v.Name)) } } else { for _, f := range v.Struct.Fields { - if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() { + if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() { out = append(out, "pq.Array("+escape(v.VariableForField(f))+")") } else { out = append(out, escape(v.VariableForField(f))) @@ -254,6 +261,22 @@ func (v QueryValue) VariableForField(f Field) string { return v.Name + "." + f.Name } +// isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query. +func (v QueryValue) isUsedWithArrayComparison() bool { + if v.Struct != nil { + for _, f := range v.Struct.Fields { + if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", f.DBName)) { + return true + } + } + } else { + if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", v.DBName)) { + return true + } + } + return false +} + // A struct used to generate methods and fields on the Queries struct type Query struct { Cmd string diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 515d0a654f..9ae5bc4da6 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -238,6 +238,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] Typ: goType(req, options, p.Column), SQLDriver: sqlpkg, Column: p.Column, + QueryText: gq.SQL, } } else if len(query.Params) >= 1 { var cols []goColumn @@ -257,6 +258,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] Struct: s, SQLDriver: sqlpkg, EmitPointer: options.EmitParamsStructPointers, + QueryText: gq.SQL, } // if query params is 2, and query params limit is 4 AND this is a copyfrom, we still want to emit the query's model