Skip to content

Commit 277a77c

Browse files
authored
fix: correct deserialization of nested map/list/types in unions (#1144)
1 parent 932b579 commit 277a77c

File tree

4 files changed

+28
-14
lines changed

4 files changed

+28
-14
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "b786851c-9427-40cd-b3fa-ee375011d931",
3+
"type": "bugfix",
4+
"description": "Correct deserialization of nested map/list types in unions",
5+
"issues": [
6+
"awslabs/smithy-kotlin#1126"
7+
]
8+
}

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ open class DeserializeStructGenerator(
4040
/**
4141
* Enables overriding the codegen output of the final value resulting
4242
* from the deserialization of a non-primitive type.
43-
* @param memberShape [MemberShape] associated with entry
43+
* @param forMemberShape [MemberShape] associated with entry, if any
4444
* @param defaultCollectionName the default value produced by this class.
4545
*/
46-
open fun collectionReturnExpression(memberShape: MemberShape, defaultCollectionName: String): String = defaultCollectionName
46+
open fun collectionReturnExpression(forMemberShape: MemberShape?, defaultCollectionName: String): String =
47+
defaultCollectionName
4748

4849
/**
4950
* Enables overriding of the lhs expression into which a deserialization operation's
@@ -292,7 +293,7 @@ open class DeserializeStructGenerator(
292293
val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
293294
val nextNestingLevel = nestingLevel + 1
294295
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
295-
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
296+
val collectionReturnExpression = collectionReturnExpression(null, memberName)
296297

297298
writeKeyVal(keyShape, keySymbol, keyName)
298299
writer.withBlock("val $valueName =", "") {
@@ -346,7 +347,7 @@ open class DeserializeStructGenerator(
346347
val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
347348
val nextNestingLevel = nestingLevel + 1
348349
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
349-
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
350+
val collectionReturnExpression = collectionReturnExpression(null, memberName)
350351

351352
writeKeyVal(keyShape, keySymbol, keyName)
352353
writer.withBlock("val $valueName =", "") {
@@ -516,7 +517,7 @@ open class DeserializeStructGenerator(
516517
val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
517518
val nextNestingLevel = nestingLevel + 1
518519
val mapName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
519-
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, mapName)
520+
val collectionReturnExpression = collectionReturnExpression(null, mapName)
520521

521522
writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
522523
write(
@@ -555,7 +556,7 @@ open class DeserializeStructGenerator(
555556
val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
556557
val nextNestingLevel = nestingLevel + 1
557558
val listName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
558-
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, listName)
559+
val collectionReturnExpression = collectionReturnExpression(null, listName)
559560

560561
writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
561562
write(

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeUnionGenerator.kt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,13 @@ class DeserializeUnionGenerator(
8787
override fun deserializationResultName(defaultName: String): String = "value"
8888

8989
// Return the type that deserializes the incoming value. Example: `MyAggregateUnion.IntList`
90-
override fun collectionReturnExpression(memberShape: MemberShape, defaultCollectionName: String): String {
91-
val unionTypeName = memberShape.unionTypeName(ctx)
92-
return "$unionTypeName($defaultCollectionName)"
93-
}
90+
override fun collectionReturnExpression(forMemberShape: MemberShape?, defaultCollectionName: String) =
91+
if (forMemberShape != null && forMemberShape in members) {
92+
// We're returning a top-level collection for a member value—nest it inside a union variant
93+
val unionTypeName = forMemberShape.unionTypeName(ctx)
94+
"$unionTypeName($defaultCollectionName)"
95+
} else {
96+
// We're returning a nested collection type—don't nest it inside a union variant
97+
super.collectionReturnExpression(null, defaultCollectionName)
98+
}
9499
}

codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeUnionGeneratorTest.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,11 @@ class DeserializeUnionGeneratorTest {
187187
val v2 = if (nextHasValue()) { deserializeBarUnionDocument(deserializer) } else { deserializeNull(); continue }
188188
map2[k2] = v2
189189
}
190-
FooUnion.StrMapVal(map2)
190+
map2
191191
}
192192
col1.add(el1)
193193
}
194-
FooUnion.StrMapVal(col1)
194+
col1
195195
}
196196
} else { deserializeNull(); continue }
197197
@@ -269,7 +269,7 @@ class DeserializeUnionGeneratorTest {
269269
val el1 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
270270
col1.add(el1)
271271
}
272-
MyAggregateUnion.ListOfIntList(col1)
272+
col1
273273
}
274274
col0.add(el0)
275275
}
@@ -288,7 +288,7 @@ class DeserializeUnionGeneratorTest {
288288
val el1 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
289289
col1.add(el1)
290290
}
291-
MyAggregateUnion.MapOfLists(col1)
291+
col1
292292
}
293293
} else { deserializeNull(); continue }
294294

0 commit comments

Comments
 (0)