Skip to content

Commit 3c0da29

Browse files
authored
chore: add unit test for protocol selection behavior (#1290)
1 parent bcad3b4 commit 3c0da29

File tree

1 file changed

+124
-6
lines changed

1 file changed

+124
-6
lines changed

codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettingsTest.kt

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,22 @@
55

66
package software.amazon.smithy.kotlin.codegen
77

8+
import org.junit.jupiter.api.extension.ExtensionContext
89
import org.junit.jupiter.params.ParameterizedTest
10+
import org.junit.jupiter.params.provider.Arguments
11+
import org.junit.jupiter.params.provider.ArgumentsProvider
12+
import org.junit.jupiter.params.provider.ArgumentsSource
913
import org.junit.jupiter.params.provider.CsvSource
1014
import software.amazon.smithy.codegen.core.CodegenException
1115
import software.amazon.smithy.kotlin.codegen.test.TestModelDefault
1216
import software.amazon.smithy.kotlin.codegen.test.toSmithyModel
17+
import software.amazon.smithy.kotlin.codegen.utils.dq
1318
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
19+
import software.amazon.smithy.model.knowledge.ServiceIndex
1420
import software.amazon.smithy.model.node.Node
1521
import software.amazon.smithy.model.shapes.ShapeId
16-
import java.lang.IllegalArgumentException
17-
import kotlin.test.Test
18-
import kotlin.test.assertEquals
19-
import kotlin.test.assertFailsWith
20-
import kotlin.test.assertFalse
21-
import kotlin.test.assertTrue
22+
import java.util.stream.Stream
23+
import kotlin.test.*
2224

2325
class KotlinSettingsTest {
2426
@Test
@@ -330,4 +332,120 @@ class KotlinSettingsTest {
330332

331333
assertEquals(expected, apiSettings.defaultValueSerializationMode)
332334
}
335+
336+
@ParameterizedTest
337+
@ArgumentsSource(TestProtocolSelectionArgumentProvider::class)
338+
fun testProtocolSelection(
339+
protocolPriorityCsv: String,
340+
serviceProtocolsCsv: String,
341+
expectedProtocolName: String?,
342+
) {
343+
val serviceProtocols = serviceProtocolsCsv.csvToProtocolList()
344+
val serviceProtocolImports = serviceProtocols.joinToString("\n") { "use $it" }
345+
val serviceProtocolTraits = serviceProtocols.joinToString("\n") { "@${it.name}" }
346+
val supportedProtocols = protocolPriorityCsv.csvToProtocolList().toSet()
347+
val protocolPriorityList = supportedProtocols.joinToString(", ") { it.toString().dq() }
348+
349+
val model = """
350+
|namespace com.test
351+
|
352+
|$serviceProtocolImports
353+
|
354+
|$serviceProtocolTraits
355+
|@xmlNamespace(uri: "http://test.com") // required for @awsQuery
356+
|service Test {
357+
| version: "1.0.0"
358+
|}
359+
""".trimMargin().toSmithyModel()
360+
val service = model.serviceShapes.single()
361+
val serviceIndex = ServiceIndex.of(model)
362+
363+
val contents = """
364+
{
365+
"package": {
366+
"name": "name",
367+
"version": "1.0.0"
368+
},
369+
"api": {
370+
"protocolResolutionPriority": [ $protocolPriorityList ]
371+
}
372+
}
373+
""".trimIndent()
374+
val settings = KotlinSettings.from(model, Node.parse(contents).expectObjectNode())
375+
376+
val expectedProtocol = expectedProtocolName?.nameToProtocol()
377+
val actualProtocol = runCatching {
378+
settings.resolveServiceProtocol(serviceIndex, service, supportedProtocols)
379+
}.getOrElse { null }
380+
381+
assertEquals(expectedProtocol, actualProtocol)
382+
}
383+
}
384+
385+
/**
386+
* A junit [ArgumentsProvider] which supplies protocol selection parameterized test values sourced from the Smithy RPCv2
387+
* CBOR Support SEP § Smithy protocol selection tests.
388+
*/
389+
class TestProtocolSelectionArgumentProvider : ArgumentsProvider {
390+
companion object {
391+
private const val ALL_PROTOCOLS = "rpcv2Cbor, awsJson1_0, awsJson1_1, restJson1, restXml, awsQuery, ec2Query"
392+
private const val NO_CBOR = "awsJson1_0, awsJson1_1, restJson1, restXml, awsQuery, ec2Query"
393+
}
394+
395+
override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> = Stream.of(
396+
Arguments.of(
397+
ALL_PROTOCOLS,
398+
"rpcv2Cbor, awsJson1_0",
399+
"rpcv2Cbor",
400+
),
401+
Arguments.of(
402+
ALL_PROTOCOLS,
403+
"rpcv2Cbor",
404+
"rpcv2Cbor",
405+
),
406+
Arguments.of(
407+
ALL_PROTOCOLS,
408+
"rpcv2Cbor, awsJson1_0, awsQuery",
409+
"rpcv2Cbor",
410+
),
411+
Arguments.of(
412+
ALL_PROTOCOLS,
413+
"awsJson1_0, awsQuery",
414+
"awsJson1_0",
415+
),
416+
Arguments.of(
417+
ALL_PROTOCOLS,
418+
"awsQuery",
419+
"awsQuery",
420+
),
421+
Arguments.of(
422+
NO_CBOR,
423+
"rpcv2Cbor, awsJson1_0",
424+
"awsJson1_0",
425+
),
426+
Arguments.of(
427+
NO_CBOR,
428+
"rpcv2Cbor",
429+
null,
430+
),
431+
Arguments.of(
432+
NO_CBOR,
433+
"rpcv2Cbor, awsJson1_0, awsQuery",
434+
"awsJson1_0",
435+
),
436+
Arguments.of(
437+
NO_CBOR,
438+
"awsJson1_0, awsQuery",
439+
"awsJson1_0",
440+
),
441+
Arguments.of(
442+
NO_CBOR,
443+
"awsQuery",
444+
"awsQuery",
445+
),
446+
)
333447
}
448+
449+
private val allProtocols = ApiSettings().protocolResolutionPriority
450+
private fun String.nameToProtocol() = allProtocols.single { protocol -> protocol.name == this }
451+
private fun String.csvToProtocolList() = split(",").map(String::trim).map(String::nameToProtocol)

0 commit comments

Comments
 (0)