diff --git a/_codegen/main.go b/_codegen/main.go index e193a74a5..36c43acd1 100644 --- a/_codegen/main.go +++ b/_codegen/main.go @@ -107,7 +107,8 @@ func parseTemplates() (*template.Template, *template.Template, error) { funcTemplate = string(f) } tmpl, err := template.New("function").Funcs(template.FuncMap{ - "replace": strings.ReplaceAll, + "replace": strings.ReplaceAll, + "requireCommentParseIf": requireCommentParseIf, }).Parse(funcTemplate) if err != nil { return nil, nil, err @@ -298,6 +299,55 @@ func (f *testFunc) CommentWithoutT(receiver string) string { return strings.Replace(f.Comment(), search, replace, -1) } +// requireCommentParseIf rewrites invalid "if require..." examples +// in generated documentation for the require package. +// +// The assert package documentation often shows conditional usage like: +// +// // if assert.NoError(t, err) { +// // // continue with test +// // } +// +// However, require package methods do not return bool values; +// they call t.FailNow() on failure. This function transforms +// such conditional blocks by removing the "if require.Function() {" +// wrapper and adjusting indentation to show proper usage: +// +// // require.NoError(t, err) +// // continue with test +func requireCommentParseIf(s string) string { + lines := strings.Split(s, "\n") + out := make([]string, 0, len(lines)) + rePrefix := regexp.MustCompile(`//[[:blank:]]+`) + ifBlock := false + prePrefix := "//\t" + + for _, line := range lines { + commentPrefix := rePrefix.FindString(line) + comment := strings.TrimSpace(line[2:]) + + if ifBlock && strings.HasPrefix(comment, "}") { + ifBlock = false + continue + } + + if strings.HasPrefix(comment, "if require.") && strings.HasSuffix(comment, "{") { + ifBlock = true + comment = strings.TrimPrefix(comment, "if ") + comment = strings.TrimSpace(comment) + comment = strings.TrimSuffix(comment, "{") + } + + if ifBlock { + commentPrefix = prePrefix + } + + prePrefix = commentPrefix + out = append(out, commentPrefix+comment) + } + return strings.Join(out, "\n") +} + // Standard header https://go.dev/s/generatedcode. var headerTemplate = `// Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT. diff --git a/_codegen/main_test.go b/_codegen/main_test.go new file mode 100644 index 000000000..8a31b3c80 --- /dev/null +++ b/_codegen/main_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "testing" +) + +func TestRequireCommentParseIf(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "single line without if", + input: "// Simple comment line", + expected: "// Simple comment line", + }, + { + name: "simple if require block transformation", + input: "//\tif require.NotEmpty(t, obj) {\n//\t require.Equal(t, \"two\", obj[1])\n//\t}", + expected: "//\trequire.NotEmpty(t, obj) \n//\trequire.Equal(t, \"two\", obj[1])", + }, + { + name: "no if block - should remain unchanged", + input: "// Contains function\n//\trequire.Contains(t, \"Hello World\", \"World\")", + expected: "// Contains function\n//\trequire.Contains(t, \"Hello World\", \"World\")", + }, + { + name: "mixed content with if block", + input: "//\t actualObj, err := SomeFunction()\n//\tif require.NoError(t, err) {\n//\t\t do something\n//\t}", + expected: "//\t actualObj, err := SomeFunction()\n//\t require.NoError(t, err) \n//\t do something", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := requireCommentParseIf(tt.input) + if result != tt.expected { + t.Errorf("requireCommentParseIf() failed:\nInput: %q\nGot: %q\nWant: %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/require/require.go b/require/require.go index 2d02f9bce..ec8e8a537 100644 --- a/require/require.go +++ b/require/require.go @@ -1387,10 +1387,9 @@ func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) { // NoError asserts that a function returned no error (i.e. `nil`). // -// actualObj, err := SomeFunction() -// if require.NoError(t, err) { -// require.Equal(t, expectedObj, actualObj) -// } +// actualObj, err := SomeFunction() +// require.NoError(t, err) +// require.Equal(t, expectedObj, actualObj) func NoError(t TestingT, err error, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1403,10 +1402,9 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) { // NoErrorf asserts that a function returned no error (i.e. `nil`). // -// actualObj, err := SomeFunction() -// if require.NoErrorf(t, err, "error message %s", "formatted") { -// require.Equal(t, expectedObj, actualObj) -// } +// actualObj, err := SomeFunction() +// require.NoErrorf(t, err, "error message %s", "formatted") +// require.Equal(t, expectedObj, actualObj) func NoErrorf(t TestingT, err error, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1515,9 +1513,8 @@ func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg str // NotEmpty asserts that the specified object is NOT [Empty]. // -// if require.NotEmpty(t, obj) { -// require.Equal(t, "two", obj[1]) -// } +// require.NotEmpty(t, obj) +// require.Equal(t, "two", obj[1]) func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1530,9 +1527,8 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { // NotEmptyf asserts that the specified object is NOT [Empty]. // -// if require.NotEmptyf(t, obj, "error message %s", "formatted") { -// require.Equal(t, "two", obj[1]) -// } +// require.NotEmptyf(t, obj, "error message %s", "formatted") +// require.Equal(t, "two", obj[1]) func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() diff --git a/require/require.go.tmpl b/require/require.go.tmpl index 8b3283685..86a7ac005 100644 --- a/require/require.go.tmpl +++ b/require/require.go.tmpl @@ -1,4 +1,4 @@ -{{ replace .Comment "assert." "require."}} +{{ replace .Comment "assert." "require." | requireCommentParseIf }} func {{.DocInfo.Name}}(t TestingT, {{.Params}}) { if h, ok := t.(tHelper); ok { h.Helper() } if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return }