Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, o
if oride.GoType.StructTags == nil {
continue
}
if override.MatchesColumn(col) {
if override.MatchesColumn(col, req.Settings.Engine) {
for k, v := range oride.GoType.StructTags {
tags[k] = v
}
Expand Down Expand Up @@ -76,7 +76,8 @@ func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin
if oride.GoType.TypeName == "" {
continue
}
if override.MatchesColumn(col) {

if override.MatchesColumn(col, req.Settings.Engine) {
return oride.GoType.TypeName
}
}
Expand Down
33 changes: 30 additions & 3 deletions internal/codegen/golang/opts/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,37 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool {
return true
}

func (o *Override) MatchesColumn(col *plugin.Column) bool {
columnType := sdk.DataType(col.Type)
func typesMatches(dbType string, colType *plugin.Identifier, isPostgresql bool) bool {
if dbType == "" {
return false
}
columnType := sdk.DataType(colType)
if dbType == columnType {
return true
}
// For example, in PostgreSQL, built-in types are in the 'pg_catalog' schema.
// colType Identifier might show them as:
// - Schema: "pg_catalog", Name: "json"
// - Or Name: "pg_catalog.json"
// - Or just Name: "json"
// This checks both to match types.
if isPostgresql {
if strings.TrimPrefix(dbType, "pg_catalog.") == columnType {
return true
}
if colType.Schema == "pg_catalog" && colType.Name == dbType {
return true
}

return strings.TrimPrefix(colType.Name, "pg_catalog.") == dbType
}

return false
}

func (o *Override) MatchesColumn(col *plugin.Column, engine string) bool {
notNull := col.NotNull || col.IsArray
return o.DBType != "" && o.DBType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned
return typesMatches(o.DBType, col.Type, engine == "postgresql") && o.Nullable != notNull && o.Unsigned == col.Unsigned
}

func (o *Override) parse(req *plugin.GenerateRequest) (err error) {
Expand Down
86 changes: 86 additions & 0 deletions internal/codegen/golang/opts/override_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func TestTypeOverrides(t *testing.T) {
Expand Down Expand Up @@ -115,3 +116,88 @@ func FuzzOverride(f *testing.F) {
o.parse(nil)
})
}

func TestOverride_MatchesColumn(t *testing.T) {
t.Parallel()
type testCase struct {
specName string
override Override
Column *plugin.Column
engine string
expected bool
}

testCases := []*testCase{
{
specName: "matches with pg_catalog in schema and name",
override: Override{
DBType: "json",
Nullable: false,
},
Column: &plugin.Column{
Name: "data",
Type: &plugin.Identifier{
Schema: "pg_catalog",
Name: "json",
},
NotNull: true,
IsArray: false,
},
engine: "postgresql",
expected: true,
},
{
specName: "matches only with name",
override: Override{
DBType: "json",
Nullable: false,
},
Column: &plugin.Column{
Name: "data",
Type: &plugin.Identifier{
Name: "json",
},
NotNull: true,
IsArray: false,
},
engine: "postgresql",
expected: true,
},
{
specName: "matches with pg_catalog in name",
override: Override{
DBType: "json",
Nullable: false,
},
Column: &plugin.Column{
Name: "data",
Type: &plugin.Identifier{
Name: "pg_catalog.json",
},
NotNull: true,
IsArray: false,
},
engine: "postgresql",
expected: true,
},
}

for _, test := range testCases {
tt := *test
t.Run(tt.specName, func(t *testing.T) {
result := tt.override.MatchesColumn(tt.Column, tt.engine)
if result != tt.expected {
t.Errorf("mismatch; got %v; want %v", result, tt.expected)
}
if tt.engine == "postgresql" && tt.expected == true {
tt.override.DBType = "pg_catalog." + tt.override.DBType
result = tt.override.MatchesColumn(test.Column, tt.engine)
if !result {
t.Errorf("mismatch; got %v; want %v", result, tt.expected)
}
}

})

}
}
Loading