Skip to content

Commit b442ce4

Browse files
authored
fix: Fix encoding of nested maps when generating encoders (#492)
1 parent 7b892d6 commit b442ce4

File tree

8 files changed

+164
-171
lines changed

8 files changed

+164
-171
lines changed

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/serde/json/MemberShapeDecodeGenerator.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,17 @@ abstract class MemberShapeDecodeGenerator(
7171
private fun determineSymbolForShape(currShape: Shape, topLevel: Boolean): String {
7272
var mappedSymbol = when (currShape) {
7373
is MapShape -> {
74+
val mapIsSparse = currShape.hasTrait<SparseTrait>()
7475
val targetShape = ctx.model.expectShape(currShape.value.target)
7576
val valueEvaluated = determineSymbolForShape(targetShape, topLevel)
76-
val terminator = if (topLevel) "?" else ""
77+
val terminator = if (topLevel || mapIsSparse) "?" else ""
7778
"[${SwiftTypes.String}: $valueEvaluated$terminator]"
7879
}
7980
is ListShape -> {
81+
val listIsSparse = currShape.hasTrait<SparseTrait>()
8082
val targetShape = ctx.model.expectShape(currShape.member.target)
8183
val nestedShape = determineSymbolForShape(targetShape, topLevel)
82-
val terminator = if (topLevel) "?" else ""
84+
val terminator = if (topLevel || listIsSparse) "?" else ""
8385
"[$nestedShape$terminator]"
8486
}
8587
is SetShape -> {

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/serde/json/MemberShapeEncodeGenerator.kt

Lines changed: 36 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import software.amazon.smithy.swift.codegen.integration.serde.getDefaultValueOfS
2929
import software.amazon.smithy.swift.codegen.model.hasTrait
3030
import software.amazon.smithy.swift.codegen.model.isBoxed
3131
import software.amazon.smithy.swift.codegen.removeSurroundingBackticks
32+
import software.amazon.smithy.swift.codegen.utils.toLowerCamelCase
3233

3334
/*
3435
Includes functions to help render conformance to Encodable protocol for shapes
@@ -86,28 +87,14 @@ abstract class MemberShapeEncodeGenerator(
8687
renderEncodeList(ctx, memberName, topLevelContainerName, targetShape, level)
8788
} else {
8889
writer.write("var \$L = $containerName.nestedUnkeyedContainer()", topLevelContainerName)
89-
val isSparse = targetShape.hasTrait<SparseTrait>()
90-
if (isSparse) {
91-
writer.openBlock("if let \$L = \$L {", "}", memberName, memberName) {
92-
renderEncodeList(ctx, memberName, topLevelContainerName, targetShape, level)
93-
}
94-
} else {
95-
renderEncodeList(ctx, memberName, topLevelContainerName, targetShape, level)
96-
}
90+
renderEncodeList(ctx, memberName, topLevelContainerName, targetShape, level)
9791
}
9892
}
9993
// this only gets called in a recursive loop where there is a map nested deeply inside a list
10094
is MapShape -> {
10195
val topLevelContainerName = "${memberName}Container"
10296
writer.write("var \$L = $containerName.nestedContainer(keyedBy: \$N.self)", topLevelContainerName, ClientRuntimeTypes.Serde.Key)
103-
val isSparse = targetShape.hasTrait<SparseTrait>()
104-
if (isSparse) {
105-
writer.openBlock("if let \$L = \$L {", "}", memberName, memberName) {
106-
renderEncodeMap(ctx, memberName, topLevelContainerName, targetShape, level)
107-
}
108-
} else {
109-
renderEncodeMap(ctx, memberName, topLevelContainerName, targetShape, level)
110-
}
97+
renderEncodeMap(ctx, memberName, topLevelContainerName, targetShape, level)
11198
}
11299
else -> {
113100
renderSimpleShape(targetShape, memberName, containerName, null, false)
@@ -151,24 +138,22 @@ abstract class MemberShapeEncodeGenerator(
151138
ctx: ProtocolGenerator.GenerationContext,
152139
collectionName: String,
153140
topLevelContainerName: String,
154-
targetShape: Shape,
141+
listShape: CollectionShape,
155142
level: Int = 0
156143
) {
144+
val listIsSparse = listShape.hasTrait<SparseTrait>()
145+
val targetShape = ctx.model.expectShape(listShape.member.target)
157146
val iteratorName = "${targetShape.id.name.lowercase()}$level"
158147
writer.openBlock("for $iteratorName in $collectionName {", "}") {
159-
when (targetShape) {
160-
is CollectionShape -> {
161-
val nestedTarget = ctx.model.expectShape(targetShape.member.target)
162-
renderEncodeListMember(nestedTarget, iteratorName, topLevelContainerName, level + 1)
148+
if (listIsSparse) {
149+
writer.openBlock("guard let \$L = \$L else {", "}", iteratorName, iteratorName) {
150+
writer.write("try \$L.encodeNil()", topLevelContainerName)
151+
writer.write("continue")
163152
}
164-
is MapShape -> {
165-
val nestedTarget = ctx.model.expectShape(targetShape.value.target)
166-
renderEncodeMapMember(
167-
nestedTarget,
168-
"${ClientRuntimeTypes.Serde.Key}(stringValue: $dictKey)",
169-
topLevelContainerName,
170-
level + 1
171-
)
153+
}
154+
when (targetShape) {
155+
is CollectionShape, is MapShape -> {
156+
renderEncodeListMember(targetShape, iteratorName, topLevelContainerName, level + 1)
172157
}
173158
else -> {
174159
val isBoxed = ctx.symbolProvider.toSymbol(targetShape).isBoxed() && targetShape.hasTrait<SparseTrait>()
@@ -186,31 +171,31 @@ abstract class MemberShapeEncodeGenerator(
186171

187172
// Render encoding of a member of Map type
188173
fun renderEncodeMapMember(targetShape: Shape, memberName: String, containerName: String, level: Int = 0) {
174+
val keyName = if (level == 0) ".$memberName" else "${ClientRuntimeTypes.Serde.Key}(stringValue: $dictKey${level - 1})"
189175
when (targetShape) {
190176
is CollectionShape -> {
191177
val topLevelContainerName = "${memberName}Container"
192-
writer.write("var \$L = $containerName.nestedUnkeyedContainer(forKey: \$N($dictKey${level - 1}))", topLevelContainerName, ClientRuntimeTypes.Serde.Key)
178+
writer.write("var \$L = $containerName.nestedUnkeyedContainer(forKey: \$L)", topLevelContainerName, keyName)
193179
renderEncodeList(ctx, memberName, topLevelContainerName, targetShape, level)
194180
}
195181
is MapShape -> {
196182
val topLevelContainerName = "${memberName}Container"
197183
writer.write(
198-
"var \$L = $containerName.nestedContainer(keyedBy: \$N.self, forKey: .\$L)",
184+
"var \$L = $containerName.nestedContainer(keyedBy: \$N.self, forKey: \$L)",
199185
topLevelContainerName,
200186
ClientRuntimeTypes.Serde.Key,
201-
memberName
187+
keyName
202188
)
203-
renderEncodeMap(ctx, memberName, topLevelContainerName, targetShape.value, level)
189+
renderEncodeMap(ctx, memberName, topLevelContainerName, targetShape, level)
204190
}
205191
else -> {
206192
val isBoxed = ctx.symbolProvider.toSymbol(targetShape).isBoxed() && targetShape.hasTrait<SparseTrait>()
207-
val keyEnumName = if (level == 0) ".$memberName" else "${ClientRuntimeTypes.Serde.Key}(stringValue: $dictKey${level - 1})"
208193
if (isBoxed && level == 0) {
209194
writer.openBlock("if let \$L = \$L {", "}", memberName, memberName) {
210-
renderSimpleShape(targetShape, memberName, containerName, keyEnumName, isBoxed)
195+
renderSimpleShape(targetShape, memberName, containerName, keyName, isBoxed)
211196
}
212197
} else {
213-
renderSimpleShape(targetShape, memberName, containerName, keyEnumName, isBoxed)
198+
renderSimpleShape(targetShape, memberName, containerName, keyName, isBoxed)
214199
}
215200
}
216201
}
@@ -221,37 +206,27 @@ abstract class MemberShapeEncodeGenerator(
221206
ctx: ProtocolGenerator.GenerationContext,
222207
mapName: String,
223208
topLevelContainerName: String,
224-
valueTargetShape: Shape,
209+
mapShape: MapShape,
225210
level: Int = 0
226211
) {
227-
val valueIterator = "${valueTargetShape.id.name.toLowerCase()}$level"
228-
val target = when (valueTargetShape) {
229-
is MemberShape -> ctx.model.expectShape(valueTargetShape.target)
230-
else -> valueTargetShape
231-
}
232-
writer.openBlock("for ($dictKey$level, $valueIterator) in $mapName {", "}") {
233-
when (target) {
234-
is CollectionShape -> {
235-
val nestedTarget = ctx.model.expectShape(target.member.target)
236-
renderEncodeMapMember(
237-
nestedTarget,
238-
valueIterator,
239-
topLevelContainerName,
240-
level + 1
241-
)
212+
val keyIterator = "$dictKey$level"
213+
val valueIterator = "${mapShape.id.name.toLowerCamelCase()}$level"
214+
val target = ctx.model.expectShape(mapShape.value.target)
215+
val mapIsSparse = mapShape.hasTrait<SparseTrait>()
216+
writer.openBlock("for ($keyIterator, $valueIterator) in $mapName {", "}") {
217+
if (mapIsSparse) {
218+
writer.openBlock("guard let \$L = \$L else {", "}", valueIterator, valueIterator) {
219+
writer.write("try \$L.encodeNil(forKey: \$L(stringValue: \$L))", topLevelContainerName, ClientRuntimeTypes.Serde.Key, keyIterator)
220+
writer.write("continue")
242221
}
243-
is MapShape -> {
244-
val nestedTarget = ctx.model.expectShape(target.value.target)
245-
renderEncodeMapMember(
246-
nestedTarget,
247-
valueIterator,
248-
topLevelContainerName,
249-
level + 1
250-
)
222+
}
223+
when (target) {
224+
is CollectionShape, is MapShape -> {
225+
renderEncodeMapMember(target, valueIterator, topLevelContainerName, level + 1)
251226
}
252227
else -> {
253228
val keyEnumName = "${ClientRuntimeTypes.Serde.Key}(stringValue: $dictKey$level)"
254-
renderSimpleShape(valueTargetShape, valueIterator, topLevelContainerName, keyEnumName, valueTargetShape.hasTrait(BoxTrait::class.java))
229+
renderSimpleShape(target, valueIterator, topLevelContainerName, keyEnumName, target.hasTrait(BoxTrait::class.java))
255230
}
256231
}
257232
}

smithy-swift-codegen/src/test/kotlin/StructDecodeGenerationTests.kt

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ class StructDecodeGenerationTests {
7777
fun `it decodes nested documents with aggregate shapes`() {
7878
val contents = getModelFileContents("example", "Nested4+Codable.swift", newTestContext.manifest)
7979
contents.shouldSyntacticSanityCheck()
80-
val expectedContents =
81-
"""
80+
val expectedContents = """
8281
extension ExampleClientTypes.Nested4: Swift.Codable {
8382
enum CodingKeys: Swift.String, Swift.CodingKey {
8483
case intList
@@ -91,23 +90,26 @@ class StructDecodeGenerationTests {
9190
var encodeContainer = encoder.container(keyedBy: CodingKeys.self)
9291
if let intList = intList {
9392
var intListContainer = encodeContainer.nestedUnkeyedContainer(forKey: .intList)
94-
for intlist0 in intList {
95-
try intListContainer.encode(intlist0)
93+
for integer0 in intList {
94+
try intListContainer.encode(integer0)
9695
}
9796
}
9897
if let intMap = intMap {
9998
var intMapContainer = encodeContainer.nestedContainer(keyedBy: ClientRuntime.Key.self, forKey: .intMap)
100-
for (dictKey0, intmap0) in intMap {
101-
try intMapContainer.encode(intmap0, forKey: ClientRuntime.Key(stringValue: dictKey0))
99+
for (dictKey0, intMap0) in intMap {
100+
try intMapContainer.encode(intMap0, forKey: ClientRuntime.Key(stringValue: dictKey0))
102101
}
103102
}
104103
if let member1 = self.member1 {
105104
try encodeContainer.encode(member1, forKey: .member1)
106105
}
107106
if let stringMap = stringMap {
108107
var stringMapContainer = encodeContainer.nestedContainer(keyedBy: ClientRuntime.Key.self, forKey: .stringMap)
109-
for (dictKey0, nestedstringmap0) in stringMap {
110-
try stringMapContainer.encode(nestedstringmap0, forKey: ClientRuntime.Key(stringValue: dictKey0))
108+
for (dictKey0, nestedStringMap0) in stringMap {
109+
var nestedStringMap0Container = stringMapContainer.nestedUnkeyedContainer(forKey: ClientRuntime.Key(stringValue: dictKey0))
110+
for string1 in nestedStringMap0 {
111+
try nestedStringMap0Container.encode(string1)
112+
}
111113
}
112114
}
113115
}
@@ -158,7 +160,7 @@ class StructDecodeGenerationTests {
158160
stringMap = stringMapDecoded0
159161
}
160162
}
161-
""".trimIndent()
163+
""".trimIndent()
162164
contents.shouldContainOnlyOnce(expectedContents)
163165
}
164166

smithy-swift-codegen/src/test/kotlin/StructEncodeGenerationIsolatedTests.kt

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ class StructEncodeGenerationIsolatedTests {
9696
val contents = getFileContents(context.manifest, "/example/models/JsonListsInput+Encodable.swift")
9797
contents.shouldSyntacticSanityCheck()
9898

99-
val expectedContents =
100-
"""
99+
val expectedContents = """
101100
extension JsonListsInput: Swift.Encodable {
102101
enum CodingKeys: Swift.String, Swift.CodingKey {
103102
case nestedStringList
@@ -109,28 +108,28 @@ class StructEncodeGenerationIsolatedTests {
109108
var encodeContainer = encoder.container(keyedBy: CodingKeys.self)
110109
if let nestedStringList = nestedStringList {
111110
var nestedStringListContainer = encodeContainer.nestedUnkeyedContainer(forKey: .nestedStringList)
112-
for nestedstringlist0 in nestedStringList {
113-
var nestedstringlist0Container = nestedStringListContainer.nestedUnkeyedContainer()
114-
for stringlist1 in nestedstringlist0 {
115-
try nestedstringlist0Container.encode(stringlist1)
111+
for stringlist0 in nestedStringList {
112+
var stringlist0Container = nestedStringListContainer.nestedUnkeyedContainer()
113+
for string1 in stringlist0 {
114+
try stringlist0Container.encode(string1)
116115
}
117116
}
118117
}
119118
if let stringList = stringList {
120119
var stringListContainer = encodeContainer.nestedUnkeyedContainer(forKey: .stringList)
121-
for stringlist0 in stringList {
122-
try stringListContainer.encode(stringlist0)
120+
for string0 in stringList {
121+
try stringListContainer.encode(string0)
123122
}
124123
}
125124
if let stringSet = stringSet {
126125
var stringSetContainer = encodeContainer.nestedUnkeyedContainer(forKey: .stringSet)
127-
for stringset0 in stringSet {
128-
try stringSetContainer.encode(stringset0)
126+
for string0 in stringSet {
127+
try stringSetContainer.encode(string0)
129128
}
130129
}
131130
}
132131
}
133-
""".trimIndent()
132+
""".trimIndent()
134133
contents.shouldContainOnlyOnce(expectedContents)
135134
}
136135

0 commit comments

Comments
 (0)