@@ -7,7 +7,22 @@ package software.amazon.smithy.swift.codegen.codegencomponents
77
88import org.junit.jupiter.api.Assertions.assertEquals
99import org.junit.jupiter.api.Test
10+ import org.junit.jupiter.api.assertThrows
11+ import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
12+ import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
13+ import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
14+ import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
15+ import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
16+ import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
17+ import software.amazon.smithy.model.Model
18+ import software.amazon.smithy.model.knowledge.ServiceIndex
19+ import software.amazon.smithy.model.shapes.ServiceShape
1020import software.amazon.smithy.model.shapes.ShapeId
21+ import software.amazon.smithy.model.shapes.StringShape
22+ import software.amazon.smithy.model.traits.ProtocolDefinitionTrait
23+ import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
24+ import software.amazon.smithy.swift.codegen.SwiftSettings
25+ import software.amazon.smithy.swift.codegen.UnresolvableProtocolException
1126import software.amazon.smithy.swift.codegen.asSmithy
1227import software.amazon.smithy.swift.codegen.defaultSettings
1328
@@ -25,4 +40,217 @@ class SwiftSettingsTest {
2540 assertEquals(" https://github.com/aws-amplify/amplify-codegen.git" , settings.gitRepo)
2641 assertEquals(false , settings.mergeModels)
2742 }
43+
44+ // Smithy Protocol Selection Tests
45+
46+ // Row 1: SDK supports all protocols
47+ private val allProtocolsSupported =
48+ setOf (
49+ Rpcv2CborTrait .ID ,
50+ AwsJson1_0Trait .ID ,
51+ AwsJson1_1Trait .ID ,
52+ RestJson1Trait .ID ,
53+ RestXmlTrait .ID ,
54+ AwsQueryTrait .ID ,
55+ Ec2QueryTrait .ID ,
56+ )
57+
58+ @Test
59+ fun `when SDK supports all protocols and service has rpcv2Cbor and awsJson1_0 then resolves rpcv2Cbor` () {
60+ val settings = createTestSettings()
61+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID ))
62+ val serviceIndex = createServiceIndex(service)
63+
64+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
65+
66+ assertEquals(Rpcv2CborTrait .ID , resolvedProtocol)
67+ }
68+
69+ @Test
70+ fun `when SDK supports all protocols and service has only rpcv2Cbor then resolves rpcv2Cbor` () {
71+ val settings = createTestSettings()
72+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID ))
73+ val serviceIndex = createServiceIndex(service)
74+
75+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
76+
77+ assertEquals(Rpcv2CborTrait .ID , resolvedProtocol)
78+ }
79+
80+ @Test
81+ fun `when SDK supports all protocols and service has rpcv2Cbor awsJson1_0 and awsQuery then resolves rpcv2Cbor` () {
82+ val settings = createTestSettings()
83+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
84+ val serviceIndex = createServiceIndex(service)
85+
86+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
87+
88+ assertEquals(Rpcv2CborTrait .ID , resolvedProtocol)
89+ }
90+
91+ @Test
92+ fun `when SDK supports all protocols and service has awsJson1_0 and awsQuery then resolves awsJson1_0` () {
93+ val settings = createTestSettings()
94+ val service = createServiceWithProtocols(setOf (AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
95+ val serviceIndex = createServiceIndex(service)
96+
97+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
98+
99+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
100+ }
101+
102+ @Test
103+ fun `when SDK supports all protocols and service has only awsQuery then resolves awsQuery` () {
104+ val settings = createTestSettings()
105+ val service = createServiceWithProtocols(setOf (AwsQueryTrait .ID ))
106+ val serviceIndex = createServiceIndex(service)
107+
108+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, allProtocolsSupported)
109+
110+ assertEquals(AwsQueryTrait .ID , resolvedProtocol)
111+ }
112+
113+ // Row 2: SDK does not support rpcv2Cbor
114+ private val withoutRpcv2CborSupport =
115+ setOf (
116+ AwsJson1_0Trait .ID ,
117+ AwsJson1_1Trait .ID ,
118+ RestJson1Trait .ID ,
119+ RestXmlTrait .ID ,
120+ AwsQueryTrait .ID ,
121+ Ec2QueryTrait .ID ,
122+ )
123+
124+ @Test
125+ fun `when SDK does not support rpcv2Cbor and service has rpcv2Cbor and awsJson1_0 then resolves awsJson1_0` () {
126+ val settings = createTestSettings()
127+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID ))
128+ val serviceIndex = createServiceIndex(service)
129+
130+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
131+
132+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
133+ }
134+
135+ @Test
136+ fun `when SDK does not support rpcv2Cbor and service has only rpcv2Cbor then throws exception` () {
137+ val settings = createTestSettings()
138+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID ))
139+ val serviceIndex = createServiceIndex(service)
140+
141+ assertThrows<UnresolvableProtocolException > {
142+ settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
143+ }
144+ }
145+
146+ @Test
147+ fun `when SDK does not support rpcv2Cbor and service has rpcv2Cbor awsJson1_0 and awsQuery then resolves awsJson1_0` () {
148+ val settings = createTestSettings()
149+ val service = createServiceWithProtocols(setOf (Rpcv2CborTrait .ID , AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
150+ val serviceIndex = createServiceIndex(service)
151+
152+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
153+
154+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
155+ }
156+
157+ @Test
158+ fun `when SDK does not support rpcv2Cbor and service has awsJson1_0 and awsQuery then resolves awsJson1_0` () {
159+ val settings = createTestSettings()
160+ val service = createServiceWithProtocols(setOf (AwsJson1_0Trait .ID , AwsQueryTrait .ID ))
161+ val serviceIndex = createServiceIndex(service)
162+
163+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
164+
165+ assertEquals(AwsJson1_0Trait .ID , resolvedProtocol)
166+ }
167+
168+ @Test
169+ fun `when SDK does not support rpcv2Cbor and service has only awsQuery then resolves awsQuery` () {
170+ val settings = createTestSettings()
171+ val service = createServiceWithProtocols(setOf (AwsQueryTrait .ID ))
172+ val serviceIndex = createServiceIndex(service)
173+
174+ val resolvedProtocol = settings.resolveServiceProtocol(serviceIndex, service, withoutRpcv2CborSupport)
175+
176+ assertEquals(AwsQueryTrait .ID , resolvedProtocol)
177+ }
178+
179+ // Helper functions
180+
181+ private fun createTestSettings (): SwiftSettings =
182+ SwiftSettings (
183+ service = ShapeId .from(" test#TestService" ),
184+ moduleName = " TestModule" ,
185+ moduleVersion = " 1.0.0" ,
186+ moduleDescription = " Test module" ,
187+ author = " Test Author" ,
188+ homepage = " https://test.com" ,
189+ sdkId = " Test" ,
190+ gitRepo = " https://github.com/test/test.git" ,
191+ swiftVersion = " 5.7" ,
192+ mergeModels = false ,
193+ copyrightNotice = " // Test copyright" ,
194+ )
195+
196+ private fun createServiceWithProtocols (protocols : Set <ShapeId >): ServiceShape {
197+ var builder =
198+ ServiceShape
199+ .builder()
200+ .id(" test#TestService" )
201+ .version(" 1.0" )
202+
203+ // Apply the actual protocol traits to the service
204+ for (protocolId in protocols) {
205+ when (protocolId) {
206+ Rpcv2CborTrait .ID -> builder = builder.addTrait(Rpcv2CborTrait .builder().build())
207+ AwsJson1_0Trait .ID -> builder = builder.addTrait(AwsJson1_0Trait .builder().build())
208+ AwsJson1_1Trait .ID -> builder = builder.addTrait(AwsJson1_1Trait .builder().build())
209+ RestJson1Trait .ID -> builder = builder.addTrait(RestJson1Trait .builder().build())
210+ RestXmlTrait .ID -> builder = builder.addTrait(RestXmlTrait .builder().build())
211+ AwsQueryTrait .ID -> builder = builder.addTrait(AwsQueryTrait ())
212+ Ec2QueryTrait .ID -> builder = builder.addTrait(Ec2QueryTrait ())
213+ }
214+ }
215+
216+ return builder.build()
217+ }
218+
219+ private fun createServiceIndex (service : ServiceShape ): ServiceIndex {
220+ val modelBuilder = Model .builder()
221+
222+ // Add the service shape
223+ modelBuilder.addShape(service)
224+
225+ // Add protocol definition shapes to the model
226+ // These are needed for ServiceIndex to recognize the protocols
227+ val protocolShapes =
228+ listOf (
229+ Rpcv2CborTrait .ID ,
230+ AwsJson1_0Trait .ID ,
231+ AwsJson1_1Trait .ID ,
232+ RestJson1Trait .ID ,
233+ RestXmlTrait .ID ,
234+ AwsQueryTrait .ID ,
235+ Ec2QueryTrait .ID ,
236+ )
237+
238+ for (protocolId in protocolShapes) {
239+ // Create a shape that represents the protocol definition
240+ // and add the ProtocolDefinitionTrait to it
241+ val protocolShape =
242+ StringShape
243+ .builder()
244+ .id(protocolId)
245+ .addTrait(
246+ ProtocolDefinitionTrait
247+ .builder()
248+ .build(),
249+ ).build()
250+ modelBuilder.addShape(protocolShape)
251+ }
252+
253+ val model = modelBuilder.build()
254+ return ServiceIndex .of(model)
255+ }
28256}
0 commit comments