Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,191 @@ func computeNewlineTargets(before, after string) []int {

return result
}

// writeYAMLFilesSurgical writes files using surgical text replacement instead of
// re-serializing the YAML. This preserves the original formatting of the file.
// It walks the modified yaml.Node tree and applies text replacements based on
// line/column positions.
func (r loadResults) writeYAMLFilesSurgical(outPath string) error {
var merr error

for pth, f := range r {
outFile := outPath
if strings.HasSuffix(outPath, "/") {
outFile = filepath.Join(outPath, pth)
}
if outFile == "" {
outFile = pth
}

final := applySurgicalReplacements(f.contents, f.node)

if err := atomic.Write(pth, outFile, strings.NewReader(final)); err != nil {
merr = errors.Join(merr, fmt.Errorf("failed to save file %s: %w", outFile, err))
continue
}
}

return merr
}

// applySurgicalReplacements walks the yaml.Node tree and applies text replacements
// to the original content based on the node's line/column positions.
func applySurgicalReplacements(contents string, node *yaml.Node) string {
// Collect all replacements from the node tree
var replacements []surgicalReplacement
collectReplacements(node, contents, &replacements)

if len(replacements) == 0 {
return contents
}

// Sort by line descending, then column descending, so we can apply
// replacements without affecting positions of subsequent ones
slices.SortFunc(replacements, func(a, b surgicalReplacement) int {
if a.line != b.line {
return b.line - a.line
}
return b.col - a.col
})

lines := strings.Split(contents, "\n")

for _, rep := range replacements {
if rep.line < 1 || rep.line > len(lines) {
continue
}

lineIdx := rep.line - 1
line := lines[lineIdx]

// Find the old value starting from the column position
colIdx := rep.col - 1
if colIdx < 0 || colIdx >= len(line) {
continue
}

// Find the old value in the line
valueStart := strings.Index(line[colIdx:], rep.oldValue)
if valueStart == -1 {
continue
}
valueStart += colIdx
valueEnd := valueStart + len(rep.oldValue)

// Find where any existing comment starts (after the value)
commentStart := -1
rest := line[valueEnd:]
for i := 0; i < len(rest); i++ {
if rest[i] == '#' {
commentStart = valueEnd + i
break
}
}

// Build the new line: keep prefix, add new value, then new comment
var newLine string
if commentStart != -1 {
// There's an existing comment - replace value and everything after
newLine = line[:valueStart] + rep.newValue
} else {
// No existing comment - just replace value
newLine = line[:valueStart] + rep.newValue + line[valueEnd:]
}

// Add new comment if specified
if rep.newComment != "" {
// Check if the comment already includes the # prefix
if strings.HasPrefix(rep.newComment, "#") {
newLine = newLine + " " + rep.newComment
} else {
newLine = newLine + " # " + rep.newComment
}
}

lines[lineIdx] = newLine
}

return strings.Join(lines, "\n")
}

type surgicalReplacement struct {
line int
col int
oldValue string
newValue string
newComment string
}

// collectReplacements walks the node tree and collects replacements for nodes
// that have been modified by the parser (detected by presence of "ratchet:" in LineComment).
func collectReplacements(node *yaml.Node, contents string, replacements *[]surgicalReplacement) {
if node == nil {
return
}

// Only process scalar nodes that have been modified by Pin/Update/Upgrade
// These are identified by having "ratchet:" in the LineComment
if node.Kind == yaml.ScalarNode && node.Line > 0 && node.Column > 0 &&
strings.Contains(node.LineComment, "ratchet:") {

lines := strings.Split(contents, "\n")
if node.Line <= len(lines) {
line := lines[node.Line-1]
colIdx := node.Column - 1

if colIdx >= 0 && colIdx < len(line) {
origValue := extractValueAtPosition(line, colIdx)

if origValue != "" && origValue != node.Value {
*replacements = append(*replacements, surgicalReplacement{
line: node.Line,
col: node.Column,
oldValue: origValue,
newValue: node.Value,
newComment: node.LineComment,
})
}
}
}
}

// Recurse into children
for _, child := range node.Content {
collectReplacements(child, contents, replacements)
}
}

// extractValueAtPosition extracts the YAML value at the given position in the line.
// Handles both quoted and unquoted values.
func extractValueAtPosition(line string, col int) string {
if col >= len(line) {
return ""
}

rest := line[col:]

// Handle quoted strings
if len(rest) > 0 && (rest[0] == '\'' || rest[0] == '"') {
quote := rest[0]
end := 1
for end < len(rest) {
if rest[end] == byte(quote) && (end == 0 || rest[end-1] != '\\') {
return rest[1:end]
}
end++
}
}

// Handle unquoted values - read until whitespace or comment
end := 0
for end < len(rest) {
ch := rest[end]
if ch == ' ' || ch == '\t' || ch == '#' || ch == '\n' || ch == '\r' {
break
}
end++
}

return rest[:end]
}
100 changes: 98 additions & 2 deletions command/command_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package command

import (
"bytes"
"io/fs"
"os"
"strings"
"testing"
"testing/fstest"

"github.com/braydonk/yaml"
"github.com/google/go-cmp/cmp"
)

Expand All @@ -28,6 +31,11 @@ func Test_loadYAMLFiles(t *testing.T) {
"gitlabci.yml": "",
"no-trailing-newline.yml": "no-trailing-newline.golden.yml",
"tekton.yml": "",

// These files demonstrate the YAML marshaling bug from PR #125 where
// comments get misplaced. Uncomment to see the failures:
// "github-pr125.yml": "",
// "github-codeql-pr125.yml": "",
}

for input, expected := range cases {
Expand All @@ -52,8 +60,8 @@ func Test_loadYAMLFiles(t *testing.T) {
t.Fatal(err)
}

if got != string(want) {
t.Errorf("expected\n\n%s\n\nto be\n\n%s\n", got, want)
if diff := cmp.Diff(string(want), got); diff != "" {
t.Errorf("round-trip mismatch (-want +got):\n%s", diff)
}
})
}
Expand Down Expand Up @@ -224,3 +232,91 @@ test-code-job1:
})
}
}

// Test_applySurgicalReplacements_preservesFormatting tests that the surgical
// replacement approach preserves original YAML formatting (PR #125).
func Test_applySurgicalReplacements_preservesFormatting(t *testing.T) {
t.Parallel()

fsys := os.DirFS("../testdata")

cases := []struct {
name string
file string
modifyFn func(node *yaml.Node)
}{
{
name: "github-pr125.yml",
file: "github-pr125.yml",
modifyFn: func(node *yaml.Node) {
walkAndPin(node, "actions/checkout@v2", "actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab", "# ratchet:actions/checkout@v2")
walkAndPin(node, "actions/github-script@v6", "actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea", "# ratchet:actions/github-script@v6")
},
},
{
name: "github-codeql-pr125.yml",
file: "github-codeql-pr125.yml",
modifyFn: func(node *yaml.Node) {
walkAndPin(node, "actions/checkout@v5", "actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683", "# ratchet:actions/checkout@v5")
walkAndPin(node, "github/codeql-action/init@v3", "github/codeql-action/init@aa578102511db1f4524ed59b8cc2bae4f6e88195", "# ratchet:github/codeql-action/init@v3")
walkAndPin(node, "github/codeql-action/autobuild@v3", "github/codeql-action/autobuild@aa578102511db1f4524ed59b8cc2bae4f6e88195", "# ratchet:github/codeql-action/autobuild@v3")
walkAndPin(node, "github/codeql-action/analyze@v3", "github/codeql-action/analyze@aa578102511db1f4524ed59b8cc2bae4f6e88195", "# ratchet:github/codeql-action/analyze@v3")
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

original, err := fsys.(fs.ReadFileFS).ReadFile(tc.file)
if err != nil {
t.Fatal(err)
}

var node yaml.Node
dec := yaml.NewDecoder(bytes.NewReader(original))
dec.SetScanBlockScalarAsLiteral(true)
if err := dec.Decode(&node); err != nil {
t.Fatal(err)
}

tc.modifyFn(&node)
got := applySurgicalReplacements(string(original), &node)

// Verify line count unchanged (surgical replacement shouldn't add/remove lines)
originalLines := strings.Split(string(original), "\n")
gotLines := strings.Split(got, "\n")
if len(originalLines) != len(gotLines) {
t.Errorf("line count changed: original=%d, got=%d", len(originalLines), len(gotLines))
}

// Verify non-uses lines are unchanged, uses lines have ratchet comments
for i := 0; i < len(originalLines) && i < len(gotLines); i++ {
if strings.Contains(originalLines[i], "uses:") {
if !strings.Contains(gotLines[i], "ratchet:") {
t.Errorf("line %d: expected ratchet comment, got %q", i+1, gotLines[i])
}
continue
}
if originalLines[i] != gotLines[i] {
t.Errorf("line %d: unexpected change\n original: %q\n got: %q", i+1, originalLines[i], gotLines[i])
}
}
})
}
}

// walkAndPin simulates what Pin does when finding a matching action reference.
func walkAndPin(node *yaml.Node, oldValue, newValue, comment string) {
if node == nil {
return
}
if node.Kind == yaml.ScalarNode && node.Value == oldValue {
node.Value = newValue
node.LineComment = comment
}
for _, child := range node.Content {
walkAndPin(child, oldValue, newValue, comment)
}
}
19 changes: 14 additions & 5 deletions command/pin.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ FLAGS
`

type PinCommand struct {
flagConcurrency int64
flagParser string
flagOut string
flagConcurrency int64
flagParser string
flagOut string
flagExperimentalPreserveYAML bool
}

func (c *PinCommand) Desc() string {
Expand All @@ -56,6 +57,8 @@ func (c *PinCommand) Flags() *flag.FlagSet {
"maximum number of concurrent resolutions")
f.StringVar(&c.flagParser, "parser", "actions", "parser to use")
f.StringVar(&c.flagOut, "out", "", "output path (defaults to input file)")
f.BoolVar(&c.flagExperimentalPreserveYAML, "experimental-preserve-formatting", false,
"(experimental) preserve original YAML formatting")

return f
}
Expand Down Expand Up @@ -89,8 +92,14 @@ func (c *PinCommand) Run(ctx context.Context, originalArgs []string) error {
return fmt.Errorf("failed to pin refs: %w", err)
}

if err := loadResult.writeYAMLFiles(c.flagOut); err != nil {
return fmt.Errorf("failed to save files: %w", err)
if c.flagExperimentalPreserveYAML {
if err := loadResult.writeYAMLFilesSurgical(c.flagOut); err != nil {
return fmt.Errorf("failed to save files: %w", err)
}
} else {
if err := loadResult.writeYAMLFiles(c.flagOut); err != nil {
return fmt.Errorf("failed to save files: %w", err)
}
}

return nil
Expand Down
15 changes: 12 additions & 3 deletions command/unpin.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ FLAGS
`

type UnpinCommand struct {
flagOut string
flagOut string
flagExperimentalPreserveYAML bool
}

func (c *UnpinCommand) Desc() string {
Expand All @@ -50,6 +51,8 @@ func (c *UnpinCommand) Flags() *flag.FlagSet {
}

f.StringVar(&c.flagOut, "out", "", "output path (defaults to input file)")
f.BoolVar(&c.flagExperimentalPreserveYAML, "experimental-preserve-formatting", false,
"(experimental) preserve original YAML formatting")

return f
}
Expand All @@ -73,8 +76,14 @@ func (c *UnpinCommand) Run(ctx context.Context, originalArgs []string) error {
return fmt.Errorf("failed to pin refs: %w", err)
}

if err := loadResult.writeYAMLFiles(c.flagOut); err != nil {
return fmt.Errorf("failed to save files: %w", err)
if c.flagExperimentalPreserveYAML {
if err := loadResult.writeYAMLFilesSurgical(c.flagOut); err != nil {
return fmt.Errorf("failed to save files: %w", err)
}
} else {
if err := loadResult.writeYAMLFiles(c.flagOut); err != nil {
return fmt.Errorf("failed to save files: %w", err)
}
}

return nil
Expand Down
Loading