diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index c4aac84dd6..280e323814 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -14,7 +14,7 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, o if oride.GoType.StructTags == nil { continue } - if override.MatchesColumn(col) { + if override.MatchesColumn(col, req.Settings.Engine) { for k, v := range oride.GoType.StructTags { tags[k] = v } @@ -76,7 +76,8 @@ func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin if oride.GoType.TypeName == "" { continue } - if override.MatchesColumn(col) { + + if override.MatchesColumn(col, req.Settings.Engine) { return oride.GoType.TypeName } } diff --git a/internal/codegen/golang/opts/override.go b/internal/codegen/golang/opts/override.go index 6916c0c7f3..8a0632c135 100644 --- a/internal/codegen/golang/opts/override.go +++ b/internal/codegen/golang/opts/override.go @@ -77,10 +77,37 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool { return true } -func (o *Override) MatchesColumn(col *plugin.Column) bool { - columnType := sdk.DataType(col.Type) +func typesMatches(dbType string, colType *plugin.Identifier, isPostgresql bool) bool { + if dbType == "" { + return false + } + columnType := sdk.DataType(colType) + if dbType == columnType { + return true + } + // For example, in PostgreSQL, built-in types are in the 'pg_catalog' schema. + // colType Identifier might show them as: + // - Schema: "pg_catalog", Name: "json" + // - Or Name: "pg_catalog.json" + // - Or just Name: "json" + // This checks both to match types. + if isPostgresql { + if strings.TrimPrefix(dbType, "pg_catalog.") == columnType { + return true + } + if colType.Schema == "pg_catalog" && colType.Name == dbType { + return true + } + + return strings.TrimPrefix(colType.Name, "pg_catalog.") == dbType + } + + return false +} + +func (o *Override) MatchesColumn(col *plugin.Column, engine string) bool { notNull := col.NotNull || col.IsArray - return o.DBType != "" && o.DBType == columnType && o.Nullable != notNull && o.Unsigned == col.Unsigned + return typesMatches(o.DBType, col.Type, engine == "postgresql") && o.Nullable != notNull && o.Unsigned == col.Unsigned } func (o *Override) parse(req *plugin.GenerateRequest) (err error) { diff --git a/internal/codegen/golang/opts/override_test.go b/internal/codegen/golang/opts/override_test.go index 8405666f36..b6b64836d5 100644 --- a/internal/codegen/golang/opts/override_test.go +++ b/internal/codegen/golang/opts/override_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/sqlc-dev/sqlc/internal/plugin" ) func TestTypeOverrides(t *testing.T) { @@ -115,3 +116,88 @@ func FuzzOverride(f *testing.F) { o.parse(nil) }) } + +func TestOverride_MatchesColumn(t *testing.T) { + t.Parallel() + type testCase struct { + specName string + override Override + Column *plugin.Column + engine string + expected bool + } + + testCases := []*testCase{ + { + specName: "matches with pg_catalog in schema and name", + override: Override{ + DBType: "json", + Nullable: false, + }, + Column: &plugin.Column{ + Name: "data", + Type: &plugin.Identifier{ + Schema: "pg_catalog", + Name: "json", + }, + NotNull: true, + IsArray: false, + }, + engine: "postgresql", + expected: true, + }, + { + specName: "matches only with name", + override: Override{ + DBType: "json", + Nullable: false, + }, + Column: &plugin.Column{ + Name: "data", + Type: &plugin.Identifier{ + Name: "json", + }, + NotNull: true, + IsArray: false, + }, + engine: "postgresql", + expected: true, + }, + { + specName: "matches with pg_catalog in name", + override: Override{ + DBType: "json", + Nullable: false, + }, + Column: &plugin.Column{ + Name: "data", + Type: &plugin.Identifier{ + Name: "pg_catalog.json", + }, + NotNull: true, + IsArray: false, + }, + engine: "postgresql", + expected: true, + }, + } + + for _, test := range testCases { + tt := *test + t.Run(tt.specName, func(t *testing.T) { + result := tt.override.MatchesColumn(tt.Column, tt.engine) + if result != tt.expected { + t.Errorf("mismatch; got %v; want %v", result, tt.expected) + } + if tt.engine == "postgresql" && tt.expected == true { + tt.override.DBType = "pg_catalog." + tt.override.DBType + result = tt.override.MatchesColumn(test.Column, tt.engine) + if !result { + t.Errorf("mismatch; got %v; want %v", result, tt.expected) + } + } + + }) + + } +}