|
5 | 5 |
|
6 | 6 | package software.amazon.smithy.kotlin.codegen |
7 | 7 |
|
| 8 | +import org.junit.jupiter.api.extension.ExtensionContext |
8 | 9 | import org.junit.jupiter.params.ParameterizedTest |
| 10 | +import org.junit.jupiter.params.provider.Arguments |
| 11 | +import org.junit.jupiter.params.provider.ArgumentsProvider |
| 12 | +import org.junit.jupiter.params.provider.ArgumentsSource |
9 | 13 | import org.junit.jupiter.params.provider.CsvSource |
10 | 14 | import software.amazon.smithy.codegen.core.CodegenException |
11 | 15 | import software.amazon.smithy.kotlin.codegen.test.TestModelDefault |
12 | 16 | import software.amazon.smithy.kotlin.codegen.test.toSmithyModel |
| 17 | +import software.amazon.smithy.kotlin.codegen.utils.dq |
13 | 18 | import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode |
| 19 | +import software.amazon.smithy.model.knowledge.ServiceIndex |
14 | 20 | import software.amazon.smithy.model.node.Node |
15 | 21 | import software.amazon.smithy.model.shapes.ShapeId |
16 | | -import java.lang.IllegalArgumentException |
17 | | -import kotlin.test.Test |
18 | | -import kotlin.test.assertEquals |
19 | | -import kotlin.test.assertFailsWith |
20 | | -import kotlin.test.assertFalse |
21 | | -import kotlin.test.assertTrue |
| 22 | +import java.util.stream.Stream |
| 23 | +import kotlin.test.* |
22 | 24 |
|
23 | 25 | class KotlinSettingsTest { |
24 | 26 | @Test |
@@ -330,4 +332,120 @@ class KotlinSettingsTest { |
330 | 332 |
|
331 | 333 | assertEquals(expected, apiSettings.defaultValueSerializationMode) |
332 | 334 | } |
| 335 | + |
| 336 | + @ParameterizedTest |
| 337 | + @ArgumentsSource(TestProtocolSelectionArgumentProvider::class) |
| 338 | + fun testProtocolSelection( |
| 339 | + protocolPriorityCsv: String, |
| 340 | + serviceProtocolsCsv: String, |
| 341 | + expectedProtocolName: String?, |
| 342 | + ) { |
| 343 | + val serviceProtocols = serviceProtocolsCsv.csvToProtocolList() |
| 344 | + val serviceProtocolImports = serviceProtocols.joinToString("\n") { "use $it" } |
| 345 | + val serviceProtocolTraits = serviceProtocols.joinToString("\n") { "@${it.name}" } |
| 346 | + val supportedProtocols = protocolPriorityCsv.csvToProtocolList().toSet() |
| 347 | + val protocolPriorityList = supportedProtocols.joinToString(", ") { it.toString().dq() } |
| 348 | + |
| 349 | + val model = """ |
| 350 | + |namespace com.test |
| 351 | + | |
| 352 | + |$serviceProtocolImports |
| 353 | + | |
| 354 | + |$serviceProtocolTraits |
| 355 | + |@xmlNamespace(uri: "http://test.com") // required for @awsQuery |
| 356 | + |service Test { |
| 357 | + | version: "1.0.0" |
| 358 | + |} |
| 359 | + """.trimMargin().toSmithyModel() |
| 360 | + val service = model.serviceShapes.single() |
| 361 | + val serviceIndex = ServiceIndex.of(model) |
| 362 | + |
| 363 | + val contents = """ |
| 364 | + { |
| 365 | + "package": { |
| 366 | + "name": "name", |
| 367 | + "version": "1.0.0" |
| 368 | + }, |
| 369 | + "api": { |
| 370 | + "protocolResolutionPriority": [ $protocolPriorityList ] |
| 371 | + } |
| 372 | + } |
| 373 | + """.trimIndent() |
| 374 | + val settings = KotlinSettings.from(model, Node.parse(contents).expectObjectNode()) |
| 375 | + |
| 376 | + val expectedProtocol = expectedProtocolName?.nameToProtocol() |
| 377 | + val actualProtocol = runCatching { |
| 378 | + settings.resolveServiceProtocol(serviceIndex, service, supportedProtocols) |
| 379 | + }.getOrElse { null } |
| 380 | + |
| 381 | + assertEquals(expectedProtocol, actualProtocol) |
| 382 | + } |
| 383 | +} |
| 384 | + |
| 385 | +/** |
| 386 | + * A junit [ArgumentsProvider] which supplies protocol selection parameterized test values sourced from the Smithy RPCv2 |
| 387 | + * CBOR Support SEP § Smithy protocol selection tests. |
| 388 | + */ |
| 389 | +class TestProtocolSelectionArgumentProvider : ArgumentsProvider { |
| 390 | + companion object { |
| 391 | + private const val ALL_PROTOCOLS = "rpcv2Cbor, awsJson1_0, awsJson1_1, restJson1, restXml, awsQuery, ec2Query" |
| 392 | + private const val NO_CBOR = "awsJson1_0, awsJson1_1, restJson1, restXml, awsQuery, ec2Query" |
| 393 | + } |
| 394 | + |
| 395 | + override fun provideArguments(context: ExtensionContext?): Stream<out Arguments> = Stream.of( |
| 396 | + Arguments.of( |
| 397 | + ALL_PROTOCOLS, |
| 398 | + "rpcv2Cbor, awsJson1_0", |
| 399 | + "rpcv2Cbor", |
| 400 | + ), |
| 401 | + Arguments.of( |
| 402 | + ALL_PROTOCOLS, |
| 403 | + "rpcv2Cbor", |
| 404 | + "rpcv2Cbor", |
| 405 | + ), |
| 406 | + Arguments.of( |
| 407 | + ALL_PROTOCOLS, |
| 408 | + "rpcv2Cbor, awsJson1_0, awsQuery", |
| 409 | + "rpcv2Cbor", |
| 410 | + ), |
| 411 | + Arguments.of( |
| 412 | + ALL_PROTOCOLS, |
| 413 | + "awsJson1_0, awsQuery", |
| 414 | + "awsJson1_0", |
| 415 | + ), |
| 416 | + Arguments.of( |
| 417 | + ALL_PROTOCOLS, |
| 418 | + "awsQuery", |
| 419 | + "awsQuery", |
| 420 | + ), |
| 421 | + Arguments.of( |
| 422 | + NO_CBOR, |
| 423 | + "rpcv2Cbor, awsJson1_0", |
| 424 | + "awsJson1_0", |
| 425 | + ), |
| 426 | + Arguments.of( |
| 427 | + NO_CBOR, |
| 428 | + "rpcv2Cbor", |
| 429 | + null, |
| 430 | + ), |
| 431 | + Arguments.of( |
| 432 | + NO_CBOR, |
| 433 | + "rpcv2Cbor, awsJson1_0, awsQuery", |
| 434 | + "awsJson1_0", |
| 435 | + ), |
| 436 | + Arguments.of( |
| 437 | + NO_CBOR, |
| 438 | + "awsJson1_0, awsQuery", |
| 439 | + "awsJson1_0", |
| 440 | + ), |
| 441 | + Arguments.of( |
| 442 | + NO_CBOR, |
| 443 | + "awsQuery", |
| 444 | + "awsQuery", |
| 445 | + ), |
| 446 | + ) |
333 | 447 | } |
| 448 | + |
| 449 | +private val allProtocols = ApiSettings().protocolResolutionPriority |
| 450 | +private fun String.nameToProtocol() = allProtocols.single { protocol -> protocol.name == this } |
| 451 | +private fun String.csvToProtocolList() = split(",").map(String::trim).map(String::nameToProtocol) |
0 commit comments