Skip to content

Commit a39aa2d

Browse files
committed
fix(codegen): preserve *assert.CollectT in Require
Update replace funcmap to replace "assert." with "require." while preserving "*assert.CollectT" references. This prevents unintended replacement of pointer references to CollectT, ensuring correct code generation. Signed-off-by: a2not <[email protected]>
1 parent 429ee0b commit a39aa2d

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

_codegen/main.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ func parseTemplates() (*template.Template, *template.Template, error) {
106106
}
107107
funcTemplate = string(f)
108108
}
109-
tmpl, err := template.New("function").Funcs(template.FuncMap{
110-
"replace": strings.ReplaceAll,
111-
}).Parse(funcTemplate)
109+
tmpl, err := template.New("function").Parse(funcTemplate)
112110
if err != nil {
113111
return nil, nil, err
114112
}
@@ -298,6 +296,18 @@ func (f *testFunc) CommentWithoutT(receiver string) string {
298296
return strings.Replace(f.Comment(), search, replace, -1)
299297
}
300298

299+
func (f *testFunc) Replace(comment, search, replace string) string {
300+
// replace strings, but preserve "*assert.CollectT"
301+
const (
302+
assertCollectT = "*assert.CollectT"
303+
assertCollectTPlaceholder = "__COLLECT_T_PLACEHOLDER__"
304+
)
305+
306+
protected := strings.ReplaceAll(comment, assertCollectT, assertCollectTPlaceholder)
307+
result := strings.ReplaceAll(protected, search, replace)
308+
return strings.ReplaceAll(result, assertCollectTPlaceholder, assertCollectT)
309+
}
310+
301311
// Standard header https://go.dev/s/generatedcode.
302312
var headerTemplate = `// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
303313

require/require.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
429429
// time.Sleep(8*time.Second)
430430
// externalValue = true
431431
// }()
432-
// require.EventuallyWithT(t, func(c *require.CollectT) {
432+
// require.EventuallyWithT(t, func(c *assert.CollectT) {
433433
// // add assertions as needed; any assertion failure will fail the current tick
434434
// require.True(c, externalValue, "expected 'externalValue' to be true")
435435
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")
@@ -457,7 +457,7 @@ func EventuallyWithT(t TestingT, condition func(collect *assert.CollectT), waitF
457457
// time.Sleep(8*time.Second)
458458
// externalValue = true
459459
// }()
460-
// require.EventuallyWithTf(t, func(c *require.CollectT, "error message %s", "formatted") {
460+
// require.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") {
461461
// // add assertions as needed; any assertion failure will fail the current tick
462462
// require.True(c, externalValue, "expected 'externalValue' to be true")
463463
// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false")

require/require.go.tmpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{{ replace .Comment "assert." "require."}}
1+
{{ .Replace .Comment "assert." "require."}}
22
func {{.DocInfo.Name}}(t TestingT, {{.Params}}) {
33
if h, ok := t.(tHelper); ok { h.Helper() }
44
if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return }

0 commit comments

Comments
 (0)