66import io.kotest.matchers.string.shouldContainOnlyOnce
77import org.junit.jupiter.api.Assertions
88import org.junit.jupiter.api.Test
9+ import software.amazon.smithy.model.Model
910import software.amazon.smithy.swift.codegen.model.AddOperationShapes
11+ import software.amazon.smithy.swift.codegen.model.HashableShapeTransformer
12+ import software.amazon.smithy.swift.codegen.model.NestedShapeTransformer
13+ import software.amazon.smithy.swift.codegen.model.RecursiveShapeBoxer
1014
1115class UnionEncodeGeneratorTests {
1216 var model = javaClass.getResource(" http-binding-protocol-generator-test.smithy" ).asSmithy()
1317 private fun newTestContext (): TestContext {
14- val settings = model.defaultSettings()
15- model = AddOperationShapes .execute(model, settings.getService(model), settings.moduleName)
18+ model = preprocessModel(model)
1619 return model.newTestContext()
1720 }
21+ private fun preprocessModel (model : Model ): Model {
22+ val settings = model.defaultSettings()
23+ var resolvedModel = model
24+ resolvedModel = AddOperationShapes .execute(resolvedModel, settings.getService(resolvedModel), settings.moduleName)
25+ resolvedModel = RecursiveShapeBoxer .transform(resolvedModel)
26+ resolvedModel = HashableShapeTransformer .transform(resolvedModel)
27+ resolvedModel = NestedShapeTransformer .transform(resolvedModel, settings.getService(resolvedModel))
28+ return resolvedModel
29+ }
1830 val newTestContext = newTestContext()
1931 init {
2032 newTestContext.generator.generateSerializers(newTestContext.generationCtx)
@@ -60,7 +72,7 @@ class UnionEncodeGeneratorTests {
6072 contents.shouldSyntacticSanityCheck()
6173 val expectedContents =
6274 """
63- extension MyUnion: Swift.Codable {
75+ extension ExampleClientTypes. MyUnion: Swift.Codable {
6476 enum CodingKeys: Swift.String, Swift.CodingKey {
6577 case blobvalue = "blobValue"
6678 case booleanvalue = "booleanValue"
@@ -141,7 +153,7 @@ class UnionEncodeGeneratorTests {
141153 self = .inheritedtimestamp(inheritedtimestamp)
142154 return
143155 }
144- let enumvalueDecoded = try values.decodeIfPresent(FooEnum.self, forKey: .enumvalue)
156+ let enumvalueDecoded = try values.decodeIfPresent(ExampleClientTypes. FooEnum.self, forKey: .enumvalue)
145157 if let enumvalue = enumvalueDecoded {
146158 self = .enumvalue(enumvalue)
147159 return
@@ -174,7 +186,7 @@ class UnionEncodeGeneratorTests {
174186 self = .mapvalue(mapvalue)
175187 return
176188 }
177- let structurevalueDecoded = try values.decodeIfPresent(GreetingWithErrorsOutput.self, forKey: .structurevalue)
189+ let structurevalueDecoded = try values.decodeIfPresent(ExampleClientTypes. GreetingWithErrorsOutput.self, forKey: .structurevalue)
178190 if let structurevalue = structurevalueDecoded {
179191 self = .structurevalue(structurevalue)
180192 return
@@ -185,4 +197,48 @@ class UnionEncodeGeneratorTests {
185197 """ .trimIndent()
186198 contents.shouldContainOnlyOnce(expectedContents)
187199 }
200+
201+ @Test
202+ fun `it generates codable conformance for a recursive union` () {
203+ val contents = getModelFileContents(" example" , " IndirectEnum+Codable.swift" , newTestContext.manifest)
204+ contents.shouldSyntacticSanityCheck()
205+ val expectedContents =
206+ """
207+ extension ExampleClientTypes.IndirectEnum: Swift.Codable {
208+ enum CodingKeys: Swift.String, Swift.CodingKey {
209+ case other
210+ case sdkUnknown
211+ case some
212+ }
213+
214+ public func encode(to encoder: Swift.Encoder) throws {
215+ var container = encoder.container(keyedBy: CodingKeys.self)
216+ switch self {
217+ case let .other(other):
218+ try container.encode(other, forKey: .other)
219+ case let .some(some):
220+ try container.encode(some, forKey: .some)
221+ case let .sdkUnknown(sdkUnknown):
222+ try container.encode(sdkUnknown, forKey: .sdkUnknown)
223+ }
224+ }
225+
226+ public init (from decoder: Swift.Decoder) throws {
227+ let values = try decoder.container(keyedBy: CodingKeys.self)
228+ let someDecoded = try values.decodeIfPresent(ExampleClientTypes.IndirectEnum.self, forKey: .some)
229+ if let some = someDecoded {
230+ self = .some(some)
231+ return
232+ }
233+ let otherDecoded = try values.decodeIfPresent(Swift.String.self, forKey: .other)
234+ if let other = otherDecoded {
235+ self = .other(other)
236+ return
237+ }
238+ self = .sdkUnknown("")
239+ }
240+ }
241+ """ .trimIndent()
242+ contents.shouldContainOnlyOnce(expectedContents)
243+ }
188244}
0 commit comments