Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (g *Gen) Build(config *Config) error {
config.RightTemplateDelim = "}}"
}

var overrides map[string]string
var overrides map[string]swag.Override

if config.OverridesFile != "" {
overridesFile, err := open(config.OverridesFile)
Expand Down Expand Up @@ -379,8 +379,8 @@ func (g *Gen) formatSource(src []byte) []byte {
}

// Read and parse the overrides file.
func parseOverrides(r io.Reader) (map[string]string, error) {
overrides := make(map[string]string)
func parseOverrides(r io.Reader) (map[string]swag.Override, error) {
overrides := make(map[string]swag.Override)
scanner := bufio.NewScanner(r)

for scanner.Scan() {
Expand All @@ -393,24 +393,31 @@ func parseOverrides(r io.Reader) (map[string]string, error) {

parts := strings.Fields(line)

switch len(parts) {
case 0:
// only whitespace
if len(parts) == 0 {
continue
case 2:
// either a skip or malformed
if parts[0] != "skip" {
}

switch parts[0] {
case "skip":
if len(parts) != 2 {
return nil, fmt.Errorf("could not parse override: '%s'", line)
}

overrides[parts[1]] = ""
case 3:
// either a replace or malformed
if parts[0] != "replace" {
overrides[parts[1]] = swag.Override{}
case "replace":
if len(parts) < 3 {
return nil, fmt.Errorf("could not parse override: '%s'", line)
}

overrides[parts[1]] = parts[2]
attrs := make(map[string]string)
for _, attr := range parts[3:] {
kv := strings.SplitN(attr, ":", 2)
if len(kv) != 2 {
return nil, fmt.Errorf("malformed attribute '%s' in override: '%s'", attr, line)
}
attrs[kv[0]] = kv[1]
}

overrides[parts[1]] = swag.Override{Type: parts[2], Attrs: attrs}
default:
return nil, fmt.Errorf("could not parse override: '%s'", line)
}
Expand Down
52 changes: 39 additions & 13 deletions gen/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,59 +684,85 @@ func TestGen_parseOverrides(t *testing.T) {
testCases := []struct {
Name string
Data string
Expected map[string]string
Expected map[string]swag.Override
ExpectedError error
}{
{
Name: "replace",
Data: `replace github.com/foo/bar baz`,
Expected: map[string]string{
"github.com/foo/bar": "baz",
Expected: map[string]swag.Override{
"github.com/foo/bar": {Type: "baz", Attrs: map[string]string{}},
},
},
{
Name: "skip",
Data: `skip github.com/foo/bar`,
Expected: map[string]string{
"github.com/foo/bar": "",
Expected: map[string]swag.Override{
"github.com/foo/bar": {},
},
},
{
Name: "generic-simple",
Data: `replace types.Field[string] string`,
Expected: map[string]string{
"types.Field[string]": "string",
Expected: map[string]swag.Override{
"types.Field[string]": {Type: "string", Attrs: map[string]string{}},
},
},
{
Name: "generic-double",
Data: `replace types.Field[string,string] string`,
Expected: map[string]string{
"types.Field[string,string]": "string",
Expected: map[string]swag.Override{
"types.Field[string,string]": {Type: "string", Attrs: map[string]string{}},
},
},
{
Name: "comment",
Data: `// this is a comment
replace foo bar`,
Expected: map[string]string{
"foo": "bar",
Expected: map[string]swag.Override{
"foo": {Type: "bar", Attrs: map[string]string{}},
},
},
{
Name: "ignore whitespace",
Data: `

replace foo bar`,
Expected: map[string]string{
"foo": "bar",
Expected: map[string]swag.Override{
"foo": {Type: "bar", Attrs: map[string]string{}},
},
},
{
Name: "unknown directive",
Data: `foo`,
ExpectedError: fmt.Errorf("could not parse override: 'foo'"),
},
{
Name: "replace with attrs",
Data: `replace pkg.Optional[string] string optional:true nullable:true`,
Expected: map[string]swag.Override{
"pkg.Optional[string]": {Type: "string", Attrs: map[string]string{"optional": "true", "nullable": "true"}},
},
},
{
Name: "replace with format attr",
Data: `replace pkg.Optional[time.Time] string optional:true nullable:true format:date-time`,
Expected: map[string]swag.Override{
"pkg.Optional[time.Time]": {Type: "string", Attrs: map[string]string{"optional": "true", "nullable": "true", "format": "date-time"}},
},
},
{
Name: "malformed attr",
Data: `replace pkg.Foo string badattr`,
ExpectedError: fmt.Errorf("malformed attribute 'badattr' in override: 'replace pkg.Foo string badattr'"),
},
{
Name: "placeholder in key",
Data: `replace pkg.Wrapper[$T] $T nullable:true`,
Expected: map[string]swag.Override{
"pkg.Wrapper[$T]": {Type: "$T", Attrs: map[string]string{"nullable": "true"}},
},
},
}

for _, tc := range testCases {
Expand Down
8 changes: 4 additions & 4 deletions generics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func TestParseGenericsBasic(t *testing.T) {
assert.NoError(t, err)

p := New()
p.Overrides = map[string]string{
"types.Field[string]": "string",
"types.DoubleField[string,string]": "[]string",
"types.TrippleField[string,string]": "[][]string",
p.Overrides = map[string]Override{
"types.Field[string]": {Type: "string"},
"types.DoubleField[string,string]": {Type: "[]string"},
"types.TrippleField[string,string]": {Type: "[][]string"},
}

err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
Expand Down
119 changes: 104 additions & 15 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ type Parser struct {
fieldParserFactory FieldParserFactory

// Overrides allows global replacements of types. A blank replacement will be skipped.
Overrides map[string]string
Overrides map[string]Override

// parseGoList whether swag use go list to parse dependency
parseGoList bool
Expand All @@ -193,6 +193,23 @@ type Parser struct {
UseStructName bool
}

// Override represents a type override with optional attributes.
type Override struct {
// Type is the replacement type (e.g. "string", "number"). Empty means skip.
Type string
// Attrs holds key:value attributes (e.g. optional:true, nullable:true, format:date-time).
Attrs map[string]string
}

// IsOptional returns true if the override marks the field as optional.
func (o Override) IsOptional() bool { return o.Attrs["optional"] == "true" }

// IsNullable returns true if the override marks the field as nullable.
func (o Override) IsNullable() bool { return o.Attrs["nullable"] == "true" }

// Format returns the format attribute value, if any.
func (o Override) Format() string { return o.Attrs["format"] }

// FieldParserFactory create FieldParser.
type FieldParserFactory func(ps *Parser, field *ast.Field) FieldParser

Expand Down Expand Up @@ -248,7 +265,7 @@ func New(options ...func(*Parser)) *Parser {
excludes: make(map[string]struct{}),
tags: make(map[string]struct{}),
fieldParserFactory: newTagBaseFieldParser,
Overrides: make(map[string]string),
Overrides: make(map[string]Override),
}

for _, option := range options {
Expand Down Expand Up @@ -361,14 +378,57 @@ func SetFieldParserFactory(factory FieldParserFactory) func(parser *Parser) {
}

// SetOverrides allows the use of user-defined global type overrides.
func SetOverrides(overrides map[string]string) func(parser *Parser) {
func SetOverrides(overrides map[string]Override) func(parser *Parser) {
return func(p *Parser) {
for k, v := range overrides {
p.Overrides[k] = v
}
}
}

// matchOverride tries exact match first, then placeholder patterns containing $T.
func (parser *Parser) matchOverride(typeName string) *Override {
// 1. Exact match (highest priority)
if o, ok := parser.Overrides[typeName]; ok {
return &o
}

// 2. Placeholder match — find keys containing $T
for key, o := range parser.Overrides {
if !strings.Contains(key, "$T") {
continue
}
parts := strings.SplitN(key, "$T", 2)
prefix, suffix := parts[0], parts[1]
if strings.HasPrefix(typeName, prefix) && strings.HasSuffix(typeName, suffix) {
captured := typeName[len(prefix) : len(typeName)-len(suffix)]
if captured == "" {
continue
}
resolved := Override{
Type: strings.ReplaceAll(o.Type, "$T", captured),
Attrs: o.Attrs,
}
return &resolved
}
}

return nil
}

// getOverrideForType resolves an override for a type by trying the short name first, then the full path.
func (parser *Parser) getOverrideForType(typeName string, file *ast.File) *Override {
if o := parser.matchOverride(typeName); o != nil {
return o
}
if typeSpecDef := parser.packages.FindTypeSpec(typeName, file); typeSpecDef != nil {
if o := parser.matchOverride(typeSpecDef.FullPath()); o != nil {
return o
}
}
return nil
}

// SetCollectionFormat set default collection format
func SetCollectionFormat(collectionFormat string) func(*Parser) {
return func(p *Parser) {
Expand Down Expand Up @@ -1245,10 +1305,29 @@ func convertFromSpecificToPrimitive(typeName string) (string, error) {
return typeName, ErrFailedConvertPrimitiveType
}

// applyOverrideAttrs applies nullable and format attributes from an override to a schema.
func applyOverrideAttrs(schema *spec.Schema, override *Override) {
if override.IsNullable() {
schema.AddExtension("x-nullable", true)
}
if f := override.Format(); f != "" {
schema.Format = f
}
}

func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (*spec.Schema, error) {
if override, ok := parser.Overrides[typeName]; ok {
parser.debug.Printf("Override detected for %s: using %s instead", typeName, override)
return parseObjectSchema(parser, override, file)
if override := parser.matchOverride(typeName); override != nil {
if override.Type == "" {
parser.debug.Printf("Override detected for %s: ignoring", typeName)
return nil, ErrSkippedField
}
parser.debug.Printf("Override detected for %s: using %s instead", typeName, override.Type)
schema, err := parseObjectSchema(parser, override.Type, file)
if err != nil {
return nil, err
}
applyOverrideAttrs(schema, override)
return schema, nil
}

if IsInterfaceLike(typeName) {
Expand All @@ -1268,24 +1347,27 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (
return nil, fmt.Errorf("cannot find type definition: %s", typeName)
}

if override, ok := parser.Overrides[typeSpecDef.FullPath()]; ok {
if override == "" {
if override := parser.matchOverride(typeSpecDef.FullPath()); override != nil {
if override.Type == "" {
parser.debug.Printf("Override detected for %s: ignoring", typeSpecDef.FullPath())

return nil, ErrSkippedField
}

parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override)
parser.debug.Printf("Override detected for %s: using %s instead", typeSpecDef.FullPath(), override.Type)

separator := strings.LastIndex(override, ".")
separator := strings.LastIndex(override.Type, ".")
if separator == -1 {
// treat as a swaggertype tag
parts := strings.Split(override, ",")

return BuildCustomSchema(parts)
parts := strings.Split(override.Type, ",")
schema, err := BuildCustomSchema(parts)
if err != nil {
return nil, err
}
applyOverrideAttrs(schema, override)
return schema, nil
}

typeSpecDef = parser.packages.findTypeSpec(override[0:separator], override[separator+1:])
typeSpecDef = parser.packages.findTypeSpec(override.Type[0:separator], override.Type[separator+1:])
}

parser.packages.CheckTypeSpec(typeSpecDef)
Expand Down Expand Up @@ -1698,6 +1780,13 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
tagRequired = append(tagRequired, fieldNames...)
}

// Check if override marks this field as optional
if typeName, err := getFieldType(file, field.Type, nil); err == nil {
if override := parser.getOverrideForType(typeName, file); override != nil && override.IsOptional() {
tagRequired = nil
}
}

if formName := ps.FormName(); len(formName) > 0 {
schema.AddExtension("formData", formName)
}
Expand Down
Loading