Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ type QueryValue struct {

// Column is kept so late in the generation process around to differentiate
// between mysql slices and pg arrays
Column *plugin.Column
Column *plugin.Column
QueryText string
}

func (v QueryValue) EmitStruct() bool {
Expand Down Expand Up @@ -85,6 +86,9 @@ func (v QueryValue) SlicePair() string {

func (v QueryValue) Type() string {
if v.Typ != "" {
if v.isUsedWithArrayComparison() {
return strings.Trim(v.Typ, "[]") // Return single type if used in array comparison.
}
return v.Typ
}
if v.Struct != nil {
Expand Down Expand Up @@ -113,6 +117,9 @@ func (v QueryValue) UniqueFields() []Field {
fields := make([]Field, 0, len(v.Struct.Fields))

for _, field := range v.Struct.Fields {
if v.isUsedWithArrayComparison() {
field.Type = strings.Trim(field.Type, "[]")
}
if _, found := seen[field.Name]; found {
continue
}
Expand All @@ -129,14 +136,14 @@ func (v QueryValue) Params() string {
}
var out []string
if v.Struct == nil {
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
out = append(out, "pq.Array("+escape(v.Name)+")")
} else {
out = append(out, escape(v.Name))
}
} else {
for _, f := range v.Struct.Fields {
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() && !v.isUsedWithArrayComparison() {
out = append(out, "pq.Array("+escape(v.VariableForField(f))+")")
} else {
out = append(out, escape(v.VariableForField(f)))
Expand Down Expand Up @@ -254,6 +261,22 @@ func (v QueryValue) VariableForField(f Field) string {
return v.Name + "." + f.Name
}

// isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
func (v QueryValue) isUsedWithArrayComparison() bool {
if v.Struct != nil {
for _, f := range v.Struct.Fields {
if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", f.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", f.DBName)) {
return true
}
}
} else {
if strings.Contains(v.QueryText, fmt.Sprintf("ANY(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("SOME(%s)", v.DBName)) || strings.Contains(v.QueryText, fmt.Sprintf("ALL(%s)", v.DBName)) {
return true
}
}
return false
}

// A struct used to generate methods and fields on the Queries struct
type Query struct {
Cmd string
Expand Down
2 changes: 2 additions & 0 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
Typ: goType(req, options, p.Column),
SQLDriver: sqlpkg,
Column: p.Column,
QueryText: gq.SQL,
}
} else if len(query.Params) >= 1 {
var cols []goColumn
Expand All @@ -257,6 +258,7 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
Struct: s,
SQLDriver: sqlpkg,
EmitPointer: options.EmitParamsStructPointers,
QueryText: gq.SQL,
}

// if query params is 2, and query params limit is 4 AND this is a copyfrom, we still want to emit the query's model
Expand Down
Loading