Skip to content

Commit a9aa47c

Browse files
authored
fix: ignore __type when deserializing union for AWS Json protocols (#964)
1 parent 08721de commit a9aa47c

File tree

5 files changed

+218
-2
lines changed

5 files changed

+218
-2
lines changed

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ object RuntimeTypes {
237237
val JsonSerialName = symbol("JsonSerialName")
238238
val JsonSerializer = symbol("JsonSerializer")
239239
val JsonDeserializer = symbol("JsonDeserializer")
240+
val IgnoreKey = symbol("IgnoreKey")
240241
}
241242

242243
object SerdeXml : RuntimeTypePackage(KotlinDependency.SERDE_XML) {

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ open class JsonParserGenerator(
2424

2525
open val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS
2626

27+
open fun descriptorGenerator(
28+
ctx: ProtocolGenerator.GenerationContext,
29+
shape: Shape,
30+
members: List<MemberShape>,
31+
writer: KotlinWriter,
32+
): JsonSerdeDescriptorGenerator = JsonSerdeDescriptorGenerator(
33+
ctx.toRenderingContext(protocolGenerator, shape, writer),
34+
members,
35+
supportsJsonNameTrait,
36+
)
37+
2738
override fun operationDeserializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List<MemberShape>): Symbol {
2839
val outputSymbol = op.output.get().let { ctx.symbolProvider.toSymbol(ctx.model.expectShape(it)) }
2940
return op.bodyDeserializer(ctx.settings) { writer ->
@@ -127,7 +138,7 @@ open class JsonParserGenerator(
127138
members: List<MemberShape>,
128139
writer: KotlinWriter,
129140
) {
130-
JsonSerdeDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members, supportsJsonNameTrait).render()
141+
descriptorGenerator(ctx, shape, members, writer).render()
131142
if (shape.isUnionShape) {
132143
val name = ctx.symbolProvider.toSymbol(shape).name
133144
DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render()

runtime/serde/serde-json/common/src/aws/smithy/kotlin/runtime/serde/json/JsonDeserializer.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,13 @@ private class JsonFieldIterator(
198198
val token = reader.nextTokenOf<JsonToken.Name>()
199199
val propertyName = token.value
200200
val field = descriptor.fields.find { it.serialName == propertyName }
201-
field?.index ?: Deserializer.FieldIterator.UNKNOWN_FIELD
201+
202+
if (IgnoreKey(propertyName) in descriptor.traits) {
203+
reader.skipNext() // the value of the ignored key
204+
return findNextFieldIndex()
205+
} else {
206+
field?.index ?: Deserializer.FieldIterator.UNKNOWN_FIELD
207+
}
202208
}
203209
}
204210

runtime/serde/serde-json/common/src/aws/smithy/kotlin/runtime/serde/json/JsonFieldTraits.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,9 @@ public data class JsonSerialName(public val name: String) : FieldTrait
2222
@InternalApi
2323
public val SdkFieldDescriptor.serialName: String
2424
get() = expectTrait<JsonSerialName>().name
25+
26+
/**
27+
* Indicates to deserializers to ignore field/key
28+
*/
29+
@InternalApi
30+
public data class IgnoreKey(public val key: String) : FieldTrait
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package aws.smithy.kotlin.runtime.serde.json
7+
8+
import aws.smithy.kotlin.runtime.serde.SdkFieldDescriptor
9+
import aws.smithy.kotlin.runtime.serde.SdkObjectDescriptor
10+
import aws.smithy.kotlin.runtime.serde.SerialKind
11+
import aws.smithy.kotlin.runtime.serde.deserializeStruct
12+
import kotlin.test.Test
13+
import kotlin.test.assertEquals
14+
15+
class JsonDeserializerIgnoresKeysTest {
16+
private val X_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("x"))
17+
private val Y_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("y"))
18+
private val Z_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("z"))
19+
private val OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
20+
trait(IgnoreKey("z"))
21+
field(X_DESCRIPTOR)
22+
field(Y_DESCRIPTOR)
23+
field(Z_DESCRIPTOR)
24+
}
25+
26+
@Test
27+
fun itIgnoresKeys() {
28+
val payload = """
29+
{
30+
"x": 1,
31+
"y": 2,
32+
"z": 3
33+
}
34+
""".trimIndent().encodeToByteArray()
35+
36+
val deserializer = JsonDeserializer(payload)
37+
var x: Int? = null
38+
var y: Int? = null
39+
var z: Int? = null
40+
deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
41+
loop@ while (true) {
42+
when (findNextFieldIndex()) {
43+
X_DESCRIPTOR.index -> x = deserializeInt()
44+
Y_DESCRIPTOR.index -> y = deserializeInt()
45+
Z_DESCRIPTOR.index -> z = deserializeInt()
46+
null -> break@loop
47+
}
48+
}
49+
}
50+
51+
assertEquals(1, x)
52+
assertEquals(2, y)
53+
assertEquals(null, z)
54+
}
55+
56+
@Test
57+
fun itIgnoresKeysOutOfOrder() {
58+
val payload = """
59+
{
60+
"z": 3,
61+
"x": 1,
62+
"y": 2
63+
}
64+
""".trimIndent().encodeToByteArray()
65+
66+
val deserializer = JsonDeserializer(payload)
67+
var x: Int? = null
68+
var y: Int? = null
69+
var z: Int? = null
70+
deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
71+
loop@ while (true) {
72+
when (findNextFieldIndex()) {
73+
X_DESCRIPTOR.index -> x = deserializeInt()
74+
Y_DESCRIPTOR.index -> y = deserializeInt()
75+
Z_DESCRIPTOR.index -> z = deserializeInt()
76+
null -> break@loop
77+
}
78+
}
79+
}
80+
81+
assertEquals(1, x)
82+
assertEquals(2, y)
83+
assertEquals(null, z)
84+
}
85+
86+
@Test
87+
fun itIgnoresKeysManyTimes() {
88+
val payload = """
89+
{
90+
"x": 1,
91+
"y": 2,
92+
"z": 3,
93+
"z": 3,
94+
"z": 3
95+
}
96+
""".trimIndent().encodeToByteArray()
97+
98+
val deserializer = JsonDeserializer(payload)
99+
var x: Int? = null
100+
var y: Int? = null
101+
var z: Int? = null
102+
deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
103+
loop@ while (true) {
104+
when (findNextFieldIndex()) {
105+
X_DESCRIPTOR.index -> x = deserializeInt()
106+
Y_DESCRIPTOR.index -> y = deserializeInt()
107+
Z_DESCRIPTOR.index -> z = deserializeInt()
108+
null -> break@loop
109+
}
110+
}
111+
}
112+
113+
assertEquals(1, x)
114+
assertEquals(2, y)
115+
assertEquals(null, z)
116+
}
117+
118+
private val MISSING_KEYS_OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
119+
trait(IgnoreKey("x"))
120+
field(Y_DESCRIPTOR)
121+
}
122+
123+
@Test
124+
fun itIgnoresKeysNotInModel() {
125+
val payload = """
126+
{
127+
"x": 1,
128+
"y": 2
129+
}
130+
""".trimIndent().encodeToByteArray()
131+
132+
val deserializer = JsonDeserializer(payload)
133+
var x: Int? = null
134+
var y: Int? = null
135+
deserializer.deserializeStruct(MISSING_KEYS_OBJ_DESCRIPTOR) {
136+
loop@ while (true) {
137+
when (findNextFieldIndex()) {
138+
Y_DESCRIPTOR.index -> y = deserializeInt()
139+
null -> break@loop
140+
else -> x = deserializeInt()
141+
}
142+
}
143+
}
144+
145+
assertEquals(null, x)
146+
assertEquals(2, y)
147+
}
148+
149+
private val W_DESCRIPTOR = SdkFieldDescriptor(SerialKind.Integer, JsonSerialName("w"))
150+
private val MULT_KEYS_OBJ_DESCRIPTOR = SdkObjectDescriptor.build {
151+
trait(IgnoreKey("w"))
152+
trait(IgnoreKey("z"))
153+
field(W_DESCRIPTOR)
154+
field(X_DESCRIPTOR)
155+
field(Y_DESCRIPTOR)
156+
field(Z_DESCRIPTOR)
157+
}
158+
159+
@Test
160+
fun itIgnoresMultipleKeys() {
161+
val payload = """
162+
{
163+
"w": 0,
164+
"x": 1,
165+
"y": 2,
166+
"z": 3
167+
}
168+
""".trimIndent().encodeToByteArray()
169+
170+
val deserializer = JsonDeserializer(payload)
171+
var w: Int? = null
172+
var x: Int? = null
173+
var y: Int? = null
174+
var z: Int? = null
175+
deserializer.deserializeStruct(MULT_KEYS_OBJ_DESCRIPTOR) {
176+
loop@ while (true) {
177+
when (findNextFieldIndex()) {
178+
W_DESCRIPTOR.index -> w = deserializeInt()
179+
X_DESCRIPTOR.index -> x = deserializeInt()
180+
Y_DESCRIPTOR.index -> y = deserializeInt()
181+
Z_DESCRIPTOR.index -> z = deserializeInt()
182+
null -> break@loop
183+
}
184+
}
185+
}
186+
187+
assertEquals(null, w)
188+
assertEquals(1, x)
189+
assertEquals(2, y)
190+
assertEquals(null, z)
191+
}
192+
}

0 commit comments

Comments
 (0)