Skip to content

Commit 0c3c5da

Browse files
fix: fixed sync of extensions when creating new models
1 parent 47cbcb0 commit 0c3c5da

File tree

3 files changed

+73
-17
lines changed

3 files changed

+73
-17
lines changed

extensions/extensions.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ func (e *Extensions) SetCore(core any) {
6262
e.core = c
6363
}
6464

65+
func (e *Extensions) GetCore() *sequencedmap.Map[string, marshaller.Node[*yaml.Node]] {
66+
return e.core
67+
}
68+
6569
// UnmarshalExtensionModel will unmarshal the extension into a model and its associated core model.
6670
func UnmarshalExtensionModel[H any, L any](ctx context.Context, e *Extensions, ext string, m *H) error {
6771
if e == nil {

marshaller/extensions.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"iter"
77
"reflect"
88
"slices"
9+
"unsafe"
910

1011
"github.com/speakeasy-api/openapi/errors"
1112
"github.com/speakeasy-api/openapi/yml"
@@ -110,5 +111,14 @@ func syncExtensions(ctx context.Context, source any, target reflect.Value, mapNo
110111
}
111112
}
112113

114+
sUnderlying := getUnderlyingValue(reflect.ValueOf(source))
115+
116+
// Update the core of the source with the updated value
117+
cf, ok := sUnderlying.Type().FieldByName("core")
118+
if ok {
119+
sf := sUnderlying.FieldByIndex(cf.Index)
120+
reflect.NewAt(sf.Type(), unsafe.Pointer(sf.UnsafeAddr())).Elem().Set(target)
121+
}
122+
113123
return mapNode, nil
114124
}

marshaller/syncer_test.go

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
package marshaller
1+
package marshaller_test
22

33
import (
44
"context"
55
"fmt"
66
"reflect"
77
"testing"
88

9+
"github.com/speakeasy-api/openapi/extensions"
10+
"github.com/speakeasy-api/openapi/extensions/core"
911
"github.com/speakeasy-api/openapi/internal/testutils"
12+
"github.com/speakeasy-api/openapi/marshaller"
1013
"github.com/speakeasy-api/openapi/pointer"
1114
"github.com/stretchr/testify/assert"
1215
"github.com/stretchr/testify/require"
@@ -15,23 +18,23 @@ import (
1518

1619
func TestSyncValue_String(t *testing.T) {
1720
target := ""
18-
outNode, err := SyncValue(context.Background(), "some-value", &target, nil, false)
21+
outNode, err := marshaller.SyncValue(context.Background(), "some-value", &target, nil, false)
1922
require.NoError(t, err)
2023
assert.Equal(t, testutils.CreateStringYamlNode("some-value", 0, 0), outNode)
2124
assert.Equal(t, "some-value", target)
2225
}
2326

2427
func TestSyncValue_StringPtrSet(t *testing.T) {
2528
target := pointer.From("")
26-
outNode, err := SyncValue(context.Background(), pointer.From("some-value"), &target, nil, false)
29+
outNode, err := marshaller.SyncValue(context.Background(), pointer.From("some-value"), &target, nil, false)
2730
require.NoError(t, err)
2831
assert.Equal(t, testutils.CreateStringYamlNode("some-value", 0, 0), outNode)
2932
assert.Equal(t, "some-value", *target)
3033
}
3134

3235
func TestSyncValue_StringPtrNil(t *testing.T) {
3336
var target *string
34-
outNode, err := SyncValue(context.Background(), pointer.From("some-value"), &target, nil, false)
37+
outNode, err := marshaller.SyncValue(context.Background(), pointer.From("some-value"), &target, nil, false)
3538
require.NoError(t, err)
3639
assert.Equal(t, testutils.CreateStringYamlNode("some-value", 0, 0), outNode)
3740
assert.Equal(t, "some-value", *target)
@@ -57,7 +60,7 @@ func (t *TestStructSyncerCore[T]) SyncChanges(ctx context.Context, model any, va
5760
}
5861

5962
var err error
60-
t.RootNode, err = SyncValue(ctx, mv.FieldByName("Val").Interface(), &t.Val, valueNode, false)
63+
t.RootNode, err = marshaller.SyncValue(ctx, mv.FieldByName("Val").Interface(), &t.Val, valueNode, false)
6164
return t.RootNode, err
6265
}
6366

@@ -66,7 +69,7 @@ func TestSyncValue_StructPtr_CustomSyncer(t *testing.T) {
6669

6770
source := &TestStructSyncer[int]{Val: pointer.From(1)}
6871

69-
outNode, err := SyncValue(context.Background(), source, &target, nil, false)
72+
outNode, err := marshaller.SyncValue(context.Background(), source, &target, nil, false)
7073
require.NoError(t, err)
7174
node := testutils.CreateIntYamlNode(1, 0, 0)
7275
assert.Equal(t, node, outNode)
@@ -79,7 +82,7 @@ func TestSyncValue_Struct_CustomSyncer(t *testing.T) {
7982

8083
source := TestStructSyncer[int]{Val: pointer.From(1)}
8184

82-
outNode, err := SyncValue(context.Background(), source, &target, nil, false)
85+
outNode, err := marshaller.SyncValue(context.Background(), source, &target, nil, false)
8386
require.NoError(t, err)
8487
node := testutils.CreateIntYamlNode(1, 0, 0)
8588
assert.Equal(t, node, outNode)
@@ -96,10 +99,10 @@ type TestStruct struct {
9699
}
97100

98101
type TestStructCore struct {
99-
Int Node[int] `key:"int"`
100-
Str Node[string] `key:"str"`
101-
StrPtr Node[*string] `key:"strPtr"`
102-
BoolPtr Node[*bool] `key:"boolPtr"`
102+
Int marshaller.Node[int] `key:"int"`
103+
Str marshaller.Node[string] `key:"str"`
104+
StrPtr marshaller.Node[*string] `key:"strPtr"`
105+
BoolPtr marshaller.Node[*bool] `key:"boolPtr"`
103106

104107
RootNode *yaml.Node
105108
}
@@ -112,7 +115,7 @@ func TestSyncChanges_Struct(t *testing.T) {
112115
BoolPtr: pointer.From(true),
113116
}
114117

115-
outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false)
118+
outNode, err := marshaller.SyncValue(context.Background(), &source, &source.core, nil, false)
116119
require.NoError(t, err)
117120

118121
node := testutils.CreateMapYamlNode([]*yaml.Node{
@@ -140,7 +143,7 @@ func TestSyncChanges_StructWithOptionalsUnset(t *testing.T) {
140143
Str: "some-string",
141144
}
142145

143-
outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false)
146+
outNode, err := marshaller.SyncValue(context.Background(), &source, &source.core, nil, false)
144147
require.NoError(t, err)
145148

146149
node := testutils.CreateMapYamlNode([]*yaml.Node{
@@ -166,7 +169,7 @@ func TestSyncChanges_StructPtr(t *testing.T) {
166169
BoolPtr: pointer.From(true),
167170
}
168171

169-
outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false)
172+
outNode, err := marshaller.SyncValue(context.Background(), &source, &source.core, nil, false)
170173
require.NoError(t, err)
171174

172175
node := testutils.CreateMapYamlNode([]*yaml.Node{
@@ -195,7 +198,7 @@ type TestStructNested struct {
195198
}
196199

197200
type TestStructNestedCore struct {
198-
TestStruct Node[TestStructCore] `key:"testStruct"`
201+
TestStruct marshaller.Node[TestStructCore] `key:"testStruct"`
199202

200203
RootNode *yaml.Node
201204
}
@@ -210,7 +213,7 @@ func TestSyncChanges_NestedStruct(t *testing.T) {
210213
},
211214
}
212215

213-
outNode, err := SyncValue(context.Background(), &source, &source.core, nil, false)
216+
outNode, err := marshaller.SyncValue(context.Background(), &source, &source.core, nil, false)
214217
require.NoError(t, err)
215218

216219
nestedNode := testutils.CreateMapYamlNode([]*yaml.Node{
@@ -242,8 +245,47 @@ type TestInt int
242245

243246
func TestSyncValue_TypeDefinition(t *testing.T) {
244247
var target TestInt
245-
outNode, err := SyncValue(context.Background(), 1, &target, nil, false)
248+
outNode, err := marshaller.SyncValue(context.Background(), 1, &target, nil, false)
246249
require.NoError(t, err)
247250
assert.Equal(t, testutils.CreateIntYamlNode(1, 0, 0), outNode)
248251
assert.Equal(t, TestInt(1), target)
249252
}
253+
254+
type TestStructWithExtensions struct {
255+
Extensions *extensions.Extensions
256+
257+
core TestStructWithExtensionsCore
258+
}
259+
260+
type TestStructWithExtensionsCore struct {
261+
Extensions core.Extensions `key:"extensions"`
262+
263+
RootNode *yaml.Node
264+
}
265+
266+
func TestSyncValue_TypeWithExtensions(t *testing.T) {
267+
var source TestStructWithExtensions
268+
269+
extensionNode := testutils.CreateMapYamlNode(
270+
[]*yaml.Node{
271+
testutils.CreateStringYamlNode("name", 0, 0),
272+
testutils.CreateStringYamlNode("test", 0, 0),
273+
testutils.CreateStringYamlNode("value", 0, 0),
274+
testutils.CreateIntYamlNode(1, 0, 0),
275+
}, 0, 0)
276+
277+
source.Extensions = extensions.New(extensions.NewElem("x-speakeasy-test", extensionNode))
278+
279+
outNode, err := marshaller.SyncValue(context.Background(), &source, &source.core, nil, false)
280+
require.NoError(t, err)
281+
282+
node := testutils.CreateMapYamlNode(
283+
[]*yaml.Node{
284+
testutils.CreateStringYamlNode("x-speakeasy-test", 0, 0),
285+
extensionNode,
286+
}, 0, 0)
287+
288+
assert.Equal(t, node, outNode)
289+
assert.Equal(t, node, source.core.RootNode)
290+
assert.True(t, source.Extensions.GetCore().Has("x-speakeasy-test"))
291+
}

0 commit comments

Comments
 (0)