Skip to content

Commit 64c6de1

Browse files
committed
Improve code quality and test assertions
1 parent 2d9413f commit 64c6de1

32 files changed

+418
-430
lines changed

cmd/build_test.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,6 @@ func TestRewriteMarkdownLinksInHTML_ValidTargets(t *testing.T) {
345345
}
346346
}
347347

348-
// TestBuildVariables_Initialization tests that build variables are initialized
349-
func TestBuildVariables_Initialization(t *testing.T) {
350-
_ = buildFormat
351-
_ = buildBranch
352-
_ = buildSubDir
353-
_ = buildOutput
354-
_ = buildSummary
355-
}
356-
357348
// TestBuildCommand_ExamplesInHelp tests that build help contains examples
358349
func TestBuildCommand_ExamplesInHelp(t *testing.T) {
359350
if !strings.Contains(buildCmd.Long, "mdpress build") {

cmd/root_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ func TestCacheDirFlagParsing(t *testing.T) {
6464

6565
for i, tt := range tests {
6666
t.Run(tt.name, func(t *testing.T) {
67-
// Reset environment before each test
68-
os.Unsetenv("MDPRESS_CACHE_DIR")
69-
os.Unsetenv("MDPRESS_DISABLE_CACHE")
67+
// Reset environment before each test (t.Setenv auto-restores)
68+
t.Setenv("MDPRESS_CACHE_DIR", "")
69+
t.Setenv("MDPRESS_DISABLE_CACHE", "")
7070

7171
// Create a test command to verify flag parsing
7272
testCmd := &cobra.Command{Use: "test", RunE: func(cmd *cobra.Command, args []string) error {
@@ -107,8 +107,8 @@ func TestCacheDirFlagParsing(t *testing.T) {
107107

108108
// TestNoCacheDisablesBothFlags tests that --no-cache properly disables caching
109109
func TestNoCacheDisablesFlagCaching(t *testing.T) {
110-
os.Unsetenv("MDPRESS_CACHE_DIR")
111-
os.Unsetenv("MDPRESS_DISABLE_CACHE")
110+
t.Setenv("MDPRESS_CACHE_DIR", "")
111+
t.Setenv("MDPRESS_DISABLE_CACHE", "")
112112

113113
// Simulate flag parsing and configuration
114114
noCache = true
@@ -123,8 +123,8 @@ func TestNoCacheDisablesFlagCaching(t *testing.T) {
123123

124124
// TestCacheDirOverrideWorks tests that --cache-dir properly overrides default
125125
func TestCacheDirOverrideWorks(t *testing.T) {
126-
os.Unsetenv("MDPRESS_CACHE_DIR")
127-
os.Unsetenv("MDPRESS_DISABLE_CACHE")
126+
t.Setenv("MDPRESS_CACHE_DIR", "")
127+
t.Setenv("MDPRESS_DISABLE_CACHE", "")
128128

129129
customPath := "/custom/cache/path"
130130
cacheDir = customPath
@@ -158,8 +158,8 @@ func TestFlagDefaults(t *testing.T) {
158158

159159
// TestConfigureRuntimeCacheEnvDoesNotPanicOnEmpty tests robustness
160160
func TestConfigureRuntimeCacheEnvRobustness(t *testing.T) {
161-
os.Unsetenv("MDPRESS_CACHE_DIR")
162-
os.Unsetenv("MDPRESS_DISABLE_CACHE")
161+
t.Setenv("MDPRESS_CACHE_DIR", "")
162+
t.Setenv("MDPRESS_DISABLE_CACHE", "")
163163

164164
cacheDir = ""
165165
noCache = false

internal/config/config.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,15 @@ func DefaultConfig() *BookConfig {
172172
// It also auto-detects GLOSSARY.md and LANGS.md.
173173
func Load(path string) (*BookConfig, error) {
174174
// Limit config size to guard against malformed or malicious YAML inputs.
175+
// Check size via os.Stat before reading to avoid loading large files into memory.
175176
const maxConfigSize = 10 * 1024 * 1024 // 10MB
176-
fi, err := os.Stat(path)
177+
info, err := os.Stat(path)
177178
if err != nil {
178-
return nil, fmt.Errorf("failed to read config file: %w (ensure %s exists and is readable)", err, path)
179+
return nil, fmt.Errorf("failed to stat config file: %w (ensure %s exists and is readable)", err, path)
179180
}
180-
if fi.Size() > maxConfigSize {
181-
return nil, fmt.Errorf("config file is too large (%d bytes; max allowed is %d bytes)", fi.Size(), maxConfigSize)
181+
if info.Size() > int64(maxConfigSize) {
182+
return nil, fmt.Errorf("config file is too large (%d bytes; max allowed is %d bytes)", info.Size(), maxConfigSize)
182183
}
183-
184184
data, err := os.ReadFile(path)
185185
if err != nil {
186186
return nil, fmt.Errorf("failed to read config file: %w (ensure %s exists and is readable)", err, path)
@@ -295,8 +295,17 @@ func (c *BookConfig) Validate() error {
295295
return nil
296296
}
297297

298+
const maxChapterNestingDepth = 20
299+
298300
// validateChapters recursively validates chapter definitions and their nested sections.
299301
func (c *BookConfig) validateChapters(chapters []ChapterDef, prefix string) error {
302+
return c.validateChaptersDepth(chapters, prefix, 0)
303+
}
304+
305+
func (c *BookConfig) validateChaptersDepth(chapters []ChapterDef, prefix string, depth int) error {
306+
if depth > maxChapterNestingDepth {
307+
return fmt.Errorf("chapter nesting exceeds maximum depth of %d", maxChapterNestingDepth)
308+
}
300309
for i, ch := range chapters {
301310
label := fmt.Sprintf("%s%d", prefix, i+1)
302311
if ch.File == "" {
@@ -309,7 +318,7 @@ func (c *BookConfig) validateChapters(chapters []ChapterDef, prefix string) erro
309318
}
310319
// Recursively validate nested sections.
311320
if len(ch.Sections) > 0 {
312-
if err := c.validateChapters(ch.Sections, label+"."); err != nil {
321+
if err := c.validateChaptersDepth(ch.Sections, label+".", depth+1); err != nil {
313322
return fmt.Errorf("nested section validation failed: %w", err)
314323
}
315324
}

internal/config/discover.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,11 @@ func Discover(ctx context.Context, dir string) (*BookConfig, error) {
106106
}
107107
}
108108

109-
// If book.json was found and we got this far, return the config from it
109+
// If book.json was found and we got this far, validate and return the config.
110110
if hasBookJSON {
111+
if err := cfg.Validate(); err != nil {
112+
return nil, fmt.Errorf("book.json config validation failed: %w", err)
113+
}
111114
return cfg, nil
112115
}
113116

@@ -237,9 +240,11 @@ func fileNameToTitle(path string) string {
237240
"-", " ",
238241
).Replace(name)
239242

240-
// Uppercase the first letter.
241-
if len(name) > 0 {
242-
name = strings.ToUpper(name[:1]) + name[1:]
243+
// Uppercase the first rune (safe for multi-byte UTF-8).
244+
runes := []rune(name)
245+
if len(runes) > 0 {
246+
runes[0] = unicode.ToUpper(runes[0])
247+
name = string(runes)
243248
}
244249

245250
return name
@@ -427,8 +432,10 @@ func inferBookTitle(h1Title, content, dir string) string {
427432
dirName := filepath.Base(dir)
428433
dirName = strings.ReplaceAll(dirName, "_", " ")
429434
dirName = strings.ReplaceAll(dirName, "-", " ")
430-
if len(dirName) > 0 {
431-
dirName = strings.ToUpper(dirName[:1]) + dirName[1:]
435+
dr := []rune(dirName)
436+
if len(dr) > 0 {
437+
dr[0] = unicode.ToUpper(dr[0])
438+
dirName = string(dr)
432439
}
433440
return dirName
434441
}

internal/markdown/diagnostics.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ func isMatchingBracket(open, close rune) bool {
408408
}
409409

410410
func runeCountBytes(b []byte) int {
411-
return utf8.RuneCountInString(string(b))
411+
return utf8.RuneCount(b)
412412
}
413413

414414
// longHeadingThreshold 是触发"标题文本过长"警告的字符数阈值。

internal/markdown/extensions.go

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,40 @@ package markdown
33
import (
44
"fmt"
55
"strings"
6-
"sync"
76

87
"github.com/yuin/goldmark/ast"
98
"github.com/yuin/goldmark/parser"
109
"github.com/yuin/goldmark/text"
1110
)
1211

13-
// headingIDTransformer 为标题自动生成唯一 ID 属性
14-
type headingIDTransformer struct {
15-
usedIDs map[string]int
16-
mu sync.Mutex
17-
}
12+
// headingIDTransformer 为标题自动生成唯一 ID 属性.
13+
// It is safe for concurrent use: each Transform call uses a local map so
14+
// multiple documents can be parsed in parallel without interference.
15+
type headingIDTransformer struct{}
1816

1917
func newHeadingIDTransformer() parser.ASTTransformer {
20-
return &headingIDTransformer{
21-
usedIDs: make(map[string]int),
22-
}
18+
return &headingIDTransformer{}
2319
}
2420

25-
// Transform 遍历 AST,为所有标题节点生成 ID
21+
// Transform 遍历 AST,为所有标题节点生成 ID.
22+
// A fresh usedIDs map is created per call so heading IDs are scoped to a
23+
// single document and concurrent Transform calls don't interfere.
2624
func (t *headingIDTransformer) Transform(node *ast.Document, reader text.Reader, pc parser.Context) {
25+
usedIDs := make(map[string]int)
2726
source := reader.Source()
28-
if err := ast.Walk(node, func(n ast.Node, entering bool) (ast.WalkStatus, error) {
27+
_ = ast.Walk(node, func(n ast.Node, entering bool) (ast.WalkStatus, error) {
2928
if !entering {
3029
return ast.WalkContinue, nil
3130
}
3231
if heading, ok := n.(*ast.Heading); ok {
33-
t.processHeading(heading, source)
32+
processHeading(heading, source, usedIDs)
3433
}
3534
return ast.WalkContinue, nil
36-
}); err != nil {
37-
return
38-
}
35+
})
3936
}
4037

4138
// processHeading 为单个标题节点设置 ID 属性
42-
func (t *headingIDTransformer) processHeading(heading *ast.Heading, source []byte) {
39+
func processHeading(heading *ast.Heading, source []byte, usedIDs map[string]int) {
4340
if _, ok := heading.AttributeString("id"); ok {
4441
return
4542
}
@@ -49,27 +46,24 @@ func (t *headingIDTransformer) processHeading(heading *ast.Heading, source []byt
4946
return
5047
}
5148

52-
id := t.generateUniqueID(headingText)
49+
id := generateUniqueID(headingText, usedIDs)
5350
heading.SetAttributeString("id", []byte(id))
5451
}
5552

5653
// generateUniqueID 生成唯一的标题 ID,遇到重复自动添加后缀
57-
func (t *headingIDTransformer) generateUniqueID(text string) string {
54+
func generateUniqueID(text string, usedIDs map[string]int) string {
5855
baseID := generateHeadingID(text)
5956
if baseID == "" {
6057
baseID = "heading"
6158
}
6259

63-
t.mu.Lock()
64-
defer t.mu.Unlock()
65-
66-
count, exists := t.usedIDs[baseID]
60+
count, exists := usedIDs[baseID]
6761
if !exists {
68-
t.usedIDs[baseID] = 1
62+
usedIDs[baseID] = 1
6963
return baseID
7064
}
7165

72-
t.usedIDs[baseID] = count + 1
66+
usedIDs[baseID] = count + 1
7367
return fmt.Sprintf("%s-%d", baseID, count+1)
7468
}
7569

0 commit comments

Comments
 (0)