Skip to content

Commit 9a6447a

Browse files
authored
feat: goctl model Add a new method hasField (#5484)
1 parent 004995f commit 9a6447a

File tree

8 files changed

+193
-11
lines changed

8 files changed

+193
-11
lines changed

tools/goctl/model/sql/gen/findone.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ func genFindOne(table Table, withCache, postgreSql bool) (string, string, error)
1616

1717
output, err := util.With("findOne").
1818
Parse(text).
19+
AddFunc("hasField", hasField(table)).
1920
Execute(map[string]any{
2021
"withCache": withCache,
2122
"upperStartCamelObject": camel,

tools/goctl/model/sql/gen/findonebyfield.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e
2222
return nil, err
2323
}
2424

25-
t := util.With("findOneByField").Parse(text)
25+
t := util.With("findOneByField").Parse(text).AddFunc("hasField", hasField(table))
2626
var list []string
2727
camelTableName := table.Name.ToCamel()
2828
for _, key := range table.UniqueCacheKey {
@@ -54,7 +54,7 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e
5454
return nil, err
5555
}
5656

57-
t = util.With("findOneByFieldMethod").Parse(text)
57+
t = util.With("findOneByFieldMethod").Parse(text).AddFunc("hasField", hasField(table))
5858
var listMethod []string
5959
for _, key := range table.UniqueCacheKey {
6060
var inJoin, paramJoin Join
@@ -88,7 +88,7 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e
8888
return nil, err
8989
}
9090

91-
out, err := util.With("findOneByFieldExtraMethod").Parse(text).Execute(map[string]any{
91+
out, err := util.With("findOneByFieldExtraMethod").AddFunc("hasField", hasField(table)).Parse(text).Execute(map[string]any{
9292
"upperStartCamelObject": camelTableName,
9393
"primaryKeyLeft": table.PrimaryCacheKey.VarLeft,
9494
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),

tools/goctl/model/sql/gen/gen.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ func (g *defaultGenerator) genModelCustom(in parser.Table, withCache bool) (stri
360360

361361
t := util.With("model-custom").
362362
Parse(text).
363+
AddFunc("hasField", hasField(Table{Table: in})).
363364
GoFmt(true)
364365
output, err := t.Execute(map[string]any{
365366
"pkg": g.pkg,
@@ -381,6 +382,7 @@ func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer,
381382
}
382383
t := util.With("model").
383384
Parse(text).
385+
AddFunc("hasField", hasField(table)).
384386
GoFmt(true)
385387
output, err := t.Execute(map[string]any{
386388
"pkg": g.pkg,

tools/goctl/model/sql/gen/imports.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func genImports(table Table, withCache, timeImport bool) (string, error) {
2828
return "", err
2929
}
3030

31-
buffer, err := util.With("import").Parse(text).Execute(map[string]any{
31+
buffer, err := util.With("import").Parse(text).AddFunc("hasField", hasField(table)).Execute(map[string]any{
3232
"time": timeImport,
3333
"containsPQ": table.ContainsPQ,
3434
"data": table,

tools/goctl/model/sql/gen/template.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,16 @@ func Update() error {
9494

9595
return pathx.InitTemplates(category, templates)
9696
}
97+
98+
// hasField returns a function that checks if a field exists in the table.
99+
// It uses a pre-built map for O(1) lookup performance.
100+
func hasField(table Table) func(string) bool {
101+
fieldSet := make(map[string]struct{}, len(table.Fields))
102+
for _, field := range table.Fields {
103+
fieldSet[field.NameOriginal] = struct{}{}
104+
}
105+
return func(f string) bool {
106+
_, ok := fieldSet[f]
107+
return ok
108+
}
109+
}

tools/goctl/model/sql/gen/template_test.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import (
66
"testing"
77

88
"github.com/stretchr/testify/assert"
9+
"github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
910
"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
1011
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
12+
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
1113
)
1214

1315
func TestGenTemplates(t *testing.T) {
@@ -91,3 +93,151 @@ func TestUpdate(t *testing.T) {
9193
assert.Nil(t, err)
9294
assert.Equal(t, template.New, string(data))
9395
}
96+
97+
func TestHasField(t *testing.T) {
98+
tests := []struct {
99+
name string
100+
table Table
101+
fieldName string
102+
wantResult bool
103+
}{
104+
{
105+
name: "field exists",
106+
table: Table{
107+
Table: parser.Table{
108+
Fields: []*parser.Field{
109+
{NameOriginal: "id"},
110+
{NameOriginal: "name"},
111+
{NameOriginal: "created_at"},
112+
},
113+
},
114+
},
115+
fieldName: "name",
116+
wantResult: true,
117+
},
118+
{
119+
name: "field does not exist",
120+
table: Table{
121+
Table: parser.Table{
122+
Fields: []*parser.Field{
123+
{NameOriginal: "id"},
124+
{NameOriginal: "name"},
125+
},
126+
},
127+
},
128+
fieldName: "email",
129+
wantResult: false,
130+
},
131+
{
132+
name: "empty table",
133+
table: Table{
134+
Table: parser.Table{
135+
Fields: []*parser.Field{},
136+
},
137+
},
138+
fieldName: "id",
139+
wantResult: false,
140+
},
141+
{
142+
name: "case sensitive match",
143+
table: Table{
144+
Table: parser.Table{
145+
Fields: []*parser.Field{
146+
{NameOriginal: "ID"},
147+
{NameOriginal: "Name"},
148+
},
149+
},
150+
},
151+
fieldName: "id",
152+
wantResult: false,
153+
},
154+
{
155+
name: "exact match required",
156+
table: Table{
157+
Table: parser.Table{
158+
Fields: []*parser.Field{
159+
{NameOriginal: "user_name"},
160+
},
161+
},
162+
},
163+
fieldName: "user_name",
164+
wantResult: true,
165+
},
166+
{
167+
name: "partial match should fail",
168+
table: Table{
169+
Table: parser.Table{
170+
Fields: []*parser.Field{
171+
{NameOriginal: "user_name"},
172+
},
173+
},
174+
},
175+
fieldName: "user",
176+
wantResult: false,
177+
},
178+
}
179+
180+
for _, tt := range tests {
181+
t.Run(tt.name, func(t *testing.T) {
182+
fn := hasField(tt.table)
183+
result := fn(tt.fieldName)
184+
assert.Equal(t, tt.wantResult, result)
185+
})
186+
}
187+
}
188+
189+
func TestHasFieldWithRealTable(t *testing.T) {
190+
// Create a realistic table structure
191+
table := Table{
192+
Table: parser.Table{
193+
Name: stringx.From("users"),
194+
Fields: []*parser.Field{
195+
{NameOriginal: "id", DataType: "int64"},
196+
{NameOriginal: "username", DataType: "string"},
197+
{NameOriginal: "email", DataType: "string"},
198+
{NameOriginal: "password", DataType: "string"},
199+
{NameOriginal: "created_at", DataType: "time.Time"},
200+
{NameOriginal: "updated_at", DataType: "time.Time"},
201+
},
202+
},
203+
}
204+
205+
fn := hasField(table)
206+
207+
// Test all existing fields
208+
assert.True(t, fn("id"))
209+
assert.True(t, fn("username"))
210+
assert.True(t, fn("email"))
211+
assert.True(t, fn("password"))
212+
assert.True(t, fn("created_at"))
213+
assert.True(t, fn("updated_at"))
214+
215+
// Test non-existing fields
216+
assert.False(t, fn("deleted_at"))
217+
assert.False(t, fn("ID"))
218+
assert.False(t, fn("Username"))
219+
assert.False(t, fn(""))
220+
}
221+
222+
func TestHasFieldPerformance(t *testing.T) {
223+
// Create a table with many fields to test performance optimization
224+
var fields []*parser.Field
225+
for i := 0; i < 1000; i++ {
226+
fields = append(fields, &parser.Field{
227+
NameOriginal: "field_" + string(rune('0'+i%10)) + string(rune('a'+i%26)),
228+
})
229+
}
230+
231+
table := Table{
232+
Table: parser.Table{
233+
Fields: fields,
234+
},
235+
}
236+
237+
fn := hasField(table)
238+
239+
// Verify the function works correctly
240+
assert.True(t, fn(fields[0].NameOriginal))
241+
assert.True(t, fn(fields[999].NameOriginal))
242+
assert.False(t, fn("non_existent_field"))
243+
}

tools/goctl/model/sql/gen/update.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func genUpdate(table Table, withCache, postgreSql bool) (
6161
return "", "", err
6262
}
6363

64-
output, err := util.With("update").Parse(text).Execute(
64+
output, err := util.With("update").Parse(text).AddFunc("hasField", hasField(table)).Execute(
6565
map[string]any{
6666
"withCache": withCache,
6767
"containsIndexCache": table.ContainsUniqueCacheKey,
@@ -94,7 +94,7 @@ func genUpdate(table Table, withCache, postgreSql bool) (
9494
return "", "", err
9595
}
9696

97-
updateMethodOutput, err := util.With("updateMethod").Parse(text).Execute(
97+
updateMethodOutput, err := util.With("updateMethod").Parse(text).AddFunc("hasField", hasField(table)).Execute(
9898
map[string]any{
9999
"upperStartCamelObject": camelTableName,
100100
"data": table,

tools/goctl/util/templatex.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ const regularPerm = 0o666
1515

1616
// DefaultTemplate is a tool to provides the text/template operations
1717
type DefaultTemplate struct {
18-
name string
19-
text string
20-
goFmt bool
18+
name string
19+
text string
20+
goFmt bool
21+
funcMap template.FuncMap
2122
}
2223

2324
// With returns an instance of DefaultTemplate
2425
func With(name string) *DefaultTemplate {
2526
return &DefaultTemplate{
26-
name: name,
27+
name: name,
28+
funcMap: make(template.FuncMap),
2729
}
2830
}
2931

@@ -55,7 +57,11 @@ func (t *DefaultTemplate) SaveTo(data any, path string, forceUpdate bool) error
5557

5658
// Execute returns the codes after the template executed
5759
func (t *DefaultTemplate) Execute(data any) (*bytes.Buffer, error) {
58-
tem, err := template.New(t.name).Parse(t.text)
60+
tmp := template.New(t.name)
61+
if len(t.funcMap) > 0 {
62+
tmp.Funcs(t.funcMap)
63+
}
64+
tem, err := tmp.Parse(t.text)
5965
if err != nil {
6066
return nil, errorx.Wrap(err, "template parse error:", t.text)
6167
}
@@ -79,6 +85,16 @@ func (t *DefaultTemplate) Execute(data any) (*bytes.Buffer, error) {
7985
return buf, nil
8086
}
8187

88+
// AddFunc adds a template function. It returns the template instance for chaining.
89+
// If funcName is empty or function is nil, it returns the template without modification.
90+
func (t *DefaultTemplate) AddFunc(funcName string, function any) *DefaultTemplate {
91+
if funcName == "" || function == nil {
92+
return t
93+
}
94+
t.funcMap[funcName] = function
95+
return t
96+
}
97+
8298
// IsTemplateVariable returns true if the text is a template variable.
8399
// The text must start with a dot and be a valid template.
84100
func IsTemplateVariable(text string) bool {

0 commit comments

Comments
 (0)