@@ -3,25 +3,55 @@ package software.amazon.smithy.kotlin.codegen.rendering.smoketests
33import software.amazon.smithy.codegen.core.Symbol
44import software.amazon.smithy.kotlin.codegen.core.*
55import software.amazon.smithy.kotlin.codegen.integration.SectionId
6+ import software.amazon.smithy.kotlin.codegen.integration.SectionKey
67import software.amazon.smithy.kotlin.codegen.model.getTrait
78import software.amazon.smithy.kotlin.codegen.model.hasTrait
9+ import software.amazon.smithy.kotlin.codegen.model.isStringEnumShape
10+ import software.amazon.smithy.kotlin.codegen.rendering.endpoints.EndpointParametersGenerator
11+ import software.amazon.smithy.kotlin.codegen.rendering.endpoints.EndpointProviderGenerator
12+ import software.amazon.smithy.kotlin.codegen.rendering.protocol.stringToNumber
13+ import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.EndpointParams
14+ import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.EndpointProvider
15+ import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.Name
16+ import software.amazon.smithy.kotlin.codegen.rendering.smoketests.SmokeTestSectionIds.ClientConfig.Value
817import software.amazon.smithy.kotlin.codegen.rendering.util.format
918import software.amazon.smithy.kotlin.codegen.utils.dq
1019import software.amazon.smithy.kotlin.codegen.utils.toCamelCase
20+ import software.amazon.smithy.kotlin.codegen.utils.toPascalCase
1121import software.amazon.smithy.kotlin.codegen.utils.topDownOperations
12- import software.amazon.smithy.model.shapes.OperationShape
22+ import software.amazon.smithy.model.node.*
23+ import software.amazon.smithy.model.shapes.*
1324import software.amazon.smithy.smoketests.traits.SmokeTestCase
1425import software.amazon.smithy.smoketests.traits.SmokeTestsTrait
1526import kotlin.jvm.optionals.getOrNull
1627
17- object SmokeTestsRunner : SectionId
18- object SmokeTestAdditionalEnvVars : SectionId
19- object SmokeTestDefaultConfig : SectionId
20- object SmokeTestRegionDefault : SectionId
21- object SmokeTestHttpEngineOverride : SectionId
28+ // Section IDs
29+ object SmokeTestSectionIds {
30+ object AdditionalEnvironmentVariables : SectionId
31+ object DefaultClientConfig : SectionId
32+ object HttpEngineOverride : SectionId
33+ object ServiceFilter : SectionId
34+ object SkipTags : SectionId
35+ object ClientConfig : SectionId {
36+ val Name : SectionKey <String > = SectionKey (" aws.smithy.kotlin#SmokeTestClientConfigName" )
37+ val Value : SectionKey <String > = SectionKey (" aws.smithy.kotlin#SmokeTestClientConfigValue" )
38+ val EndpointProvider : SectionKey <Symbol > = SectionKey (" aws.smithy.kotlin#SmokeTestEndpointProvider" )
39+ val EndpointParams : SectionKey <Symbol > = SectionKey (" aws.smithy.kotlin#SmokeTestClientEndpointParams" )
40+ }
41+ }
2242
23- const val SKIP_TAGS = " AWS_SMOKE_TEST_SKIP_TAGS"
24- const val SERVICE_FILTER = " AWS_SMOKE_TEST_SERVICE_IDS"
43+ /* *
44+ * Env var for smoke test runners.
45+ * Should be a comma-delimited list of strings that correspond to tags on the test cases.
46+ * If a test case is tagged with one of the tags indicated by SMOKE_TEST_SKIP_TAGS, it MUST be skipped by the smoke test runner.
47+ */
48+ const val SKIP_TAGS = " SMOKE_TEST_SKIP_TAGS"
49+
50+ /* *
51+ * Env var for smoke test runners.
52+ * Should be a comma-separated list of service identifiers to test.
53+ */
54+ const val SERVICE_FILTER = " SMOKE_TEST_SERVICE_IDS"
2555
2656/* *
2757 * Renders smoke tests runner for a service
@@ -30,36 +60,45 @@ class SmokeTestsRunnerGenerator(
3060 private val writer : KotlinWriter ,
3161 ctx : CodegenContext ,
3262) {
33- private val model = ctx.model
34- private val sdkId = ctx.settings.sdkId
35- private val symbolProvider = ctx.symbolProvider
36- private val service = symbolProvider.toSymbol(model.expectShape(ctx.settings.service))
37- private val operations = ctx.model.topDownOperations(ctx.settings.service).filter { it.hasTrait<SmokeTestsTrait >() }
38-
3963 internal fun render () {
40- writer.declareSection(SmokeTestsRunner ) {
41- write(" private var exitCode = 0" )
42- write(
43- " private val skipTags = #T.System.getenv(#S)?.let { it.split(#S).map { it.trim() }.toSet() } ?: emptySet()" ,
44- RuntimeTypes .Core .Utils .PlatformProvider ,
45- SKIP_TAGS ,
46- " ," ,
47- )
48- write(
49- " private val serviceFilter = #T.System.getenv(#S)?.let { it.split(#S).map { it.trim() }.toSet() } ?: emptySet()" ,
50- RuntimeTypes .Core .Utils .PlatformProvider ,
51- SERVICE_FILTER ,
52- " ," ,
53- )
54- declareSection(SmokeTestAdditionalEnvVars )
55- write(" " )
56- withBlock(" public suspend fun main() {" , " }" ) {
57- renderFunctionCalls()
58- write(" #T(exitCode)" , RuntimeTypes .Core .SmokeTests .exitProcess)
59- }
60- write(" " )
61- renderFunctions()
64+ writer.write(" private var exitCode = 0" )
65+ renderEnvironmentVariables()
66+ writer.declareSection(SmokeTestSectionIds .AdditionalEnvironmentVariables )
67+ writer.write(" " )
68+ writer.withBlock(" public suspend fun main() {" , " }" ) {
69+ renderFunctionCalls()
70+ write(" #T(exitCode)" , RuntimeTypes .Core .SmokeTests .exitProcess)
71+ }
72+ writer.write(" " )
73+ renderFunctions()
74+ }
75+
76+ private fun renderEnvironmentVariables () {
77+ // Skip tags
78+ writer.writeInline(
79+ " private val skipTags = #T.System.getenv(" ,
80+ RuntimeTypes .Core .Utils .PlatformProvider ,
81+ )
82+ writer.declareSection(SmokeTestSectionIds .SkipTags ) {
83+ writer.writeInline(" #S" , SKIP_TAGS )
84+ }
85+ writer.write(
86+ " )?.let { it.split(#S).map { it.trim() }.toSet() } ?: emptySet()" ,
87+ " ," ,
88+ )
89+
90+ // Service filter
91+ writer.writeInline(
92+ " private val serviceFilter = #T.System.getenv(" ,
93+ RuntimeTypes .Core .Utils .PlatformProvider ,
94+ )
95+ writer.declareSection(SmokeTestSectionIds .ServiceFilter ) {
96+ writer.writeInline(" #S" , SERVICE_FILTER )
6297 }
98+ writer.write(
99+ " )?.let { it.split(#S).map { it.trim() }.toSet() } ?: emptySet()" ,
100+ " ," ,
101+ )
63102 }
64103
65104 private fun renderFunctionCalls () {
@@ -98,32 +137,45 @@ class SmokeTestsRunnerGenerator(
98137 renderClient(testCase)
99138 renderOperation(operation, testCase)
100139 }
101- withBlock(" catch (e : Exception) {" , " }" ) {
140+ withBlock(" catch (exception : Exception) {" , " }" ) {
102141 renderCatchBlock(testCase)
103142 }
104143 }
105144 }
106145
107146 private fun renderClient (testCase : SmokeTestCase ) {
108147 writer.withInlineBlock(" #L {" , " }" , service) {
109- if (testCase.vendorParams.isPresent) {
110- testCase.vendorParams.get().members.forEach { vendorParam ->
111- if (vendorParam.key.value == " region" ) {
112- writeInline(" #L = " , vendorParam.key.value.toCamelCase())
113- declareSection(SmokeTestRegionDefault )
114- write(" #L" , vendorParam.value.format())
115- } else {
116- write(" #L = #L" , vendorParam.key.value.toCamelCase(), vendorParam.value.format())
117- }
118- }
119- } else {
120- declareSection(SmokeTestDefaultConfig )
121- }
122- val expectingSpecificError = testCase.expectation.failure.getOrNull()?.errorId?.getOrNull() != null
123- if (! expectingSpecificError) {
124- write(" interceptors.add(#T())" , RuntimeTypes .HttpClient .Interceptors .SmokeTestsInterceptor )
148+ renderClientConfig(testCase)
149+ }
150+ }
151+
152+ private fun renderClientConfig (testCase : SmokeTestCase ) {
153+ if (! testCase.expectingSpecificError) {
154+ writer.write(" interceptors.add(#T())" , RuntimeTypes .HttpClient .Interceptors .SmokeTestsInterceptor )
155+ }
156+
157+ writer.declareSection(SmokeTestSectionIds .HttpEngineOverride )
158+
159+ if (! testCase.hasClientConfig) {
160+ writer.declareSection(SmokeTestSectionIds .DefaultClientConfig )
161+ return
162+ }
163+
164+ testCase.clientConfig!! .forEach { config ->
165+ val name = config.key.value.toCamelCase()
166+ val value = config.value.format()
167+
168+ writer.declareSection(
169+ SmokeTestSectionIds .ClientConfig ,
170+ mapOf (
171+ Name to name,
172+ Value to value,
173+ EndpointProvider to EndpointProviderGenerator .getSymbol(settings),
174+ EndpointParams to EndpointParametersGenerator .getSymbol(settings),
175+ ),
176+ ) {
177+ writer.writeInline(" #L = #L" , name, value)
125178 }
126- declareSection(SmokeTestHttpEngineOverride )
127179 }
128180 }
129181
@@ -133,30 +185,97 @@ class SmokeTestsRunnerGenerator(
133185 writer.withBlock(" .#T { client ->" , " }" , RuntimeTypes .Core .IO .use) {
134186 withBlock(" client.#L(" , " )" , operation.defaultName()) {
135187 withBlock(" #L {" , " }" , operationSymbol) {
136- testCase.params.get().members.forEach { member ->
137- write(" #L = #L" , member.key.value.toCamelCase(), member.value.format())
138- }
188+ renderOperationParameters(operation, testCase)
139189 }
140190 }
141191 }
142192 }
143193
194+ private fun renderOperationParameters (operation : OperationShape , testCase : SmokeTestCase ) {
195+ if (! testCase.hasOperationParameters) return
196+
197+ val paramsToShapes = mapOperationParametersToModeledShapes(operation)
198+
199+ testCase.operationParameters.forEach { param ->
200+ val paramName = param.key.value.toCamelCase()
201+ writer.writeInline(" #L = " , paramName)
202+ val paramShape = paramsToShapes[paramName] ? : throw IllegalArgumentException (" Unable to find shape for operation parameter '$paramName ' in smoke test '${testCase.functionName} '." )
203+ renderOperationParameter(paramName, param.value, paramShape, testCase)
204+ }
205+ }
206+
144207 private fun renderCatchBlock (testCase : SmokeTestCase ) {
145- val expected = if (testCase.expectation.isFailure) {
208+ val expectedException = if (testCase.expectation.isFailure) {
146209 getFailureCriterion(testCase)
147210 } else {
148211 RuntimeTypes .HttpClient .Interceptors .SmokeTestsSuccessException
149212 }
150213
151- writer.write(" val success = e is #T" , expected)
152- writer.write(" val status = if (success) #S else #S" , " ok" , " not ok" )
214+ writer.write(" val success: Boolean = exception is #T" , expectedException)
215+ writer.write(" val status: String = if (success) #S else #S" , " ok" , " not ok" )
216+
153217 printTestResult(
154218 sdkId.filter { ! it.isWhitespace() },
155219 testCase.id,
156220 testCase.expectation.isFailure,
157221 writer,
158222 )
159- writer.write(" if (!success) exitCode = 1" )
223+
224+ writer.withBlock(" if (!success) {" , " }" ) {
225+ write(" #T(exception)" , RuntimeTypes .Core .SmokeTests .printExceptionStackTrace)
226+ write(" exitCode = 1" )
227+ }
228+ }
229+
230+ // Helpers
231+ /* *
232+ * Renders a [SmokeTestCase] operation parameter
233+ */
234+ private fun renderOperationParameter (
235+ paramName : String ,
236+ node : Node ,
237+ shape : Shape ,
238+ testCase : SmokeTestCase ,
239+ ) {
240+ when {
241+ // String enum
242+ node is StringNode && shape.isStringEnumShape -> {
243+ val enumSymbol = symbolProvider.toSymbol(shape)
244+ val enumValue = node.value.toPascalCase()
245+ writer.write(" #T.#L" , enumSymbol, enumValue)
246+ }
247+ // Int enum
248+ node is NumberNode && shape is IntEnumShape -> {
249+ val enumSymbol = symbolProvider.toSymbol(shape)
250+ val enumValue = node.format()
251+ writer.write(" #T.fromValue(#L.toInt())" , enumSymbol, enumValue)
252+ }
253+ // Number
254+ node is NumberNode && shape is NumberShape -> writer.write(" #L.#L" , node.format(), stringToNumber(shape))
255+ // Object
256+ node is ObjectNode -> {
257+ val shapeSymbol = symbolProvider.toSymbol(shape)
258+ writer.withBlock(" #T {" , " }" , shapeSymbol) {
259+ node.members.forEach { member ->
260+ val memberName = member.key.value.toCamelCase()
261+ val memberShape = shape.allMembers[member.key.value] ? : throw IllegalArgumentException (" Unable to find shape for operation parameter '$paramName ' in smoke test '${testCase.functionName} '." )
262+ writer.writeInline(" #L = " , memberName)
263+ renderOperationParameter(memberName, member.value, memberShape, testCase)
264+ }
265+ }
266+ }
267+ // List
268+ node is ArrayNode && shape is CollectionShape -> {
269+ writer.withBlock(" listOf(" , " )" ) {
270+ node.elements.forEach { element ->
271+ renderOperationParameter(paramName, element, model.expectShape(shape.member.target), testCase)
272+ writer.write(" ," )
273+ }
274+ }
275+ }
276+ // Everything else
277+ else -> writer.write(" #L" , node.format())
278+ }
160279 }
161280
162281 /* *
@@ -184,10 +303,56 @@ class SmokeTestsRunnerGenerator(
184303 val testResult = " $status $service $testCase - $expectation $directive "
185304 writer.write(" println(#S)" , testResult)
186305 }
187- }
188306
189- /* *
190- * Derives a function name for a [SmokeTestCase]
191- */
192- private val SmokeTestCase .functionName: String
193- get() = this .id.toCamelCase()
307+ /* *
308+ * Maps an operations parameters to their shapes
309+ */
310+ private fun mapOperationParametersToModeledShapes (operation : OperationShape ): Map <String , Shape > =
311+ model.getShape(operation.inputShape).get().allMembers.map { (key, value) ->
312+ key.toCamelCase() to model.getShape(value.target).get()
313+ }.toMap()
314+
315+ /* *
316+ * Derives a function name for a [SmokeTestCase]
317+ */
318+ private val SmokeTestCase .functionName: String
319+ get() = this .id.toCamelCase()
320+
321+ /* *
322+ * Get the operation parameters for a [SmokeTestCase]
323+ */
324+ private val SmokeTestCase .operationParameters: Map <StringNode , Node >
325+ get() = this .params.get().members
326+
327+ /* *
328+ * Checks if there are operation parameters for a [SmokeTestCase]
329+ */
330+ private val SmokeTestCase .hasOperationParameters: Boolean
331+ get() = this .params.isPresent
332+
333+ /* *
334+ * Check if a [SmokeTestCase] is expecting a specific error
335+ */
336+ private val SmokeTestCase .expectingSpecificError: Boolean
337+ get() = this .expectation.failure.getOrNull()?.errorId?.getOrNull() != null
338+
339+ /* *
340+ * Checks if a [SmokeTestCase] requires client configuration
341+ */
342+ private val SmokeTestCase .hasClientConfig: Boolean
343+ get() = this .vendorParams.isPresent
344+
345+ /* *
346+ * Get the client configuration required for a [SmokeTestCase]
347+ */
348+ private val SmokeTestCase .clientConfig: MutableMap <StringNode , Node >?
349+ get() = this .vendorParams.get().members
350+
351+ // Constants
352+ private val model = ctx.model
353+ private val settings = ctx.settings
354+ private val sdkId = settings.sdkId
355+ private val symbolProvider = ctx.symbolProvider
356+ private val service = symbolProvider.toSymbol(model.expectShape(settings.service))
357+ private val operations = model.topDownOperations(settings.service).filter { it.hasTrait<SmokeTestsTrait >() }
358+ }
0 commit comments