Skip to content

Commit b9b6443

Browse files
author
Ed Paulosky
authored
fix: Properly SerDe recusive union shapes (enums) (#474)
* Prevents Union shapes from being boxed during preprocessing * Fixes up tests
1 parent 60de8a4 commit b9b6443

File tree

4 files changed

+84
-9
lines changed

4 files changed

+84
-9
lines changed

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/model/RecursiveShapeBoxer.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@ import software.amazon.smithy.model.shapes.MapShape
1212
import software.amazon.smithy.model.shapes.MemberShape
1313
import software.amazon.smithy.model.shapes.SetShape
1414
import software.amazon.smithy.model.shapes.Shape
15+
import software.amazon.smithy.model.shapes.UnionShape
1516
import software.amazon.smithy.model.transform.ModelTransformer
1617
import software.amazon.smithy.swift.codegen.customtraits.SwiftBoxTrait
1718

1819
object RecursiveShapeBoxer {
1920
/**
2021
* Transform a model which may contain recursive shapes into a model annotated with [SwiftBoxTrait]
2122
*
22-
* When recursive shapes do NOT go through a List, Map, or Set, they must be boxed in Swift. This function will
23+
* When recursive shapes do NOT go through a List, Map, Union, or Set, they must be boxed in Swift. This function will
2324
* iteratively find loops & add the `SwiftBoxTrait` trait in a deterministic way until it reaches a fixed point.
2425
*
2526
* This function MUST be deterministic (always choose the same shapes to `Box`). If it is not, that is a bug. Even so
@@ -84,6 +85,7 @@ object RecursiveShapeBoxer {
8485
when (it) {
8586
is ListShape,
8687
is MapShape,
88+
is UnionShape,
8789
is SetShape -> true
8890
else -> it.hasTrait(SwiftBoxTrait::class.java)
8991
}

smithy-swift-codegen/src/test/kotlin/UnionEncodeGeneratorTests.kt

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,27 @@
66
import io.kotest.matchers.string.shouldContainOnlyOnce
77
import org.junit.jupiter.api.Assertions
88
import org.junit.jupiter.api.Test
9+
import software.amazon.smithy.model.Model
910
import software.amazon.smithy.swift.codegen.model.AddOperationShapes
11+
import software.amazon.smithy.swift.codegen.model.HashableShapeTransformer
12+
import software.amazon.smithy.swift.codegen.model.NestedShapeTransformer
13+
import software.amazon.smithy.swift.codegen.model.RecursiveShapeBoxer
1014

1115
class UnionEncodeGeneratorTests {
1216
var model = javaClass.getResource("http-binding-protocol-generator-test.smithy").asSmithy()
1317
private fun newTestContext(): TestContext {
14-
val settings = model.defaultSettings()
15-
model = AddOperationShapes.execute(model, settings.getService(model), settings.moduleName)
18+
model = preprocessModel(model)
1619
return model.newTestContext()
1720
}
21+
private fun preprocessModel(model: Model): Model {
22+
val settings = model.defaultSettings()
23+
var resolvedModel = model
24+
resolvedModel = AddOperationShapes.execute(resolvedModel, settings.getService(resolvedModel), settings.moduleName)
25+
resolvedModel = RecursiveShapeBoxer.transform(resolvedModel)
26+
resolvedModel = HashableShapeTransformer.transform(resolvedModel)
27+
resolvedModel = NestedShapeTransformer.transform(resolvedModel, settings.getService(resolvedModel))
28+
return resolvedModel
29+
}
1830
val newTestContext = newTestContext()
1931
init {
2032
newTestContext.generator.generateSerializers(newTestContext.generationCtx)
@@ -60,7 +72,7 @@ class UnionEncodeGeneratorTests {
6072
contents.shouldSyntacticSanityCheck()
6173
val expectedContents =
6274
"""
63-
extension MyUnion: Swift.Codable {
75+
extension ExampleClientTypes.MyUnion: Swift.Codable {
6476
enum CodingKeys: Swift.String, Swift.CodingKey {
6577
case blobvalue = "blobValue"
6678
case booleanvalue = "booleanValue"
@@ -141,7 +153,7 @@ class UnionEncodeGeneratorTests {
141153
self = .inheritedtimestamp(inheritedtimestamp)
142154
return
143155
}
144-
let enumvalueDecoded = try values.decodeIfPresent(FooEnum.self, forKey: .enumvalue)
156+
let enumvalueDecoded = try values.decodeIfPresent(ExampleClientTypes.FooEnum.self, forKey: .enumvalue)
145157
if let enumvalue = enumvalueDecoded {
146158
self = .enumvalue(enumvalue)
147159
return
@@ -174,7 +186,7 @@ class UnionEncodeGeneratorTests {
174186
self = .mapvalue(mapvalue)
175187
return
176188
}
177-
let structurevalueDecoded = try values.decodeIfPresent(GreetingWithErrorsOutput.self, forKey: .structurevalue)
189+
let structurevalueDecoded = try values.decodeIfPresent(ExampleClientTypes.GreetingWithErrorsOutput.self, forKey: .structurevalue)
178190
if let structurevalue = structurevalueDecoded {
179191
self = .structurevalue(structurevalue)
180192
return
@@ -185,4 +197,48 @@ class UnionEncodeGeneratorTests {
185197
""".trimIndent()
186198
contents.shouldContainOnlyOnce(expectedContents)
187199
}
200+
201+
@Test
202+
fun `it generates codable conformance for a recursive union`() {
203+
val contents = getModelFileContents("example", "IndirectEnum+Codable.swift", newTestContext.manifest)
204+
contents.shouldSyntacticSanityCheck()
205+
val expectedContents =
206+
"""
207+
extension ExampleClientTypes.IndirectEnum: Swift.Codable {
208+
enum CodingKeys: Swift.String, Swift.CodingKey {
209+
case other
210+
case sdkUnknown
211+
case some
212+
}
213+
214+
public func encode(to encoder: Swift.Encoder) throws {
215+
var container = encoder.container(keyedBy: CodingKeys.self)
216+
switch self {
217+
case let .other(other):
218+
try container.encode(other, forKey: .other)
219+
case let .some(some):
220+
try container.encode(some, forKey: .some)
221+
case let .sdkUnknown(sdkUnknown):
222+
try container.encode(sdkUnknown, forKey: .sdkUnknown)
223+
}
224+
}
225+
226+
public init (from decoder: Swift.Decoder) throws {
227+
let values = try decoder.container(keyedBy: CodingKeys.self)
228+
let someDecoded = try values.decodeIfPresent(ExampleClientTypes.IndirectEnum.self, forKey: .some)
229+
if let some = someDecoded {
230+
self = .some(some)
231+
return
232+
}
233+
let otherDecoded = try values.decodeIfPresent(Swift.String.self, forKey: .other)
234+
if let other = otherDecoded {
235+
self = .other(other)
236+
return
237+
}
238+
self = .sdkUnknown("")
239+
}
240+
}
241+
""".trimIndent()
242+
contents.shouldContainOnlyOnce(expectedContents)
243+
}
188244
}

smithy-swift-codegen/src/test/kotlin/serde/xml/UnionEncodeXMLGenerationTests.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class UnionEncodeXMLGenerationTests {
7474
let datavalueDecoded = try containerValues.decode(ClientRuntime.Data.self, forKey: .datavalue)
7575
self = .datavalue(datavalueDecoded)
7676
case .unionvalue:
77-
let unionvalueDecoded = try containerValues.decode(Box<RestXmlProtocolClientTypes.XmlUnionShape>.self, forKey: .unionvalue)
78-
self = .unionvalue(unionvalueDecoded.value)
77+
let unionvalueDecoded = try containerValues.decode(RestXmlProtocolClientTypes.XmlUnionShape.self, forKey: .unionvalue)
78+
self = .unionvalue(unionvalueDecoded)
7979
case .structvalue:
8080
let structvalueDecoded = try containerValues.decode(RestXmlProtocolClientTypes.XmlNestedUnionStruct.self, forKey: .structvalue)
8181
self = .structvalue(structvalueDecoded)

smithy-swift-codegen/src/test/resources/http-binding-protocol-generator-test.smithy

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ service Example {
1818
ListInput,
1919
MapInput,
2020
EnumInput,
21+
IndirectEnumOperation,
2122
TimestampInput,
2223
BlobInput,
2324
EmptyInputAndEmptyOutput,
@@ -1029,6 +1030,12 @@ operation JsonUnions {
10291030
output: UnionInputOutput,
10301031
}
10311032

1033+
@http(uri: "/IndirectEnumOperation", method: "POST")
1034+
operation IndirectEnumOperation {
1035+
input: IndirectEnumInputOutput
1036+
output: IndirectEnumInputOutput
1037+
}
1038+
10321039
@timestampFormat("http-date")
10331040
timestamp CommonTimestamp
10341041

@@ -1362,4 +1369,14 @@ structure IdempotencyTokenWithoutHttpPayloadTraitOnTokenInput {
13621369
@httpHeader("token")
13631370
@idempotencyToken
13641371
token: String,
1365-
}
1372+
}
1373+
1374+
union IndirectEnum {
1375+
some: IndirectEnum
1376+
other: String
1377+
}
1378+
1379+
structure IndirectEnumInputOutput {
1380+
value: IndirectEnum
1381+
}
1382+

0 commit comments

Comments
 (0)