diff --git a/gen/gen.go b/gen/gen.go index 915ed1dbe..94ed209ad 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -183,7 +183,7 @@ func (g *Gen) Build(config *Config) error { config.RightTemplateDelim = "}}" } - var overrides map[string]string + var overrides map[string]swag.Override if config.OverridesFile != "" { overridesFile, err := open(config.OverridesFile) @@ -379,8 +379,8 @@ func (g *Gen) formatSource(src []byte) []byte { } // Read and parse the overrides file. -func parseOverrides(r io.Reader) (map[string]string, error) { - overrides := make(map[string]string) +func parseOverrides(r io.Reader) (map[string]swag.Override, error) { + overrides := make(map[string]swag.Override) scanner := bufio.NewScanner(r) for scanner.Scan() { @@ -393,24 +393,31 @@ func parseOverrides(r io.Reader) (map[string]string, error) { parts := strings.Fields(line) - switch len(parts) { - case 0: - // only whitespace + if len(parts) == 0 { continue - case 2: - // either a skip or malformed - if parts[0] != "skip" { + } + + switch parts[0] { + case "skip": + if len(parts) != 2 { return nil, fmt.Errorf("could not parse override: '%s'", line) } - - overrides[parts[1]] = "" - case 3: - // either a replace or malformed - if parts[0] != "replace" { + overrides[parts[1]] = swag.Override{} + case "replace": + if len(parts) < 3 { return nil, fmt.Errorf("could not parse override: '%s'", line) } - overrides[parts[1]] = parts[2] + attrs := make(map[string]string) + for _, attr := range parts[3:] { + kv := strings.SplitN(attr, ":", 2) + if len(kv) != 2 { + return nil, fmt.Errorf("malformed attribute '%s' in override: '%s'", attr, line) + } + attrs[kv[0]] = kv[1] + } + + overrides[parts[1]] = swag.Override{Type: parts[2], Attrs: attrs} default: return nil, fmt.Errorf("could not parse override: '%s'", line) } diff --git a/gen/gen_test.go b/gen/gen_test.go index 97c0cdef4..1831685a7 100644 --- a/gen/gen_test.go +++ b/gen/gen_test.go @@ -684,43 +684,43 @@ func TestGen_parseOverrides(t *testing.T) { testCases := []struct { Name string Data string - Expected map[string]string + Expected map[string]swag.Override ExpectedError error }{ { Name: "replace", Data: `replace github.com/foo/bar baz`, - Expected: map[string]string{ - "github.com/foo/bar": "baz", + Expected: map[string]swag.Override{ + "github.com/foo/bar": {Type: "baz", Attrs: map[string]string{}}, }, }, { Name: "skip", Data: `skip github.com/foo/bar`, - Expected: map[string]string{ - "github.com/foo/bar": "", + Expected: map[string]swag.Override{ + "github.com/foo/bar": {}, }, }, { Name: "generic-simple", Data: `replace types.Field[string] string`, - Expected: map[string]string{ - "types.Field[string]": "string", + Expected: map[string]swag.Override{ + "types.Field[string]": {Type: "string", Attrs: map[string]string{}}, }, }, { Name: "generic-double", Data: `replace types.Field[string,string] string`, - Expected: map[string]string{ - "types.Field[string,string]": "string", + Expected: map[string]swag.Override{ + "types.Field[string,string]": {Type: "string", Attrs: map[string]string{}}, }, }, { Name: "comment", Data: `// this is a comment replace foo bar`, - Expected: map[string]string{ - "foo": "bar", + Expected: map[string]swag.Override{ + "foo": {Type: "bar", Attrs: map[string]string{}}, }, }, { @@ -728,8 +728,8 @@ func TestGen_parseOverrides(t *testing.T) { Data: ` replace foo bar`, - Expected: map[string]string{ - "foo": "bar", + Expected: map[string]swag.Override{ + "foo": {Type: "bar", Attrs: map[string]string{}}, }, }, { @@ -737,6 +737,32 @@ func TestGen_parseOverrides(t *testing.T) { Data: `foo`, ExpectedError: fmt.Errorf("could not parse override: 'foo'"), }, + { + Name: "replace with attrs", + Data: `replace pkg.Optional[string] string optional:true nullable:true`, + Expected: map[string]swag.Override{ + "pkg.Optional[string]": {Type: "string", Attrs: map[string]string{"optional": "true", "nullable": "true"}}, + }, + }, + { + Name: "replace with format attr", + Data: `replace pkg.Optional[time.Time] string optional:true nullable:true format:date-time`, + Expected: map[string]swag.Override{ + "pkg.Optional[time.Time]": {Type: "string", Attrs: map[string]string{"optional": "true", "nullable": "true", "format": "date-time"}}, + }, + }, + { + Name: "malformed attr", + Data: `replace pkg.Foo string badattr`, + ExpectedError: fmt.Errorf("malformed attribute 'badattr' in override: 'replace pkg.Foo string badattr'"), + }, + { + Name: "placeholder in key", + Data: `replace pkg.Wrapper[$T] $T nullable:true`, + Expected: map[string]swag.Override{ + "pkg.Wrapper[$T]": {Type: "$T", Attrs: map[string]string{"nullable": "true"}}, + }, + }, } for _, tc := range testCases { diff --git a/generics_test.go b/generics_test.go index 74dad3f9a..115f36990 100644 --- a/generics_test.go +++ b/generics_test.go @@ -27,10 +27,10 @@ func TestParseGenericsBasic(t *testing.T) { assert.NoError(t, err) p := New() - p.Overrides = map[string]string{ - "types.Field[string]": "string", - "types.DoubleField[string,string]": "[]string", - "types.TrippleField[string,string]": "[][]string", + p.Overrides = map[string]Override{ + "types.Field[string]": {Type: "string"}, + "types.DoubleField[string,string]": {Type: "[]string"}, + "types.TrippleField[string,string]": {Type: "[][]string"}, } err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) diff --git a/parser.go b/parser.go index 1d19f27bd..3a5585116 100644 --- a/parser.go +++ b/parser.go @@ -169,7 +169,7 @@ type Parser struct { fieldParserFactory FieldParserFactory // Overrides allows global replacements of types. A blank replacement will be skipped. - Overrides map[string]string + Overrides map[string]Override // parseGoList whether swag use go list to parse dependency parseGoList bool @@ -193,6 +193,23 @@ type Parser struct { UseStructName bool } +// Override represents a type override with optional attributes. +type Override struct { + // Type is the replacement type (e.g. "string", "number"). Empty means skip. + Type string + // Attrs holds key:value attributes (e.g. optional:true, nullable:true, format:date-time). + Attrs map[string]string +} + +// IsOptional returns true if the override marks the field as optional. +func (o Override) IsOptional() bool { return o.Attrs["optional"] == "true" } + +// IsNullable returns true if the override marks the field as nullable. +func (o Override) IsNullable() bool { return o.Attrs["nullable"] == "true" } + +// Format returns the format attribute value, if any. +func (o Override) Format() string { return o.Attrs["format"] } + // FieldParserFactory create FieldParser. type FieldParserFactory func(ps *Parser, field *ast.Field) FieldParser @@ -248,7 +265,7 @@ func New(options ...func(*Parser)) *Parser { excludes: make(map[string]struct{}), tags: make(map[string]struct{}), fieldParserFactory: newTagBaseFieldParser, - Overrides: make(map[string]string), + Overrides: make(map[string]Override), } for _, option := range options { @@ -361,7 +378,7 @@ func SetFieldParserFactory(factory FieldParserFactory) func(parser *Parser) { } // SetOverrides allows the use of user-defined global type overrides. -func SetOverrides(overrides map[string]string) func(parser *Parser) { +func SetOverrides(overrides map[string]Override) func(parser *Parser) { return func(p *Parser) { for k, v := range overrides { p.Overrides[k] = v @@ -369,6 +386,49 @@ func SetOverrides(overrides map[string]string) func(parser *Parser) { } } +// matchOverride tries exact match first, then placeholder patterns containing $T. +func (parser *Parser) matchOverride(typeName string) *Override { + // 1. Exact match (highest priority) + if o, ok := parser.Overrides[typeName]; ok { + return &o + } + + // 2. Placeholder match — find keys containing $T + for key, o := range parser.Overrides { + if !strings.Contains(key, "$T") { + continue + } + parts := strings.SplitN(key, "$T", 2) + prefix, suffix := parts[0], parts[1] + if strings.HasPrefix(typeName, prefix) && strings.HasSuffix(typeName, suffix) { + captured := typeName[len(prefix) : len(typeName)-len(suffix)] + if captured == "" { + continue + } + resolved := Override{ + Type: strings.ReplaceAll(o.Type, "$T", captured), + Attrs: o.Attrs, + } + return &resolved + } + } + + return nil +} + +// getOverrideForType resolves an override for a type by trying the short name first, then the full path. +func (parser *Parser) getOverrideForType(typeName string, file *ast.File) *Override { + if o := parser.matchOverride(typeName); o != nil { + return o + } + if typeSpecDef := parser.packages.FindTypeSpec(typeName, file); typeSpecDef != nil { + if o := parser.matchOverride(typeSpecDef.FullPath()); o != nil { + return o + } + } + return nil +} + // SetCollectionFormat set default collection format func SetCollectionFormat(collectionFormat string) func(*Parser) { return func(p *Parser) { @@ -1245,10 +1305,29 @@ func convertFromSpecificToPrimitive(typeName string) (string, error) { return typeName, ErrFailedConvertPrimitiveType } +// applyOverrideAttrs applies nullable and format attributes from an override to a schema. +func applyOverrideAttrs(schema *spec.Schema, override *Override) { + if override.IsNullable() { + schema.AddExtension("x-nullable", true) + } + if f := override.Format(); f != "" { + schema.Format = f + } +} + func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (*spec.Schema, error) { - if override, ok := parser.Overrides[typeName]; ok { - parser.debug.Printf("Override detected for %s: using %s instead", typeName, override) - return parseObjectSchema(parser, override, file) + if override := parser.matchOverride(typeName); override != nil { + if override.Type == "" { + parser.debug.Printf("Override detected for %s: ignoring", typeName) + return nil, ErrSkippedField + } + parser.debug.Printf("Override detected for %s: using %s instead", typeName, override.Type) + schema, err := parseObjectSchema(parser, override.Type, file) + if err != nil { + return nil, err + } + applyOverrideAttrs(schema, override) + return schema, nil } if IsInterfaceLike(typeName) { @@ -1268,24 +1347,27 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( return nil, fmt.Errorf("cannot find type definition: %s", typeName) } - if override, ok := parser.Overrides[typeSpecDef.FullPath()]; ok { - if override == "" { + if override := parser.matchOverride(typeSpecDef.FullPath()); override != nil { + if override.Type == "" { parser.debug.Printf("Override detected for %s: ignoring", typeSpecDef.FullPath()) - return nil, ErrSkippedField } - parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override) + parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override.Type) - separator := strings.LastIndex(override, ".") + separator := strings.LastIndex(override.Type, ".") if separator == -1 { // treat as a swaggertype tag - parts := strings.Split(override, ",") - - return BuildCustomSchema(parts) + parts := strings.Split(override.Type, ",") + schema, err := BuildCustomSchema(parts) + if err != nil { + return nil, err + } + applyOverrideAttrs(schema, override) + return schema, nil } - typeSpecDef = parser.packages.findTypeSpec(override[0:separator], override[separator+1:]) + typeSpecDef = parser.packages.findTypeSpec(override.Type[0:separator], override.Type[separator+1:]) } parser.packages.CheckTypeSpec(typeSpecDef) @@ -1698,6 +1780,13 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st tagRequired = append(tagRequired, fieldNames...) } + // Check if override marks this field as optional + if typeName, err := getFieldType(file, field.Type, nil); err == nil { + if override := parser.getOverrideForType(typeName, file); override != nil && override.IsOptional() { + tagRequired = nil + } + } + if formName := ps.FormName(); len(formName) > 0 { schema.AddExtension("formData", formName) } diff --git a/parser_test.go b/parser_test.go index 3aa540bb6..776dd976d 100644 --- a/parser_test.go +++ b/parser_test.go @@ -69,8 +69,8 @@ func TestNew(t *testing.T) { func TestSetOverrides(t *testing.T) { t.Parallel() - overrides := map[string]string{ - "foo": "bar", + overrides := map[string]Override{ + "foo": {Type: "bar"}, } p := New(SetOverrides(overrides)) @@ -80,8 +80,8 @@ func TestSetOverrides(t *testing.T) { func TestOverrides_getTypeSchema(t *testing.T) { t.Parallel() - overrides := map[string]string{ - "sql.NullString": "string", + overrides := map[string]Override{ + "sql.NullString": {Type: "string"}, } p := New(SetOverrides(overrides)) @@ -105,6 +105,130 @@ func TestOverrides_getTypeSchema(t *testing.T) { }) } +func TestOverrides_nullable(t *testing.T) { + t.Parallel() + + overrides := map[string]Override{ + "sql.NullString": {Type: "string", Attrs: map[string]string{"nullable": "true"}}, + } + + p := New(SetOverrides(overrides)) + + s, err := p.getTypeSchema("sql.NullString", nil, false) + if assert.NoError(t, err) { + assert.Truef(t, s.Type.Contains("string"), "type sql.NullString should be overridden by string") + assert.Equal(t, true, s.Extensions["x-nullable"]) + } +} + +func TestOverrides_format(t *testing.T) { + t.Parallel() + + overrides := map[string]Override{ + "sql.NullTime": {Type: "string", Attrs: map[string]string{"format": "date-time"}}, + } + + p := New(SetOverrides(overrides)) + + s, err := p.getTypeSchema("sql.NullTime", nil, false) + if assert.NoError(t, err) { + assert.Truef(t, s.Type.Contains("string"), "type sql.NullTime should be overridden by string") + assert.Equal(t, "date-time", s.Format) + } +} + +func TestOverrides_placeholder(t *testing.T) { + t.Parallel() + + overrides := map[string]Override{ + "pkg.Wrapper[$T]": {Type: "$T", Attrs: map[string]string{"nullable": "true"}}, + } + + p := New(SetOverrides(overrides)) + + t.Run("Placeholder resolves string", func(t *testing.T) { + t.Parallel() + + s, err := p.getTypeSchema("pkg.Wrapper[string]", nil, false) + if assert.NoError(t, err) { + assert.Truef(t, s.Type.Contains("string"), "pkg.Wrapper[string] should resolve to string") + assert.Equal(t, true, s.Extensions["x-nullable"]) + } + }) + + t.Run("Placeholder resolves number", func(t *testing.T) { + t.Parallel() + + s, err := p.getTypeSchema("pkg.Wrapper[number]", nil, false) + if assert.NoError(t, err) { + assert.Truef(t, s.Type.Contains("number"), "pkg.Wrapper[number] should resolve to number") + assert.Equal(t, true, s.Extensions["x-nullable"]) + } + }) + + t.Run("No match for different type", func(t *testing.T) { + t.Parallel() + + _, err := p.getTypeSchema("other.Type", nil, false) + assert.Error(t, err) + }) +} + +func TestOverrides_matchOverride(t *testing.T) { + t.Parallel() + + t.Run("Exact match takes priority over placeholder", func(t *testing.T) { + t.Parallel() + + p := New(SetOverrides(map[string]Override{ + "pkg.Wrapper[string]": {Type: "string"}, + "pkg.Wrapper[$T]": {Type: "$T", Attrs: map[string]string{"nullable": "true"}}, + })) + + o := p.matchOverride("pkg.Wrapper[string]") + if assert.NotNil(t, o) { + assert.Equal(t, "string", o.Type) + assert.False(t, o.IsNullable(), "exact match should not have nullable attr") + } + }) + + t.Run("Placeholder captures type parameter", func(t *testing.T) { + t.Parallel() + + p := New(SetOverrides(map[string]Override{ + "pkg.Wrapper[$T]": {Type: "$T", Attrs: map[string]string{"nullable": "true"}}, + })) + + o := p.matchOverride("pkg.Wrapper[custom.Type]") + if assert.NotNil(t, o) { + assert.Equal(t, "custom.Type", o.Type) + assert.True(t, o.IsNullable()) + } + }) + + t.Run("No match returns nil", func(t *testing.T) { + t.Parallel() + + p := New(SetOverrides(map[string]Override{ + "pkg.Wrapper[$T]": {Type: "$T"}, + })) + + o := p.matchOverride("other.Type") + assert.Nil(t, o) + }) + + t.Run("Empty capture is skipped", func(t *testing.T) { + t.Parallel() + + p := New(SetOverrides(map[string]Override{ + "pkg.Wrapper[$T]": {Type: "$T"}, + })) + + o := p.matchOverride("pkg.Wrapper[]") + assert.Nil(t, o) + }) +} + func TestParser_ParseDefinition(t *testing.T) { p := New() @@ -2183,10 +2307,10 @@ func TestParseTypeOverrides(t *testing.T) { t.Parallel() searchDir := "testdata/global_override" - p := New(SetOverrides(map[string]string{ - "github.com/swaggo/swag/testdata/global_override/types.Application": "string", - "github.com/swaggo/swag/testdata/global_override/types.Application2": "github.com/swaggo/swag/testdata/global_override/othertypes.Application", - "github.com/swaggo/swag/testdata/global_override/types.ShouldSkip": "", + p := New(SetOverrides(map[string]Override{ + "github.com/swaggo/swag/testdata/global_override/types.Application": {Type: "string"}, + "github.com/swaggo/swag/testdata/global_override/types.Application2": {Type: "github.com/swaggo/swag/testdata/global_override/othertypes.Application"}, + "github.com/swaggo/swag/testdata/global_override/types.ShouldSkip": {}, })) err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) assert.NoError(t, err)