Skip to content

Commit 4734f4d

Browse files
fix: prevent race condition panic when unmarshalling YAML with duplicate keys (#110)
1 parent a9fb29b commit 4734f4d

File tree

3 files changed

+375
-1
lines changed

3 files changed

+375
-1
lines changed

marshaller/duplicate_key_test.go

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
package marshaller_test
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/speakeasy-api/openapi/marshaller"
8+
testmodels "github.com/speakeasy-api/openapi/marshaller/tests"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
// joinErrors converts a slice of errors to a single string for assertion checks.
14+
func joinErrors(errs []error) string {
15+
errStrs := make([]string, len(errs))
16+
for i, e := range errs {
17+
errStrs[i] = e.Error()
18+
}
19+
return strings.Join(errStrs, " ")
20+
}
21+
22+
func TestUnmarshal_DuplicateKey_ReturnsValidationError(t *testing.T) {
23+
t.Parallel()
24+
25+
tests := []struct {
26+
name string
27+
yaml string
28+
expectedErrors int
29+
errorContains []string
30+
}{
31+
{
32+
name: "single duplicate key",
33+
yaml: `stringField: "first value"
34+
boolField: true
35+
stringField: "second value"
36+
intField: 42
37+
float64Field: 3.14
38+
`,
39+
expectedErrors: 1,
40+
errorContains: []string{"stringField", "duplicate"},
41+
},
42+
{
43+
name: "multiple duplicate keys",
44+
yaml: `stringField: "first value"
45+
boolField: true
46+
stringField: "second value"
47+
intField: 42
48+
boolField: false
49+
float64Field: 3.14
50+
`,
51+
expectedErrors: 2,
52+
errorContains: []string{"stringField", "boolField", "duplicate"},
53+
},
54+
{
55+
name: "same key three times",
56+
yaml: `stringField: "first value"
57+
boolField: true
58+
stringField: "second value"
59+
intField: 42
60+
stringField: "third value"
61+
float64Field: 3.14
62+
`,
63+
expectedErrors: 2,
64+
errorContains: []string{"stringField", "duplicate"},
65+
},
66+
{
67+
name: "no duplicates",
68+
yaml: `stringField: "test string"
69+
boolField: true
70+
intField: 42
71+
float64Field: 3.14
72+
`,
73+
expectedErrors: 0,
74+
errorContains: nil,
75+
},
76+
}
77+
78+
for _, tt := range tests {
79+
t.Run(tt.name, func(t *testing.T) {
80+
t.Parallel()
81+
82+
reader := strings.NewReader(tt.yaml)
83+
model := &testmodels.TestPrimitiveHighModel{}
84+
validationErrs, err := marshaller.Unmarshal(t.Context(), reader, model)
85+
require.NoError(t, err, "unmarshal should not return a fatal error")
86+
87+
assert.Len(t, validationErrs, tt.expectedErrors, "should have expected number of validation errors")
88+
89+
if tt.errorContains != nil {
90+
errStr := joinErrors(validationErrs)
91+
for _, contains := range tt.errorContains {
92+
assert.Contains(t, errStr, contains, "validation error should contain expected text")
93+
}
94+
}
95+
})
96+
}
97+
}
98+
99+
func TestUnmarshal_DuplicateKey_LastValueWins(t *testing.T) {
100+
t.Parallel()
101+
102+
tests := []struct {
103+
name string
104+
yaml string
105+
expectedValue string
106+
}{
107+
{
108+
name: "last value wins for string field",
109+
yaml: `stringField: "first value"
110+
boolField: true
111+
stringField: "second value"
112+
intField: 42
113+
float64Field: 3.14
114+
`,
115+
expectedValue: "second value",
116+
},
117+
{
118+
name: "last value wins with three occurrences",
119+
yaml: `stringField: "first value"
120+
boolField: true
121+
stringField: "second value"
122+
intField: 42
123+
stringField: "third value"
124+
float64Field: 3.14
125+
`,
126+
expectedValue: "third value",
127+
},
128+
}
129+
130+
for _, tt := range tests {
131+
t.Run(tt.name, func(t *testing.T) {
132+
t.Parallel()
133+
134+
reader := strings.NewReader(tt.yaml)
135+
model := &testmodels.TestPrimitiveHighModel{}
136+
validationErrs, err := marshaller.Unmarshal(t.Context(), reader, model)
137+
require.NoError(t, err, "unmarshal should not return a fatal error")
138+
139+
// We expect validation errors for duplicates, but the model should still be populated
140+
assert.NotEmpty(t, validationErrs, "should have validation errors for duplicate keys")
141+
142+
// Per YAML spec, the last value should win
143+
assert.Equal(t, tt.expectedValue, model.StringField, "last value should be used")
144+
})
145+
}
146+
}
147+
148+
func TestUnmarshal_DuplicateKey_NestedModel(t *testing.T) {
149+
t.Parallel()
150+
151+
yaml := `nestedModelValue:
152+
stringField: "first nested"
153+
boolField: true
154+
stringField: "second nested"
155+
intField: 100
156+
float64Field: 1.23
157+
eitherModelOrPrimitive: 999
158+
`
159+
160+
reader := strings.NewReader(yaml)
161+
model := &testmodels.TestComplexHighModel{}
162+
validationErrs, err := marshaller.Unmarshal(t.Context(), reader, model)
163+
require.NoError(t, err, "unmarshal should not return a fatal error")
164+
165+
// Should have validation error for nested duplicate
166+
assert.NotEmpty(t, validationErrs, "should have validation errors for nested duplicate keys")
167+
168+
// Last value should win
169+
assert.Equal(t, "second nested", model.NestedModelValue.StringField, "last value should be used in nested model")
170+
}
171+
172+
func TestUnmarshal_DuplicateKey_WithExtensions(t *testing.T) {
173+
t.Parallel()
174+
175+
yaml := `stringField: "test string"
176+
boolField: true
177+
x-custom: "first extension"
178+
intField: 42
179+
x-custom: "second extension"
180+
float64Field: 3.14
181+
`
182+
183+
reader := strings.NewReader(yaml)
184+
model := &testmodels.TestPrimitiveHighModel{}
185+
validationErrs, err := marshaller.Unmarshal(t.Context(), reader, model)
186+
require.NoError(t, err, "unmarshal should not return a fatal error")
187+
188+
// Should have validation error for duplicate extension key
189+
assert.NotEmpty(t, validationErrs, "should have validation errors for duplicate extension keys")
190+
191+
errStr := joinErrors(validationErrs)
192+
assert.Contains(t, errStr, "x-custom", "validation error should mention the duplicate extension key")
193+
}
194+
195+
func TestUnmarshal_DuplicateKey_RaceCondition(t *testing.T) {
196+
t.Parallel()
197+
198+
// This test verifies that duplicate keys don't cause race conditions
199+
// by testing concurrent unmarshalling with duplicate keys
200+
yaml := `stringField: "value1"
201+
boolField: true
202+
stringField: "value2"
203+
intField: 42
204+
stringField: "value3"
205+
float64Field: 3.14
206+
boolField: false
207+
intField: 100
208+
`
209+
210+
// Run multiple times to increase chance of catching race condition
211+
for i := 0; i < 10; i++ {
212+
t.Run("iteration", func(t *testing.T) {
213+
t.Parallel()
214+
215+
reader := strings.NewReader(yaml)
216+
model := &testmodels.TestPrimitiveHighModel{}
217+
validationErrs, err := marshaller.Unmarshal(t.Context(), reader, model)
218+
require.NoError(t, err, "unmarshal should not return a fatal error")
219+
220+
// Should have validation errors for duplicates
221+
assert.NotEmpty(t, validationErrs, "should have validation errors")
222+
223+
// Values should be consistent (last value wins)
224+
assert.Equal(t, "value3", model.StringField, "string field should have last value")
225+
assert.False(t, model.BoolField, "bool field should have last value")
226+
assert.Equal(t, 100, model.IntField, "int field should have last value")
227+
})
228+
}
229+
}
230+
231+
func TestUnmarshal_DuplicateKey_EmbeddedMap(t *testing.T) {
232+
t.Parallel()
233+
234+
yaml := `dynamicKey1: "value1"
235+
dynamicKey2: "value2"
236+
dynamicKey1: "value3"
237+
`
238+
239+
reader := strings.NewReader(yaml)
240+
model := &testmodels.TestEmbeddedMapHighModel{}
241+
validationErrs, err := marshaller.Unmarshal(t.Context(), reader, model)
242+
require.NoError(t, err, "unmarshal should not return a fatal error")
243+
244+
// Should have validation error for duplicate key
245+
assert.NotEmpty(t, validationErrs, "should have validation errors for duplicate key in embedded map")
246+
247+
errStr := joinErrors(validationErrs)
248+
assert.Contains(t, errStr, "dynamicKey1", "validation error should mention the duplicate key")
249+
250+
// Last value should win
251+
val, exists := model.Get("dynamicKey1")
252+
assert.True(t, exists, "key should exist in map")
253+
assert.Equal(t, "value3", val, "last value should be used")
254+
}
255+
256+
func TestUnmarshal_DuplicateKey_EmbeddedMapWithFields(t *testing.T) {
257+
t.Parallel()
258+
259+
yaml := `name: "test name"
260+
dynamicKey1:
261+
stringField: "first nested"
262+
boolField: true
263+
intField: 100
264+
float64Field: 1.23
265+
dynamicKey1:
266+
stringField: "second nested"
267+
boolField: false
268+
intField: 200
269+
float64Field: 4.56
270+
`
271+
272+
reader := strings.NewReader(yaml)
273+
model := &testmodels.TestEmbeddedMapWithFieldsHighModel{}
274+
validationErrs, err := marshaller.Unmarshal(t.Context(), reader, model)
275+
require.NoError(t, err, "unmarshal should not return a fatal error")
276+
277+
// Should have validation error for duplicate key
278+
assert.NotEmpty(t, validationErrs, "should have validation errors for duplicate key in embedded map with fields")
279+
280+
errStr := joinErrors(validationErrs)
281+
assert.Contains(t, errStr, "dynamicKey1", "validation error should mention the duplicate key")
282+
283+
// Last value should win
284+
val, exists := model.Get("dynamicKey1")
285+
assert.True(t, exists, "key should exist in map")
286+
require.NotNil(t, val, "value should not be nil")
287+
assert.Equal(t, "second nested", val.StringField, "last value's string field should be used")
288+
assert.False(t, val.BoolField, "last value's bool field should be used")
289+
assert.Equal(t, 200, val.IntField, "last value's int field should be used")
290+
}

marshaller/sequencedmap.go

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,41 @@ func unmarshalSequencedMap(ctx context.Context, parentName string, node *yaml.No
3535

3636
target.Init()
3737

38+
// Pre-scan for duplicate keys to detect them before concurrent processing
39+
type keyInfo struct {
40+
firstLine int
41+
lastIndex int
42+
}
43+
seenKeys := make(map[string]*keyInfo)
44+
indicesToSkip := make(map[int]bool)
45+
var duplicateKeyErrs []error
46+
47+
for i := 0; i < len(resolvedNode.Content); i += 2 {
48+
keyNode := resolvedNode.Content[i]
49+
resolvedKeyNode := yml.ResolveAlias(keyNode)
50+
if resolvedKeyNode == nil {
51+
continue
52+
}
53+
key := resolvedKeyNode.Value
54+
55+
if existing, ok := seenKeys[key]; ok {
56+
// This is a duplicate key - mark the previous occurrence for skipping
57+
indicesToSkip[existing.lastIndex] = true
58+
// Create validation error for the earlier occurrence
59+
duplicateKeyErrs = append(duplicateKeyErrs, validation.NewValidationError(
60+
validation.NewValueValidationError("mapping key %q at line %d is a duplicate; previous definition at line %d", key, keyNode.Line, existing.firstLine),
61+
keyNode,
62+
))
63+
// Update to point to current (last) occurrence
64+
existing.lastIndex = i / 2
65+
} else {
66+
seenKeys[key] = &keyInfo{
67+
firstLine: keyNode.Line,
68+
lastIndex: i / 2,
69+
}
70+
}
71+
}
72+
3873
g, ctx := errgroup.WithContext(ctx)
3974

4075
numJobs := len(resolvedNode.Content) / 2
@@ -49,6 +84,11 @@ func unmarshalSequencedMap(ctx context.Context, parentName string, node *yaml.No
4984

5085
for i := 0; i < len(resolvedNode.Content); i += 2 {
5186
g.Go(func() error {
87+
// Skip duplicate keys (all but the last occurrence)
88+
if indicesToSkip[i/2] {
89+
return nil
90+
}
91+
5292
keyNode := resolvedNode.Content[i]
5393
valueNode := resolvedNode.Content[i+1]
5494

@@ -96,14 +136,21 @@ func unmarshalSequencedMap(ctx context.Context, parentName string, node *yaml.No
96136
return nil, err
97137
}
98138

99-
for _, keyPair := range valuesToSet {
139+
for i, keyPair := range valuesToSet {
140+
// Skip entries that were marked as duplicates
141+
if indicesToSkip[i] {
142+
continue
143+
}
100144
if err := target.SetUntyped(keyPair.key, keyPair.value); err != nil {
101145
return nil, err
102146
}
103147
}
104148

105149
var allValidationErrs []error
106150

151+
// Add duplicate key validation errors first
152+
allValidationErrs = append(allValidationErrs, duplicateKeyErrs...)
153+
107154
for _, jobValidationErrs := range jobsValidationErrs {
108155
allValidationErrs = append(allValidationErrs, jobValidationErrs...)
109156
}

0 commit comments

Comments
 (0)