diff --git a/.brazil.json b/.brazil.json index 0908ffedb4..987379aec3 100644 --- a/.brazil.json +++ b/.brazil.json @@ -5,10 +5,12 @@ "com.squareup.okhttp3:okhttp-coroutines:5.*": "OkHttp3Coroutines-5.x", "com.squareup.okhttp3:okhttp:5.*": "OkHttp3-5.x", + "com.squareup.okhttp3:okhttp-jvm:5.*": "OkHttp3-5.x", "com.squareup.okio:okio-jvm:3.*": "OkioJvm-3.x", "io.opentelemetry:opentelemetry-api:1.*": "Maven-io-opentelemetry_opentelemetry-api-1.x", "io.opentelemetry:opentelemetry-extension-kotlin:1.*": "Maven-io-opentelemetry_opentelemetry-extension-kotlin-1.x", "org.slf4j:slf4j-api:2.*": "Maven-org-slf4j_slf4j-api-2.x", + "aws.sdk.kotlin.crt:aws-crt-kotlin:0.10.*": "AwsCrtKotlin-0.10.x", "aws.sdk.kotlin.crt:aws-crt-kotlin:0.9.*": "AwsCrtKotlin-0.9.x", "aws.sdk.kotlin.crt:aws-crt-kotlin:0.8.*": "AwsCrtKotlin-0.8.x", "com.squareup.okhttp3:okhttp:4.*": "OkHttp3-4.x", diff --git a/.github/workflows/merge-main.yml b/.github/workflows/merge-main.yml index db1df3a142..ae4af7a055 100644 --- a/.github/workflows/merge-main.yml +++ b/.github/workflows/merge-main.yml @@ -12,4 +12,4 @@ jobs: uses: awslabs/aws-kotlin-repo-tools/.github/actions/merge-main@main with: ci-user-pat: ${{ secrets.CI_USER_PAT }} - exempt-branches: # Add any if required \ No newline at end of file + exempt-branches: # Add any if required diff --git a/CHANGELOG.md b/CHANGELOG.md index ad34f18d2e..eb355ef808 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## [1.5.1] - 07/17/2025 + +## [1.5.0] - 07/17/2025 + +### Features +* Upgrade to Kotlin 2.2.0 +* [#1413](https://github.com/awslabs/aws-sdk-kotlin/issues/1413) ⚠️ **IMPORTANT**: Refactor endpoint discoverer classes into interfaces so custom implementations may be provided + +### Fixes +* [#1311](https://github.com/smithy-lang/smithy-kotlin/issues/1311) Reimplement idle connection monitoring using `okhttp3.EventListener` instead of now-internal `okhttp3.ConnectionListener` +* [#1608](https://github.com/awslabs/aws-sdk-kotlin/issues/1608) Switch to always serialize defaults in requests. Previously fields were not serialized in requests if they weren't `@required` and hadn't been changed from the default value. +* [#1413](https://github.com/awslabs/aws-sdk-kotlin/issues/1413) Favor `endpointUrl` instead of endpoint discovery if both are provided + +### Miscellaneous +* Add telemetry provider configuration to `DefaultAwsSigner` + ## [1.4.23] - 07/15/2025 ## [1.4.22] - 07/02/2025 diff --git a/bom/build.gradle.kts b/bom/build.gradle.kts index 51b05616c1..8af09dbf58 100644 --- a/bom/build.gradle.kts +++ b/bom/build.gradle.kts @@ -7,8 +7,6 @@ import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension import org.jetbrains.kotlin.gradle.plugin.KotlinMultiplatformPluginWrapper import org.jetbrains.kotlin.gradle.plugin.KotlinTarget import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinMetadataTarget -import org.jetbrains.kotlin.gradle.targets.js.KotlinJsTarget -import java.util.* plugins { `maven-publish` @@ -52,7 +50,6 @@ fun createBomConstraintsAndVersionCatalog() { fun Project.artifactId(target: KotlinTarget): String = when (target) { is KotlinMetadataTarget -> name - is KotlinJsTarget -> "$name-js" else -> "$name-${target.targetName.lowercase()}" } diff --git a/build.gradle.kts b/build.gradle.kts index bdec55641a..8b7a33571e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -114,6 +114,7 @@ apiValidation { "nullability-tests", "paginator-tests", "waiter-tests", + "service-codegen-tests", "compile", "slf4j-1x-consumer", "slf4j-2x-consumer", diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/customization/RegionSupport.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/customization/RegionSupport.kt index f144c54f9f..9b25eaaa29 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/customization/RegionSupport.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/customization/RegionSupport.kt @@ -40,7 +40,27 @@ class RegionSupport : KotlinIntegration { name = "region" symbol = KotlinTypes.String.toBuilder().nullable().build() documentation = """ - The region to sign with and make requests to. + The AWS region to sign with and make requests to. When specified, this static region configuration + takes precedence over other region resolution methods. + + The region resolution order is: + 1. Static region (if specified) + 2. Custom region provider (if configured) + 3. Default region provider chain + """.trimIndent() + } + + val RegionProviderProp: ConfigProperty = ConfigProperty { + name = "regionProvider" + symbol = RuntimeTypes.SmithyClient.Region.RegionProvider + documentation = """ + An optional region provider that determines the AWS region for client operations. When specified, this provider + takes precedence over the default region provider chain, unless a static region is explicitly configured. + + The region resolution order is: + 1. Static region (if specified) + 2. Custom region provider (if configured) + 3. Default region provider chain """.trimIndent() } } @@ -57,7 +77,7 @@ class RegionSupport : KotlinIntegration { return supportsSigv4 || hasRegionBuiltin || isAwsSdk } - override fun additionalServiceConfigProps(ctx: CodegenContext): List = listOf(RegionProp) + override fun additionalServiceConfigProps(ctx: CodegenContext): List = listOf(RegionProp, RegionProviderProp) override fun customizeEndpointResolution(ctx: ProtocolGenerator.GenerationContext): EndpointCustomization = object : EndpointCustomization { diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt index 48d47a7c18..935eba31ab 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt @@ -7,17 +7,28 @@ package software.amazon.smithy.kotlin.codegen.aws.protocols import software.amazon.smithy.aws.traits.protocols.AwsQueryErrorTrait import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AbstractQueryFormUrlSerializerGenerator import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AwsHttpBindingProtocolGenerator import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator -import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RenderingContext +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.withBlock import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes -import software.amazon.smithy.kotlin.codegen.model.* -import software.amazon.smithy.kotlin.codegen.rendering.protocol.* -import software.amazon.smithy.kotlin.codegen.rendering.serde.* +import software.amazon.smithy.kotlin.codegen.model.buildSymbol +import software.amazon.smithy.kotlin.codegen.model.getTrait +import software.amazon.smithy.kotlin.codegen.model.hasTrait +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext +import software.amazon.smithy.kotlin.codegen.rendering.serde.FormUrlSerdeDescriptorGenerator +import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataParserGenerator +import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataSerializerGenerator +import software.amazon.smithy.kotlin.codegen.rendering.serde.XmlParserGenerator import software.amazon.smithy.model.shapes.* -import software.amazon.smithy.model.traits.* +import software.amazon.smithy.model.traits.XmlFlattenedTrait +import software.amazon.smithy.model.traits.XmlNameTrait /** * Handles generating the aws.protocols#awsQuery protocol for services. @@ -45,7 +56,7 @@ class AwsQuery : QueryHttpBindingProtocolGenerator() { writer: KotlinWriter, ) { writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""") - writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponseNoSuspend) + writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponse) } } @@ -76,6 +87,14 @@ private class AwsQuerySerializerGenerator( members: List, writer: KotlinWriter, ): FormUrlSerdeDescriptorGenerator = AwsQuerySerdeFormUrlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members) + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + TODO("Used for service-codegen. Not yet implemented") + } } private class AwsQueryXmlParserGenerator( diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt index 10be8bfc6e..38e74d010d 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt @@ -10,7 +10,10 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AbstractQueryFormUrlSerializerGenerator import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator -import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RenderingContext +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.withBlock import software.amazon.smithy.kotlin.codegen.model.buildSymbol import software.amazon.smithy.kotlin.codegen.model.getTrait import software.amazon.smithy.kotlin.codegen.model.isNullable @@ -39,7 +42,7 @@ class Ec2Query : QueryHttpBindingProtocolGenerator() { writer: KotlinWriter, ) { writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""") - writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseEc2QueryErrorResponseNoSuspend) + writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseEc2QueryErrorResponse) } } @@ -95,6 +98,14 @@ private class Ec2QuerySerializerGenerator( members: List, writer: KotlinWriter, ): FormUrlSerdeDescriptorGenerator = Ec2QuerySerdeFormUrlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members) + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + TODO("Used for service-codegen. Not yet implemented") + } } private class Ec2QueryParserGenerator( diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt index dcf2f0a814..db6591738f 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt @@ -40,9 +40,11 @@ class RestJson1 : JsonHttpBindingProtocolGenerator() { writer: KotlinWriter, ) { super.renderSerializeHttpBody(ctx, op, writer) + if (ctx.settings.build.generateServiceProject) return val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) - if (!resolver.hasHttpBody(op)) return + + if (!resolver.hasHttpRequestBody(op)) return // restjson1 has some different semantics and expectations around empty structures bound via @httpPayload trait // * empty structures get serialized to `{}` diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestXml.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestXml.kt index 7c1f45b92c..c7c834c317 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestXml.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestXml.kt @@ -63,7 +63,7 @@ open class RestXml : AwsHttpBindingProtocolGenerator() { writer: KotlinWriter, ) { writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""") - writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponseNoSuspend) + writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponse) } } diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt index 117574da40..ff09bc0d5b 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt @@ -103,19 +103,30 @@ class RpcV2Cbor : AwsHttpBindingProtocolGenerator() { if (!op.hasHttpBody(ctx)) return // payload member(s) - val requestBindings = resolver.requestBindings(op) - val httpPayload = requestBindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD } + val bindings = if (ctx.settings.build.generateServiceProject) { + resolver.responseBindings(op) + } else { + resolver.requestBindings(op) + } + val httpPayload = bindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD } + if (httpPayload != null) { renderExplicitHttpPayloadSerializer(ctx, httpPayload, writer) } else { - val documentMembers = requestBindings.filterDocumentBoundMembers() + val documentMembers = bindings.filterDocumentBoundMembers() // Unbound document members that should be serialized into the document format for the protocol. // delegate to the generate operation body serializer function val sdg = structuredDataSerializer(ctx) val opBodySerializerFn = sdg.operationSerializer(ctx, op, documentMembers) - writer.write("builder.body = #T(context, input)", opBodySerializerFn) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = #T(context, input)", opBodySerializerFn) + } else { + writer.write("builder.body = #T(context, input)", opBodySerializerFn) + } + } + if (!ctx.settings.build.generateServiceProject) { + renderContentTypeHeader(ctx, op, writer, resolver) } - renderContentTypeHeader(ctx, op, writer, resolver) } /** diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/StaticHttpBindingResolver.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/StaticHttpBindingResolver.kt index f549d0ed4c..06155287d6 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/StaticHttpBindingResolver.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/StaticHttpBindingResolver.kt @@ -56,10 +56,16 @@ open class StaticHttpBindingResolver( /** * By default returns all inputs as [HttpBinding.Location.DOCUMENT] */ - override fun requestBindings(operationShape: OperationShape): List { - if (!operationShape.input.isPresent) return emptyList() - val input = model.expectShape(operationShape.input.get()) - return input.members().map { member -> HttpBindingDescriptor(member, HttpBinding.Location.DOCUMENT) }.toList() + override fun requestBindings(shape: Shape): List { + when (shape) { + is OperationShape -> { + if (!shape.input.isPresent) return emptyList() + val input = model.expectShape(shape.input.get()) + return input.members().map { member -> HttpBindingDescriptor(member, HttpBinding.Location.DOCUMENT) }.toList() + } + is StructureShape -> return shape.members().map { member -> member.toHttpBindingDescriptor() }.toList() + else -> error("unimplemented shape type for http response bindings: $shape") + } } /** diff --git a/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/customization/RegionSupportTest.kt b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/customization/RegionSupportTest.kt new file mode 100644 index 0000000000..2df3e519f0 --- /dev/null +++ b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/customization/RegionSupportTest.kt @@ -0,0 +1,73 @@ +package software.amazon.smithy.kotlin.codegen.aws.customization + +import org.junit.jupiter.api.Test +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.model.expectShape +import software.amazon.smithy.kotlin.codegen.rendering.ServiceClientConfigGenerator +import software.amazon.smithy.kotlin.codegen.test.* +import software.amazon.smithy.model.shapes.ServiceShape + +class RegionSupportTest { + @Test + fun testRegionSupportProperties() { + val model = """ + namespace com.test + + use aws.protocols#awsJson1_1 + use aws.api#service + use aws.auth#sigv4 + + @service(sdkId: "service with overrides", endpointPrefix: "service-with-overrides") + @sigv4(name: "example") + @awsJson1_1 + service Example { + version: "1.0.0", + operations: [GetFoo] + } + + operation GetFoo {} + """.toSmithyModel() + + val serviceShape = model.expectShape("com.test#Example") + + val testCtx = model.newTestContext(serviceName = "Example") + val writer = KotlinWriter("com.test") + + val renderingCtx = testCtx.toRenderingContext(writer, serviceShape) + .copy(integrations = listOf(RegionSupport())) + + ServiceClientConfigGenerator(serviceShape, detectDefaultProps = false).render(renderingCtx, renderingCtx.writer) + val contents = writer.toString() + + val expectedProps = """ + public val region: String? = builder.region + public val regionProvider: RegionProvider? = builder.regionProvider + """.formatForTest() + contents.shouldContainOnlyOnceWithDiff(expectedProps) + + val expectedImpl = """ + /** + * The AWS region to sign with and make requests to. When specified, this static region configuration + * takes precedence over other region resolution methods. + * + * The region resolution order is: + * 1. Static region (if specified) + * 2. Custom region provider (if configured) + * 3. Default region provider chain + */ + public var region: String? = null + + /** + * An optional region provider that determines the AWS region for client operations. When specified, this provider + * takes precedence over the default region provider chain, unless a static region is explicitly configured. + * + * The region resolution order is: + * 1. Static region (if specified) + * 2. Custom region provider (if configured) + * 3. Default region provider chain + */ + public var regionProvider: RegionProvider? = null + """.formatForTest(indent = " ") + contents.shouldContainOnlyOnceWithDiff(expectedImpl) + } +} diff --git a/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt index 19a1963844..3fbf2cfadc 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt @@ -123,6 +123,14 @@ class AwsHttpBindingProtocolGeneratorTest { ): Symbol { error("Unneeded for test") } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + error("Unneeded for test") + } } override val protocol: ShapeId diff --git a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt index 164e7cb2e6..c8fae63ca8 100644 --- a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt +++ b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt @@ -214,6 +214,14 @@ class MockHttpProtocolGenerator(model: Model) : HttpBindingProtocolGenerator() { val symbol = ctx.symbolProvider.toSymbol(shape) name = "serialize" + StringUtils.capitalize(symbol.name) + "Payload" } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + error("Unneeded for test") + } } override fun operationErrorHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol = buildSymbol { diff --git a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/ModelTestUtils.kt b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/ModelTestUtils.kt index e857f91fee..4c421f7fa7 100644 --- a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/ModelTestUtils.kt +++ b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/ModelTestUtils.kt @@ -8,6 +8,7 @@ import software.amazon.smithy.build.MockManifest import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.kotlin.codegen.* import software.amazon.smithy.kotlin.codegen.core.CodegenContext +import software.amazon.smithy.kotlin.codegen.core.GenerationContext import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration import software.amazon.smithy.kotlin.codegen.model.OperationNormalizer @@ -122,9 +123,11 @@ fun Model.newTestContext( val manifest = MockManifest() val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model = this, rootNamespace = packageName, serviceName = serviceName, settings = settings) val service = this.getShape(ShapeId.from("$packageName#$serviceName")).get().asServiceShape().get() - val delegator = KotlinDelegator(settings, this, manifest, provider, integrations) - val ctx = ProtocolGenerator.GenerationContext( + val codegenCtx = GenerationContext(this, provider, settings, generator, integrations) + val delegator = KotlinDelegator(codegenCtx, manifest, integrations) + + val generationCtx = ProtocolGenerator.GenerationContext( settings, this, service, @@ -133,7 +136,8 @@ fun Model.newTestContext( generator.protocol, delegator, ) - return TestContext(ctx, manifest, generator) + + return TestContext(generationCtx, manifest, generator) } fun TestContext.toCodegenContext() = object : CodegenContext { @@ -173,7 +177,7 @@ fun Model.defaultSettings( sdkId: String = TestModelDefault.SDK_ID, generateDefaultBuildFiles: Boolean = false, nullabilityCheckMode: CheckMode = CheckMode.CLIENT_CAREFUL, - defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT, + defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.DEFAULT, ): KotlinSettings { val serviceId = if (serviceName == null) { this.inferService() diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/CodegenVisitor.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/CodegenVisitor.kt index 4adf42b91f..1447363d4a 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/CodegenVisitor.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/CodegenVisitor.kt @@ -20,6 +20,7 @@ import software.amazon.smithy.kotlin.codegen.model.hasTrait import software.amazon.smithy.kotlin.codegen.rendering.* import software.amazon.smithy.kotlin.codegen.rendering.protocol.ApplicationProtocol import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.service.AbstractStubGenerator import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.ServiceIndex import software.amazon.smithy.model.neighbor.Walker @@ -87,12 +88,12 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default() { integration.decorateSymbolProvider(settings, model, provider) } - writers = KotlinDelegator(settings, model, fileManifest, symbolProvider, integrations) - protocolGenerator = resolveProtocolGenerator(integrations, model, service, settings) - applicationProtocol = protocolGenerator?.applicationProtocol ?: ApplicationProtocol.createDefaultHttpApplicationProtocol() - baseGenerationContext = GenerationContext(model, symbolProvider, settings, protocolGenerator, integrations) + + writers = KotlinDelegator(baseGenerationContext, fileManifest, integrations) + + applicationProtocol = protocolGenerator?.applicationProtocol ?: ApplicationProtocol.createDefaultHttpApplicationProtocol() } private fun resolveProtocolGenerator( @@ -114,7 +115,12 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default() { } fun execute() { - logger.info("Generating Kotlin client for service ${settings.service}") + val generateServiceProject = settings.build.generateServiceProject + if (generateServiceProject) { + logger.info("Generating Kotlin service ${settings.service}") + } else { + logger.info("Generating Kotlin client for service ${settings.service}") + } logger.info("Walking shapes from ${settings.service} to find shapes to generate") val modelWithoutTraits = ModelTransformer.create().getModelWithoutTraitShapes(model) @@ -138,11 +144,18 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default() { logger.info("[${service.id}] Generating service client for protocol $protocol") generateProtocolClient(ctx) - logger.info("[${service.id}] Generating endpoint provider for protocol $protocol") - generateEndpointsSources(ctx) + if (!generateServiceProject) { + logger.info("[${service.id}] Generating endpoint provider for protocol $protocol") + generateEndpointsSources(ctx) + + logger.info("[${service.id}] Generating auth scheme provider for protocol $protocol") + generateAuthSchemeProvider(ctx) + } + } - logger.info("[${service.id}] Generating auth scheme provider for protocol $protocol") - generateAuthSchemeProvider(ctx) + if (generateServiceProject) { + val serviceStubGenerator: AbstractStubGenerator = settings.serviceStub.framework.getServiceFrameworkGenerator(baseGenerationContext, writers, fileManifest) + serviceStubGenerator.render() } writers.finalize() diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt index 3ee6e0633b..01025bf77a 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt @@ -5,14 +5,10 @@ package software.amazon.smithy.kotlin.codegen -import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait -import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait -import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait -import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait -import software.amazon.smithy.aws.traits.protocols.RestJson1Trait -import software.amazon.smithy.aws.traits.protocols.RestXmlTrait +import software.amazon.smithy.aws.traits.protocols.* import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.kotlin.codegen.lang.isValidPackageName +import software.amazon.smithy.kotlin.codegen.service.ServiceFramework import software.amazon.smithy.kotlin.codegen.utils.getOrNull import software.amazon.smithy.kotlin.codegen.utils.toCamelCase import software.amazon.smithy.model.Model @@ -24,10 +20,10 @@ import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.protocol.traits.Rpcv2CborTrait -import java.util.Optional +import java.util.* import java.util.logging.Logger +import java.util.stream.Collectors import kotlin.IllegalArgumentException -import kotlin.streams.toList // shapeId of service from which to generate an SDK private const val SERVICE = "service" @@ -37,6 +33,7 @@ private const val PACKAGE_VERSION = "version" private const val PACKAGE_DESCRIPTION = "description" private const val BUILD_SETTINGS = "build" private const val API_SETTINGS = "api" +private const val SERVICE_STUB_SETTINGS = "serviceStub" // Optional specification of sdkId for models that provide them, otherwise Service's shape id name is used private const val SDK_ID = "sdkId" @@ -61,6 +58,7 @@ data class KotlinSettings( val sdkId: String, val build: BuildSettings = BuildSettings.Default, val api: ApiSettings = ApiSettings.Default, + val serviceStub: ServiceStubSettings = ServiceStubSettings.Default, val debug: Boolean = false, ) { @@ -103,7 +101,7 @@ data class KotlinSettings( * @return Returns the extracted settings */ fun from(model: Model, config: ObjectNode): KotlinSettings { - config.warnIfAdditionalProperties(listOf(SERVICE, PACKAGE_SETTINGS, BUILD_SETTINGS, SDK_ID, API_SETTINGS)) + config.warnIfAdditionalProperties(listOf(SERVICE, PACKAGE_SETTINGS, BUILD_SETTINGS, SDK_ID, API_SETTINGS, SERVICE_STUB_SETTINGS)) val serviceId = config.getStringMember(SERVICE) .map(StringNode::expectShapeId) @@ -123,6 +121,7 @@ data class KotlinSettings( val sdkId = config.getStringMemberOrDefault(SDK_ID, serviceId.name) val build = config.getObjectMember(BUILD_SETTINGS) val api = config.getObjectMember(API_SETTINGS) + val serviceStub = config.getObjectMember(SERVICE_STUB_SETTINGS) val debug = config.getBooleanMemberOrDefault("debug", false) return KotlinSettings( serviceId, @@ -130,6 +129,7 @@ data class KotlinSettings( sdkId, BuildSettings.fromNode(build), ApiSettings.fromNode(api), + ServiceStubSettings.fromNode(serviceStub), debug, ) } @@ -164,7 +164,7 @@ fun Model.inferService(): ShapeId { val services = shapes(ServiceShape::class.java) .map(Shape::getId) .sorted() - .toList() + .collect(Collectors.toList()) return when { services.isEmpty() -> { @@ -190,23 +190,27 @@ fun Model.inferService(): ShapeId { * @param optInAnnotations Kotlin opt-in annotations. See: * https://kotlinlang.org/docs/reference/opt-in-requirements.html * @param generateMultiplatformProject Flag indicating to generate a Kotlin multiplatform or JVM project + * @param generateServiceProject Flag indicating to generate a Kotlin service project */ data class BuildSettings( val generateFullProject: Boolean = false, val generateDefaultBuildFiles: Boolean = true, val optInAnnotations: List? = null, val generateMultiplatformProject: Boolean = false, + val generateServiceProject: Boolean = false, ) { companion object { const val ROOT_PROJECT = "rootProject" const val GENERATE_DEFAULT_BUILD_FILES = "generateDefaultBuildFiles" const val ANNOTATIONS = "optInAnnotations" const val GENERATE_MULTIPLATFORM_MODULE = "multiplatform" + const val GENERATE_SERVICE_PROJECT = "generateServiceProject" fun fromNode(node: Optional): BuildSettings = node.map { val generateFullProject = node.get().getBooleanMemberOrDefault(ROOT_PROJECT, false) val generateBuildFiles = node.get().getBooleanMemberOrDefault(GENERATE_DEFAULT_BUILD_FILES, true) val generateMultiplatformProject = node.get().getBooleanMemberOrDefault(GENERATE_MULTIPLATFORM_MODULE, false) + val generateServiceProject = node.get().getBooleanMemberOrDefault(GENERATE_SERVICE_PROJECT, false) val annotations = node.get().getArrayMember(ANNOTATIONS).map { it.elements.mapNotNull { node -> node.asStringNode().map { stringNode -> @@ -214,7 +218,7 @@ data class BuildSettings( }.orNull() } }.orNull() - BuildSettings(generateFullProject, generateBuildFiles, annotations, generateMultiplatformProject) + BuildSettings(generateFullProject, generateBuildFiles, annotations, generateMultiplatformProject, generateServiceProject) }.orElse(Default) /** @@ -273,10 +277,15 @@ enum class DefaultValueSerializationMode(val value: String) { override fun toString(): String = value companion object { + /** + * The default value serialization mode, which is [ALWAYS] + */ + val DEFAULT = ALWAYS + fun fromValue(value: String): DefaultValueSerializationMode = - values().find { - it.value == value - } ?: throw IllegalArgumentException("$value is not a valid DefaultValueSerializationMode, expected one of ${values().map { it.value }}") + requireNotNull(entries.find { it.value.equals(value, ignoreCase = true) }) { + "$value is not a valid DefaultValueSerializationMode, expected one of ${values().map { it.value }}" + } } } @@ -291,7 +300,7 @@ enum class DefaultValueSerializationMode(val value: String) { data class ApiSettings( val visibility: Visibility = Visibility.PUBLIC, val nullabilityCheckMode: CheckMode = CheckMode.CLIENT_CAREFUL, - val defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT, + val defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.DEFAULT, val enableEndpointAuthProvider: Boolean = false, val protocolResolutionPriority: Set = DEFAULT_PROTOCOL_RESOLUTION_PRIORITY, ) { @@ -315,7 +324,7 @@ data class ApiSettings( node.get() .getStringMemberOrDefault( DEFAULT_VALUE_SERIALIZATION_MODE, - DefaultValueSerializationMode.WHEN_DIFFERENT.value, + DefaultValueSerializationMode.DEFAULT.value, ), ) val enableEndpointAuthProvider = node.get().getBooleanMemberOrDefault(ENABLE_ENDPOINT_AUTH_PROVIDER, false) @@ -335,3 +344,29 @@ data class ApiSettings( val Default: ApiSettings = ApiSettings() } } + +/** + * Contains configurations settings for a Kotlin service project + * @param framework Enum representing the server framework of the service. + */ +data class ServiceStubSettings( + val framework: ServiceFramework = ServiceFramework.KTOR, +) { + companion object { + const val SERVER_FRAMEWORK = "serverFramework" + + fun fromNode(node: Optional): ServiceStubSettings = node.map { + val serverFramework = node.get() + .getStringMember(SERVER_FRAMEWORK) + .map { ServiceFramework.fromValue(it.value) } + .getOrNull() ?: ServiceFramework.KTOR + + ServiceStubSettings(serverFramework) + }.orElse(Default) + + /** + * Default service stub settings + */ + val Default: ServiceStubSettings = ServiceStubSettings() + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt index 2b943da042..186dcec017 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/AbstractCodeWriterExt.kt @@ -150,6 +150,15 @@ inline fun , reified V> AbstractCodeWriter.getConte */ inline fun , reified V> AbstractCodeWriter.getContextValue(key: SectionKey): V = getContextValue(key.name) +/** + * Convenience function to set a typed value in the context + * @param key + */ +inline fun , reified V> AbstractCodeWriter.putContextValue( + key: SectionKey, + value: V, +): W = putContext(key.name, value) + /** * Convenience function to set context only if there is no value already associated with the given [key] */ diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/CodegenContext.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/CodegenContext.kt index 06d8b7497d..a3f714d741 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/CodegenContext.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/CodegenContext.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.kotlin.codegen.core import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.kotlin.codegen.KotlinSettings import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +import software.amazon.smithy.kotlin.codegen.integration.SectionKey import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.Shape @@ -16,6 +17,10 @@ import software.amazon.smithy.model.shapes.Shape * Common codegen properties required across different codegen contexts */ interface CodegenContext { + companion object { + val Key = SectionKey("CodegenContext") + } + val model: Model val symbolProvider: SymbolProvider val settings: KotlinSettings diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt index 9ae10da763..ea0b6459f1 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegator.kt @@ -6,11 +6,9 @@ package software.amazon.smithy.kotlin.codegen.core import software.amazon.smithy.build.FileManifest import software.amazon.smithy.codegen.core.* -import software.amazon.smithy.kotlin.codegen.KotlinSettings import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration import software.amazon.smithy.kotlin.codegen.model.SymbolProperty import software.amazon.smithy.kotlin.codegen.utils.namespaceToPath -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.Shape import java.nio.file.Paths @@ -21,13 +19,10 @@ const val DEFAULT_TEST_SOURCE_SET_ROOT = "./src/test/kotlin/" * Manages writers for Kotlin files. */ class KotlinDelegator( - private val settings: KotlinSettings, - private val model: Model, + private val ctx: CodegenContext, val fileManifest: FileManifest, - private val symbolProvider: SymbolProvider, private val integrations: List = listOf(), ) { - private val writers: MutableMap = mutableMapOf() // Tracks dependencies for source not provided by codegen that may reside in the service source tree. @@ -91,7 +86,7 @@ class KotlinDelegator( shape: Shape, block: (KotlinWriter) -> Unit, ) { - val symbol = symbolProvider.toSymbol(shape) + val symbol = ctx.symbolProvider.toSymbol(shape) useSymbolWriter(symbol, block) } @@ -151,7 +146,9 @@ class KotlinDelegator( val needsNewline = writers.containsKey(formattedFilename) val writer = writers.getOrPut(formattedFilename) { val kotlinWriter = KotlinWriter(namespace) - if (settings.debug) kotlinWriter.enableStackTraceComments(true) + kotlinWriter.putContextValue(CodegenContext.Key, ctx) + + if (ctx.settings.debug) kotlinWriter.enableStackTraceComments(true) // Register all integrations [SectionWriterBindings] on the writer. integrations.forEach { integration -> diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt index b575efae60..3fe6d45c02 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt @@ -37,7 +37,11 @@ private fun getDefaultRuntimeVersion(): String { // publishing info const val RUNTIME_GROUP: String = "aws.smithy.kotlin" val RUNTIME_VERSION: String = System.getProperty("smithy.kotlin.codegen.clientRuntimeVersion", getDefaultRuntimeVersion()) -val KOTLIN_COMPILER_VERSION: String = System.getProperty("smithy.kotlin.codegen.kotlinCompilerVersion", "2.1.0") +val KOTLIN_COMPILER_VERSION: String = System.getProperty("smithy.kotlin.codegen.kotlinCompilerVersion", "2.2.0") +val KTOR_VERSION: String = System.getProperty("smithy.kotlin.codegen.ktorVersion", "3.2.2") +val SERIALIZATION_PLUGIN: String = System.getProperty("smithy.kotlin.codegen.SerializationPlugin", "2.0.20") +val KOTLINX_VERSION: String = System.getProperty("smithy.kotlin.codegen.ktorKotlinxVersion", "1.9.0") +val KTOR_LOGGING_BACKEND_VERSION: String = System.getProperty("smithy.kotlin.codegen.ktorLoggingBackendVersion", "1.4.14") enum class SourceSet { CommonMain, @@ -134,6 +138,25 @@ data class KotlinDependency( // External third-party dependencies val KOTLIN_STDLIB = KotlinDependency(GradleConfiguration.Implementation, "kotlin", "org.jetbrains.kotlin", "kotlin-stdlib", KOTLIN_COMPILER_VERSION) val KOTLIN_TEST = KotlinDependency(GradleConfiguration.TestImplementation, "kotlin.test", "org.jetbrains.kotlin", "kotlin-test", KOTLIN_COMPILER_VERSION) + val KOTLIN_TEST_IMPL = KOTLIN_TEST.copy(config = GradleConfiguration.Implementation) + + // Ktor server dependencies + // FIXME: version numbers should not be hardcoded, they should be setting dynamically based on the Gradle library versions + val KTOR_SERVER_CORE = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server", "io.ktor", "ktor-server-core", KTOR_VERSION) + val KTOR_SERVER_UTILS = KotlinDependency(GradleConfiguration.Implementation, "io.ktor", "io.ktor", "ktor-server-core", KTOR_VERSION) + val KTOR_SERVER_NETTY = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.netty", "io.ktor", "ktor-server-netty", KTOR_VERSION) + val KTOR_SERVER_CIO = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.cio", "io.ktor", "ktor-server-cio", KTOR_VERSION) + val KTOR_SERVER_JETTY_JAKARTA = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.jetty.jakarta", "io.ktor", "ktor-server-jetty-jakarta", KTOR_VERSION) + val KTOR_SERVER_HTTP = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.http", "io.ktor", "ktor-http-jvm", KTOR_VERSION) + val KTOR_SERVER_LOGGING = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins.calllogging", "io.ktor", "ktor-server-call-logging", KTOR_VERSION) + val KTOR_SERVER_BODY_LIMIT = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins", "io.ktor", "ktor-server-body-limit", KTOR_VERSION) + val KTOR_LOGGING_SLF4J = KotlinDependency(GradleConfiguration.Implementation, "org.slf4j", "ch.qos.logback", "logback-classic", KTOR_LOGGING_BACKEND_VERSION) + val KTOR_LOGGING_LOGBACK = KotlinDependency(GradleConfiguration.Implementation, "ch.qos.logback", "ch.qos.logback", "logback-classic", KTOR_LOGGING_BACKEND_VERSION) + val KTOR_SERVER_STATUS_PAGE = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins.statuspages", "io.ktor", "ktor-server-status-pages-jvm", KTOR_VERSION) + val KOTLINX_CBOR_SERDE = KotlinDependency(GradleConfiguration.Implementation, "kotlinx.serialization", "org.jetbrains.kotlinx", "kotlinx-serialization-cbor", KOTLINX_VERSION) + val KOTLINX_JSON_SERDE = KotlinDependency(GradleConfiguration.Implementation, "kotlinx.serialization.json", "org.jetbrains.kotlinx", "kotlinx-serialization-json", KOTLINX_VERSION) + val KTOR_SERVER_AUTH = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.auth", "io.ktor", "ktor-server-auth", KTOR_VERSION) + val KTOR_SERVER_DOUBLE_RECEIVE = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins.doublereceive", "io.ktor", "ktor-server-double-receive-jvm", KTOR_VERSION) } override fun getDependencies(): List { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt index 8313f38c9c..e8bd0d1c6e 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.kotlin.codegen.core import software.amazon.smithy.codegen.core.* import software.amazon.smithy.kotlin.codegen.KotlinSettings +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.lang.kotlinReservedWords import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.utils.dq @@ -332,6 +333,10 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli } override fun serviceShape(shape: ServiceShape): Symbol { + if (settings.build.generateServiceProject) { + // Intentionally not generating a *client symbol* for the service + return KotlinTypes.Nothing + } val serviceName = clientName(settings.sdkId) return createSymbolBuilder(shape, "${serviceName}Client") .namespace(rootNamespace, ".") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index 1cbf3b407b..1c8c4f4e8b 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -45,11 +45,11 @@ object RuntimeTypes { object HttpClient : RuntimeTypePackage(KotlinDependency.HTTP_CLIENT) { val SdkHttpClient = symbol("SdkHttpClient") - object Middleware : RuntimeTypePackage(KotlinDependency.HTTP, "middleware") { + object Middleware : RuntimeTypePackage(KotlinDependency.HTTP_CLIENT, "middleware") { val MutateHeadersMiddleware = symbol("MutateHeaders") } - object Operation : RuntimeTypePackage(KotlinDependency.HTTP, "operation") { + object Operation : RuntimeTypePackage(KotlinDependency.HTTP_CLIENT, "operation") { val AuthSchemeResolver = symbol("AuthSchemeResolver") val context = symbol("context") val EndpointResolver = symbol("EndpointResolver") @@ -68,18 +68,19 @@ object RuntimeTypes { val setResolvedEndpoint = symbol("setResolvedEndpoint") } - object Config : RuntimeTypePackage(KotlinDependency.HTTP, "config") { + object Config : RuntimeTypePackage(KotlinDependency.HTTP_CLIENT, "config") { val HttpClientConfig = symbol("HttpClientConfig") val HttpEngineConfig = symbol("HttpEngineConfig") } - object Engine : RuntimeTypePackage(KotlinDependency.HTTP, "engine") { + object Engine : RuntimeTypePackage(KotlinDependency.HTTP_CLIENT, "engine") { val HttpClientEngine = symbol("HttpClientEngine") val manage = symbol("manage", "engine.internal", isExtension = true) } - object Interceptors : RuntimeTypePackage(KotlinDependency.HTTP, "interceptors") { + object Interceptors : RuntimeTypePackage(KotlinDependency.HTTP_CLIENT, "interceptors") { val ContinueInterceptor = symbol("ContinueInterceptor") + val DiscoveredEndpointErrorInterceptor = symbol("DiscoveredEndpointErrorInterceptor") val HttpInterceptor = symbol("HttpInterceptor") val HttpChecksumRequiredInterceptor = symbol("HttpChecksumRequiredInterceptor") val FlexibleChecksumsRequestInterceptor = symbol("FlexibleChecksumsRequestInterceptor") @@ -97,7 +98,6 @@ object RuntimeTypes { } object Core : RuntimeTypePackage(KotlinDependency.CORE) { - val Clock = symbol("Clock", "time") val ExecutionContext = symbol("ExecutionContext", "operation") val ErrorMetadata = symbol("ErrorMetadata") val ServiceErrorMetadata = symbol("ServiceErrorMetadata") @@ -126,11 +126,12 @@ object RuntimeTypes { val attributesOf = symbol("attributesOf") val AttributeKey = symbol("AttributeKey") val createOrAppend = symbol("createOrAppend") + val ExpiringKeyedCache = symbol("ExpiringKeyedCache") val get = symbol("get") val mutableMultiMapOf = symbol("mutableMultiMapOf") + val PeriodicSweepCache = symbol("PeriodicSweepCache") val putIfAbsent = symbol("putIfAbsent") val putIfAbsentNotNull = symbol("putIfAbsentNotNull") - val ReadThroughCache = symbol("ReadThroughCache") val toMutableAttributes = symbol("toMutableAttributes") val emptyAttributes = symbol("emptyAttributes") } @@ -204,6 +205,7 @@ object RuntimeTypes { object Net : RuntimeTypePackage(KotlinDependency.CORE, "net") { val Host = symbol("Host") + val Scheme = symbol("Scheme") object Url : RuntimeTypePackage(KotlinDependency.CORE, "net.url") { val QueryParameters = symbol("QueryParameters") @@ -251,6 +253,10 @@ object RuntimeTypes { val Url = symbol("Url") } } + + object Region : RuntimeTypePackage(KotlinDependency.SMITHY_CLIENT, "region") { + val RegionProvider = symbol("RegionProvider") + } } object Serde : RuntimeTypePackage(KotlinDependency.SERDE) { @@ -374,6 +380,7 @@ object RuntimeTypes { val BearerTokenAuthScheme = symbol("BearerTokenAuthScheme") val BearerTokenProviderConfig = symbol("BearerTokenProviderConfig") val BearerTokenProvider = symbol("BearerTokenProvider") + val BearerToken = symbol("BearerToken") val EnvironmentBearerTokenProvider = symbol("EnvironmentBearerTokenProvider") @@ -387,6 +394,7 @@ object RuntimeTypes { val mergeAuthOptions = symbol("mergeAuthOptions") val sigV4 = symbol("sigV4") val sigV4A = symbol("sigV4A") + val SignHttpRequest = symbol("SignHttpRequest") } object AwsSigningCrt : RuntimeTypePackage(KotlinDependency.AWS_SIGNING_CRT) { @@ -453,8 +461,8 @@ object RuntimeTypes { val RestJsonErrorDeserializer = symbol("RestJsonErrorDeserializer") } object AwsXmlProtocols : RuntimeTypePackage(KotlinDependency.AWS_XML_PROTOCOLS) { - val parseRestXmlErrorResponseNoSuspend = symbol("parseRestXmlErrorResponseNoSuspend") - val parseEc2QueryErrorResponseNoSuspend = symbol("parseEc2QueryErrorResponseNoSuspend") + val parseRestXmlErrorResponse = symbol("parseRestXmlErrorResponse") + val parseEc2QueryErrorResponse = symbol("parseEc2QueryErrorResponse") } object SmithyRpcV2Protocols : RuntimeTypePackage(KotlinDependency.SMITHY_RPCV2_PROTOCOLS) { @@ -486,4 +494,134 @@ object RuntimeTypes { val sign = symbol("sign") } + + object KtorServerCore : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_CORE) { + val embeddedServer = symbol("embeddedServer", "engine") + val EmbeddedServerType = symbol("EmbeddedServer", "engine") + val ApplicationEngineFactory = symbol("ApplicationEngineFactory", "engine") + val connector = symbol("connector", "engine") + + val Application = symbol("Application", "application") + val ApplicationCallClass = symbol("ApplicationCall", "application") + val ApplicationStarting = symbol("ApplicationStarting", "application") + val ApplicationStarted = symbol("ApplicationStarted", "application") + val ApplicationStopping = symbol("ApplicationStopping", "application") + val ApplicationStopped = symbol("ApplicationStopped", "application") + val ApplicationCreateRouteScopedPlugin = symbol("createRouteScopedPlugin", "application") + val ApplicationRouteScopedPlugin = symbol("RouteScopedPlugin", "application") + val applicationCall = symbol("call", "application") + val install = symbol("install", "application") + + val BadRequestException = symbol("BadRequestException", "plugins") + } + + object KtorServerUtils : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_UTILS) { + val AttributeKey = symbol("AttributeKey", "util") + } + + object KtorServerRouting : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_CORE) { + val routing = symbol("routing", "routing") + val route = symbol("route", "routing") + val get = symbol("get", "routing") + val post = symbol("post", "routing") + val put = symbol("put", "routing") + val delete = symbol("delete", "routing") + val patch = symbol("patch", "routing") + val head = symbol("head", "routing") + val options = symbol("options", "routing") + + val requestReceive = symbol("receive", "request") + val requestUri = symbol("uri", "request") + val requestHeader = symbol("header", "request") + val requestHttpMethod = symbol("httpMethod", "request") + val requestApplicationRequest = symbol("ApplicationRequest", "request") + val requestContentLength = symbol("contentLength", "request") + val requestContentType = symbol("contentType", "request") + val requestAcceptItems = symbol("acceptItems", "request") + val requestPath = symbol("path", "request") + + val responseResponse = symbol("respond", "response") + val responseResponseText = symbol("respondText", "response") + val responseRespondBytes = symbol("respondBytes", "response") + } + + object KtorServerNetty : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_NETTY) { + val Netty = symbol("Netty") + val Configuration = symbol("Configuration", "NettyApplicationEngine") + } + + object KtorServerCio : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_CIO) { + val CIO = symbol("CIO") + val Configuration = symbol("Configuration", "CIOApplicationEngine") + } + + object KtorServerJettyJakarta : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_JETTY_JAKARTA) { + val Jetty = symbol("Jetty") + val Configuration = symbol("Configuration", "JettyApplicationEngineBase") + } + + object KtorServerHttp : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_HTTP) { + val ContentType = symbol("ContentType") + val HttpStatusCode = symbol("HttpStatusCode") + val parseAndSortHeader = symbol("parseAndSortHeader") + val HttpHeaders = symbol("HttpHeaders") + val HeadersBuilder = symbol("HeadersBuilder") + val Parameters = symbol("Parameters") + val Cbor = symbol("Cbor", "ContentType.Application") + val Json = symbol("Json", "ContentType.Application") + val Any = symbol("Any", "ContentType.Application") + val OctetStream = symbol("OctetStream", "ContentType.Application") + val PlainText = symbol("Plain", "ContentType.Text") + } + + object KtorServerLogging : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_LOGGING) { + val CallLogging = symbol("CallLogging") + } + + object KtorServerBodyLimit : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_BODY_LIMIT) { + val RequestBodyLimit = symbol("RequestBodyLimit", "bodylimit") + val PayloadTooLargeException = symbol("PayloadTooLargeException") + } + + object KtorLoggingSlf4j : RuntimeTypePackage(KotlinDependency.KTOR_LOGGING_SLF4J) { + val Level = symbol("Level", "event") + val LoggerFactory = symbol("LoggerFactory") + val ROOT_LOGGER_NAME = symbol("ROOT_LOGGER_NAME", "Logger") + } + + object KtorLoggingLogback : RuntimeTypePackage(KotlinDependency.KTOR_LOGGING_LOGBACK) { + val Level = symbol("Level", "classic") + val LoggerContext = symbol("LoggerContext", "classic") + } + + object KtorServerStatusPage : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_STATUS_PAGE) { + val StatusPages = symbol("StatusPages") + val exception = symbol("exception") + } + + object KtorServerAuth : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_AUTH) { + val Authentication = symbol("Authentication") + val authenticate = symbol("authenticate") + val Principal = symbol("Principal") + val bearer = symbol("bearer") + val AuthenticationConfig = symbol("AuthenticationConfig") + val AuthenticationProvider = symbol("AuthenticationProvider") + val AuthenticationFailedCause = symbol("AuthenticationFailedCause") + val AuthenticationContext = symbol("AuthenticationContext") + val AuthenticationStrategy = symbol("AuthenticationStrategy") + } + + object KtorServerDoubleReceive : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_DOUBLE_RECEIVE) { + val DoubleReceive = symbol("DoubleReceive") + } + + object KotlinxCborSerde : RuntimeTypePackage(KotlinDependency.KOTLINX_CBOR_SERDE) { + val Serializable = symbol("Serializable") + val Cbor = symbol("Cbor", "cbor") + val encodeToByteArray = symbol("encodeToByteArray") + } + + object KotlinxJsonSerde : RuntimeTypePackage(KotlinDependency.KOTLINX_JSON_SERDE) { + val Json = symbol("Json") + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt index 6d2384584e..5eef2a0ef7 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt @@ -117,6 +117,7 @@ object KotlinTypes { val Duration = stdlibSymbol("Duration") val milliseconds = stdlibSymbol("milliseconds", "time.Duration.Companion") val minutes = stdlibSymbol("minutes", "time.Duration.Companion") + val seconds = stdlibSymbol("seconds", "time.Duration.Companion") } object Coroutines { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionBaseClassGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionBaseClassGenerator.kt index aa308c5343..73c789e823 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionBaseClassGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionBaseClassGenerator.kt @@ -10,7 +10,6 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.KotlinSettings import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.integration.SectionId -import software.amazon.smithy.kotlin.codegen.integration.SectionKey import software.amazon.smithy.kotlin.codegen.model.buildSymbol import software.amazon.smithy.kotlin.codegen.model.namespace import software.amazon.smithy.model.knowledge.TopDownIndex @@ -31,12 +30,10 @@ object ExceptionBaseClassGenerator { /** * Defines a section in which code can be added to the body of the base exception type. */ - object ExceptionBaseClassSection : SectionId { - val CodegenContext: SectionKey = SectionKey("CodegenContext") - } + object ExceptionBaseClassSection : SectionId fun render(ctx: CodegenContext, writer: KotlinWriter) { - writer.declareSection(ExceptionBaseClassSection, mapOf(ExceptionBaseClassSection.CodegenContext to ctx)) { + writer.declareSection(ExceptionBaseClassSection) { ServiceExceptionBaseClassGenerator().render(ctx, writer) } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/GradleGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/GradleGenerator.kt index 07f7541a06..3b11dd16d4 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/GradleGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/GradleGenerator.kt @@ -6,7 +6,11 @@ package software.amazon.smithy.kotlin.codegen.rendering import software.amazon.smithy.build.FileManifest import software.amazon.smithy.kotlin.codegen.KotlinSettings -import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.core.InlineCodeWriter +import software.amazon.smithy.kotlin.codegen.core.InlineCodeWriterFormatter +import software.amazon.smithy.kotlin.codegen.core.KOTLIN_COMPILER_VERSION +import software.amazon.smithy.kotlin.codegen.core.KotlinDependency +import software.amazon.smithy.kotlin.codegen.core.SERIALIZATION_PLUGIN import software.amazon.smithy.utils.AbstractCodeWriter // Determines the jvmTarget version emitted to the build file @@ -27,6 +31,7 @@ fun writeGradleBuild( val isKmp = settings.build.generateMultiplatformProject val isRootModule = settings.build.generateFullProject + val generateServiceProject = settings.build.generateServiceProject val annotationRenderer: InlineCodeWriter = { val annotations = settings.build.optInAnnotations ?: emptyList() @@ -46,12 +51,18 @@ fun writeGradleBuild( val pluginName = if (isKmp) "multiplatform" else "jvm" write( - "kotlin(\"$pluginName\") #W", + "kotlin(\"$pluginName\") #W \n #W", { w: AbstractCodeWriter<*> -> if (isRootModule) { w.write("version #S", KOTLIN_COMPILER_VERSION) } }, + { w: AbstractCodeWriter<*> -> + if (generateServiceProject) { + w.write("application") + w.write("kotlin(#S) version #S", "plugin.serialization", SERIALIZATION_PLUGIN) + } + }, ) } @@ -67,10 +78,12 @@ fun writeGradleBuild( else -> renderJvmGradleBuild( writer, isRootModule, + generateServiceProject, dependencies, pluginsBodyRenderer, repositoryRenderer, annotationRenderer, + applicationRenderer("${settings.pkg.name}.MainKt"), ) } @@ -134,7 +147,7 @@ fun renderRootJvmPluginConfig(writer: GradleWriter) { """ jvm { compilations.all { - kotlinOptions.jvmTarget = #S + compilerOptions.jvmTarget = #S } testRuns["test"].executionTask.configure { useJUnitPlatform() @@ -155,10 +168,12 @@ fun renderRootJvmPluginConfig(writer: GradleWriter) { fun renderJvmGradleBuild( writer: GradleWriter, isRootModule: Boolean, + generateServiceProject: Boolean, dependencies: List, pluginsRenderer: InlineCodeWriter, repositoryRenderer: InlineCodeWriter, annotationRenderer: InlineCodeWriter, + applicationRenderer: InlineCodeWriter, ) { writer.write( """ @@ -166,6 +181,8 @@ fun renderJvmGradleBuild( #W } + #W + #W dependencies { @@ -196,6 +213,7 @@ fun renderJvmGradleBuild( """.trimIndent(), pluginsRenderer, { w: GradleWriter -> if (isRootModule) repositoryRenderer(w) }, + { w: GradleWriter -> if (generateServiceProject) applicationRenderer(w) }, { w: GradleWriter -> renderDependencies(w, scope = Scope.SOURCE, isKmp = false, dependencies = dependencies) }, annotationRenderer, { w: GradleWriter -> if (isRootModule) w.write("explicitApi()") }, @@ -245,6 +263,16 @@ private val repositoryRenderer: InlineCodeWriter = { ) } +private fun applicationRenderer(mainClass: String): InlineCodeWriter = { + write( + """ + application { + mainClass.set("$mainClass") + } + """.trimIndent(), + ) +} + class GradleWriter(parent: GradleWriter? = null) : AbstractCodeWriter() { init { trimBlankLines(parent?.trimBlankLines ?: 1) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/DefaultEndpointDiscovererGenerator.kt similarity index 55% rename from codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererGenerator.kt rename to codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/DefaultEndpointDiscovererGenerator.kt index 73b6e434da..2d21de0577 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/DefaultEndpointDiscovererGenerator.kt @@ -12,12 +12,11 @@ import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.buildSymbol import software.amazon.smithy.kotlin.codegen.model.expectShape import software.amazon.smithy.kotlin.codegen.model.expectTrait -import software.amazon.smithy.kotlin.codegen.rendering.endpoints.EndpointResolverAdapterGenerator import software.amazon.smithy.kotlin.codegen.rendering.endpoints.SdkEndpointBuiltinIntegration import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape -class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val delegator: KotlinDelegator) { +class DefaultEndpointDiscovererGenerator(private val ctx: CodegenContext, private val delegator: KotlinDelegator) { private val symbol = symbolFor(ctx.settings) private val service = ctx.model.expectShape(ctx.settings.service) private val clientSymbol = ctx.symbolProvider.toSymbol(service) @@ -30,58 +29,47 @@ class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val d companion object { fun symbolFor(settings: KotlinSettings): Symbol = buildSymbol { val clientName = clientName(settings.sdkId) - name = "${clientName}EndpointDiscoverer" + name = "Default${clientName}EndpointDiscoverer" namespace = "${settings.pkg.name}.endpoints" } } fun render() { delegator.applyFileWriter(symbol) { + val service = clientName(ctx.settings.sdkId) dokka( """ - A class which looks up specific endpoints for ${ctx.settings.sdkId} calls via the `$operationName` - API. These unique endpoints are cached as appropriate to avoid unnecessary latency in subsequent - calls. + A class which looks up specific endpoints for $service calls via the `$operationName` API. These + unique endpoints are cached as appropriate to avoid unnecessary latency in subsequent calls. + @param cache An [ExpiringKeyedCache] implementation used to cache discovered hosts """.trimIndent(), ) + withBlock( - "#L class #T {", + "#1L class #2T(#1L val cache: #3T = #5T(10.#6T)) : #7T {", "}", ctx.settings.api.visibility, symbol, + RuntimeTypes.Core.Collections.ExpiringKeyedCache, + RuntimeTypes.Core.Net.Host, + RuntimeTypes.Core.Collections.PeriodicSweepCache, + KotlinTypes.Time.minutes, + EndpointDiscovererInterfaceGenerator.symbolFor(ctx.settings), ) { - write( - "private val cache = #T(10.#T, #T.System)", - RuntimeTypes.Core.Collections.ReadThroughCache, - RuntimeTypes.Core.Net.Host, - KotlinTypes.Time.minutes, - RuntimeTypes.Core.Clock, - ) - write("") renderAsEndpointResolver() write("") - renderDiscoverHost() - write("") renderInvalidate() } - write("") - write( - """private val discoveryParamsKey = #T("DiscoveryParams")""", - RuntimeTypes.Core.Collections.AttributeKey, - ) - write("private data class DiscoveryParams(private val region: String?, private val identity: String)") } } private fun KotlinWriter.renderAsEndpointResolver() { withBlock( - "internal fun asEndpointResolver(client: #T, delegate: #T) = #T { request ->", + "override fun asEndpointResolver(client: #1T, delegate: #2T): #2T = #2T { request ->", "}", clientSymbol, - EndpointResolverAdapterGenerator.getSymbol(ctx.settings), RuntimeTypes.HttpClient.Operation.EndpointResolver, ) { - // Backported from https://github.com/smithy-lang/smithy-kotlin/pull/1221; replace when merging v1.5 to main withBlock("if (client.config.#L == null) {", "}", SdkEndpointBuiltinIntegration.EndpointUrlProp.propertyName) { write("val identity = request.identity") write( @@ -90,7 +78,7 @@ class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val d ) write("") write("val cacheKey = DiscoveryParams(client.config.region, identity.accessKeyId)") - write("request.context[discoveryParamsKey] = cacheKey") + write("request.context[DiscoveryParamsKey] = cacheKey") write("val discoveredHost = cache.get(cacheKey) { discoverHost(client) }") write("") write("val originalEndpoint = delegate.resolve(request)") @@ -99,6 +87,7 @@ class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val d write("originalEndpoint.headers,") write("originalEndpoint.attributes,") } + // If user manually specifies endpointUrl, skip endpoint discovery closeAndOpenBlock("} else {") write("delegate.resolve(request)") @@ -106,34 +95,9 @@ class EndpointDiscovererGenerator(private val ctx: CodegenContext, private val d } } - private fun KotlinWriter.renderDiscoverHost() { - openBlock( - "private suspend fun discoverHost(client: #T): #T<#T> =", - clientSymbol, - RuntimeTypes.Core.Utils.ExpiringValue, - RuntimeTypes.Core.Net.Host, - ) - // ASSUMPTION No services which use discovery include parameters to the EP operation (despite being - // possible according to the Smithy spec). - write("client.#L()", operationName) - indent() - write(".endpoints") - withBlock("?.map { ep -> #T(", ")}", RuntimeTypes.Core.Utils.ExpiringValue) { - write("#T.parse(ep.address!!),", RuntimeTypes.Core.Net.Host) - write("#T.now() + ep.cachePeriodInMinutes.#T,", RuntimeTypes.Core.Instant, KotlinTypes.Time.minutes) - } - write("?.firstOrNull()") - write( - """?: throw #T("Unable to discover any endpoints when invoking #L!")""", - RuntimeTypes.SmithyClient.Endpoints.EndpointProviderException, - operationName, - ) - dedent(2) - } - private fun KotlinWriter.renderInvalidate() { - withBlock("internal suspend fun invalidate(context: #T) {", "}", RuntimeTypes.Core.ExecutionContext) { - write("context.getOrNull(discoveryParamsKey)?.let { cache.invalidate(it) }") + withBlock("override public suspend fun invalidate(context: #T) {", "}", RuntimeTypes.Core.ExecutionContext) { + write("context.getOrNull(DiscoveryParamsKey)?.let { cache.invalidate(it) }") } } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererInterfaceGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererInterfaceGenerator.kt new file mode 100644 index 0000000000..ddc70752db --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererInterfaceGenerator.kt @@ -0,0 +1,94 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.rendering.endpoints.discovery + +import software.amazon.smithy.aws.traits.clientendpointdiscovery.ClientEndpointDiscoveryTrait +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.kotlin.codegen.KotlinSettings +import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes +import software.amazon.smithy.kotlin.codegen.model.buildSymbol +import software.amazon.smithy.kotlin.codegen.model.expectShape +import software.amazon.smithy.kotlin.codegen.model.expectTrait +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ServiceShape + +class EndpointDiscovererInterfaceGenerator(private val ctx: CodegenContext, private val delegator: KotlinDelegator) { + private val symbol = symbolFor(ctx.settings) + private val service = ctx.model.expectShape(ctx.settings.service) + private val clientSymbol = ctx.symbolProvider.toSymbol(service) + private val operationName = run { + val epDiscoveryTrait = service.expectTrait() + val operation = ctx.model.expectShape(epDiscoveryTrait.operation) + operation.defaultName() + } + + companion object { + fun symbolFor(settings: KotlinSettings): Symbol = buildSymbol { + val clientName = clientName(settings.sdkId) + name = "${clientName}EndpointDiscoverer" + namespace = "${settings.pkg.name}.endpoints" + } + } + + fun render() { + delegator.applyFileWriter(symbol) { + dokka("Represents the logic for automatically discovering endpoints for ${ctx.settings.sdkId} calls") + withBlock( + "#L interface #T {", + "}", + ctx.settings.api.visibility, + symbol, + ) { + write( + "#1L fun asEndpointResolver(client: #2T, delegate: #3T): #3T", + ctx.settings.api.visibility, + clientSymbol, + RuntimeTypes.HttpClient.Operation.EndpointResolver, + ) + write("") + renderDiscoverHost() + write("") + write("public suspend fun invalidate(context: #T)", RuntimeTypes.Core.ExecutionContext) + } + write("") + write( + "#L data class DiscoveryParams(private val region: String?, private val identity: String)", + ctx.settings.api.visibility, + ) + write( + """#1L val DiscoveryParamsKey: #2T = #2T("DiscoveryParams")""", + ctx.settings.api.visibility, + RuntimeTypes.Core.Collections.AttributeKey, + ) + } + } + + private fun KotlinWriter.renderDiscoverHost() { + openBlock( + "#L suspend fun discoverHost(client: #T): #T<#T> =", + ctx.settings.api.visibility, + clientSymbol, + RuntimeTypes.Core.Utils.ExpiringValue, + RuntimeTypes.Core.Net.Host, + ) + // ASSUMPTION No services which use discovery include parameters to the EP operation (despite being + // possible according to the Smithy spec). + write("client.#L()", operationName) + indent() + write(".endpoints") + withBlock("?.map { ep -> #T(", ")}", RuntimeTypes.Core.Utils.ExpiringValue) { + write("#T.parse(ep.address!!),", RuntimeTypes.Core.Net.Host) + write("#T.now() + ep.cachePeriodInMinutes.#T,", RuntimeTypes.Core.Instant, KotlinTypes.Time.minutes) + } + write("?.firstOrNull()") + write( + """?: throw #T("Unable to discover any endpoints when invoking #L!")""", + RuntimeTypes.SmithyClient.Endpoints.EndpointProviderException, + operationName, + ) + dedent(2) + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegration.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegration.kt index 62906af4d7..4e8a7446a1 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegration.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegration.kt @@ -22,26 +22,38 @@ import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape class EndpointDiscoveryIntegration : KotlinIntegration { - override fun additionalServiceConfigProps(ctx: CodegenContext): List { - val endpointDiscoveryOptional = ctx + companion object { + const val CLIENT_CONFIG_NAME = "endpointDiscoverer" + const val ORDER: Byte = 0 // doesn't depend on any other integrations + + fun isEnabledFor(model: Model, settings: KotlinSettings) = model + .expectShape(settings.service) + .hasTrait() + + fun isOptionalFor(ctx: CodegenContext) = ctx .model .operationShapes .none { it.getTrait()?.isRequired == true } - val discovererSymbol = EndpointDiscovererGenerator.symbolFor(ctx.settings) + } + + override fun additionalServiceConfigProps(ctx: CodegenContext): List { + val endpointDiscoveryOptional = isOptionalFor(ctx) + val interfaceSymbol = EndpointDiscovererInterfaceGenerator.symbolFor(ctx.settings) return super.additionalServiceConfigProps(ctx) + listOf( ConfigProperty { - name = "endpointDiscoverer" + name = CLIENT_CONFIG_NAME if (endpointDiscoveryOptional) { documentation = """ - The endpoint discoverer for this client, if applicable. By default, no endpoint - discovery is provided. To use endpoint discovery, set this to a valid - [${discovererSymbol.name}] instance. + The endpoint discoverer for this client, if applicable. By default, no endpoint discovery is + provided. To use endpoint discovery, set this to a valid [${interfaceSymbol.name}] instance. """.trimIndent() - symbol = discovererSymbol.asNullable() + symbol = interfaceSymbol.asNullable() } else { + val defaultImplSymbol = DefaultEndpointDiscovererGenerator.symbolFor(ctx.settings) documentation = "The endpoint discoverer for this client" - useSymbolWithNullableBuilder(discovererSymbol, "${discovererSymbol.name}()") + additionalImports = listOf(defaultImplSymbol) + useSymbolWithNullableBuilder(interfaceSymbol, "${defaultImplSymbol.name}()") } }, ) @@ -50,10 +62,11 @@ class EndpointDiscoveryIntegration : KotlinIntegration { override fun customizeMiddleware( ctx: ProtocolGenerator.GenerationContext, resolved: List, - ): List = super.customizeMiddleware(ctx, resolved) + listOf(DiscoveredEndpointMiddleware) + ): List = resolved + DiscoveredEndpointErrorMiddleware - override fun enabledForService(model: Model, settings: KotlinSettings): Boolean = - model.expectShape(settings.service).hasTrait() + override fun enabledForService(model: Model, settings: KotlinSettings): Boolean = isEnabledFor(model, settings) + + override val order = ORDER override val sectionWriters: List = listOf( SectionWriterBinding(HttpProtocolClientGenerator.EndpointResolverAdapterBinding, ::renderEndpointResolver), @@ -69,13 +82,15 @@ class EndpointDiscoveryIntegration : KotlinIntegration { null -> writer.write("#L", previousValue) true -> writer.write( - "execution.endpointResolver = config.endpointDiscoverer.asEndpointResolver(this@#L, #T(config))", + "execution.endpointResolver = config.#L.asEndpointResolver(this@#L, #T(config))", + CLIENT_CONFIG_NAME, defaultClientName, EndpointResolverAdapterGenerator.getSymbol(ctx.settings), ) false -> writer.write( - "execution.endpointResolver = config.endpointDiscoverer?.asEndpointResolver(this@#1L, #2T(config)) ?: #2T(config)", + "execution.endpointResolver = config.#1L?.asEndpointResolver(this@#2L, #3T(config)) ?: #3T(config)", + CLIENT_CONFIG_NAME, defaultClientName, EndpointResolverAdapterGenerator.getSymbol(ctx.settings), ) @@ -83,30 +98,29 @@ class EndpointDiscoveryIntegration : KotlinIntegration { } override fun writeAdditionalFiles(ctx: CodegenContext, delegator: KotlinDelegator) { - EndpointDiscovererGenerator(ctx, delegator).render() - super.writeAdditionalFiles(ctx, delegator) + EndpointDiscovererInterfaceGenerator(ctx, delegator).render() + + if (!isOptionalFor(ctx)) { + DefaultEndpointDiscovererGenerator(ctx, delegator).render() + } } } -private object DiscoveredEndpointMiddleware : ProtocolMiddleware { - override val name: String = "DiscoveredEndpointMiddleware" +private object DiscoveredEndpointErrorMiddleware : ProtocolMiddleware { + override val name: String = "DiscoveredEndpointErrorMiddleware" override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean = - op.getTrait()?.optionalError?.getOrNull() != null && + ctx.service.getTrait()?.optionalError?.getOrNull() != null && op.hasTrait() override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { - val interceptor = buildSymbol { - name = "DiscoveredEndpointErrorInterceptor" - namespace(KotlinDependency.HTTP_CLIENT, "aws.smithy.kotlin.runtime.http.interceptors") - } - val errorShapeId = ctx.service.expectTrait().optionalError.get() val errorShape = ctx.model.expectShape(errorShapeId) val errorSymbol = ctx.symbolProvider.toSymbol(errorShape) writer.write( - "config.endpointDiscoverer?.let { op.interceptors.add(#T(#T, it::invalidate)) }", - interceptor, + "config.#L?.let { op.interceptors.add(#T(#T::class, it::invalidate)) }", + EndpointDiscoveryIntegration.CLIENT_CONFIG_NAME, + RuntimeTypes.HttpClient.Interceptors.DiscoveredEndpointErrorInterceptor, errorSymbol, ) } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt index 0587f2ed24..4dae6e5c7e 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt @@ -120,6 +120,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { httpOperations.forEach { operation -> generateOperationSerializer(ctx, operation) } + + if (ctx.settings.build.generateServiceProject) { + val modeledErrors = httpOperations.flatMap { it.errors }.map { ctx.model.expectShape(it) as StructureShape }.toSet() + modeledErrors.forEach { generateExceptionSerializer(ctx, it) } + } } private fun generateDeserializers(ctx: ProtocolGenerator.GenerationContext) { @@ -131,15 +136,22 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } // generate HttpDeserialize for exception types - val modeledErrors = httpOperations.flatMap { it.errors }.map { ctx.model.expectShape(it) as StructureShape }.toSet() - modeledErrors.forEach { generateExceptionDeserializer(ctx, it) } + if (!ctx.settings.build.generateServiceProject) { + val modeledErrors = httpOperations.flatMap { it.errors }.map { ctx.model.expectShape(it) as StructureShape }.toSet() + modeledErrors.forEach { generateExceptionDeserializer(ctx, it) } + } } override fun generateProtocolClient(ctx: ProtocolGenerator.GenerationContext) { - val symbol = ctx.symbolProvider.toSymbol(ctx.service) - ctx.delegator.useFileWriter("Default${symbol.name}.kt", ctx.settings.pkg.name) { writer -> - val clientGenerator = getHttpProtocolClientGenerator(ctx) - clientGenerator.render(writer) + if (ctx.settings.build.generateServiceProject) { + require(protocolName in listOf("smithyRpcv2cbor", "awsRestjson1")) { "service project accepts only Cbor or JSON protocol" } + } + if (!ctx.settings.build.generateServiceProject) { + val symbol = ctx.symbolProvider.toSymbol(ctx.service) + ctx.delegator.useFileWriter("Default${symbol.name}.kt", ctx.settings.pkg.name) { writer -> + val clientGenerator = getHttpProtocolClientGenerator(ctx) + clientGenerator.render(writer) + } } generateSerializers(ctx) generateDeserializers(ctx) @@ -149,12 +161,17 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { * Generate request serializer (HttpSerialize) for an operation */ private fun generateOperationSerializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape) { - if (!op.input.isPresent) { + val serializationTarget = if (ctx.settings.build.generateServiceProject) { + op.output + } else { + op.input + } + if (!serializationTarget.isPresent) { return } - val inputShape = ctx.model.expectShape(op.input.get()) - val inputSymbol = ctx.symbolProvider.toSymbol(inputShape) + val serializationShape = ctx.model.expectShape(serializationTarget.get()) + val serializationSymbol = ctx.symbolProvider.toSymbol(serializationShape) // operation input shapes could be re-used across one or more operations. The protocol details may // be different though (e.g. uri/method). We need to generate a serializer/deserializer per/operation @@ -164,7 +181,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { name = op.serializerName() namespace = ctx.settings.pkg.serde - reference(inputSymbol, SymbolReference.ContextOption.DECLARE) + reference(serializationSymbol, SymbolReference.ContextOption.DECLARE) } val operationSerializerSymbols = setOf( RuntimeTypes.Http.HttpBody, @@ -175,27 +192,50 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val serdeMeta = HttpSerdeMeta(op.isInputEventStream(ctx.model)) ctx.delegator.useSymbolWriter(serializerSymbol) { writer -> - writer - .addImport(operationSerializerSymbols) - .write("") - .openBlock("internal class #T: #T.#L<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerializer, serdeMeta.variantName, inputSymbol) - .call { - val modifier = if (serdeMeta.isStreaming) "suspend " else "" - writer.openBlock( - "override #Lfun serialize(context: #T, input: #T): #T {", - modifier, - RuntimeTypes.Core.ExecutionContext, - inputSymbol, - RuntimeTypes.Http.Request.HttpRequestBuilder, - ) - .write("val builder = #T()", RuntimeTypes.Http.Request.HttpRequestBuilder) - .call { - renderHttpSerialize(ctx, op, writer) - } - .write("return builder") - .closeBlock("}") - } - .closeBlock("}") + if (ctx.settings.build.generateServiceProject) { + val serializerResultSymbol = getHttpSerializerResultSymbol(protocolName) + val defaultResponse = getHttpSerializerDefaultResponse(protocolName) + + writer + .openBlock("internal class #T {", serializerSymbol) + .call { + writer.openBlock( + "public fun serialize(context: #T, input: #T): #T {", + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + serializerResultSymbol, + ) + .write("var response: #T = $defaultResponse", serializerResultSymbol) + .call { + renderSerializeHttpBody(ctx, op, writer) + } + .write("return response") + .closeBlock("}") + } + .closeBlock("}") + } else { + writer + .addImport(operationSerializerSymbols) + .write("") + .openBlock("internal class #T: #T.#L<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerializer, serdeMeta.variantName, serializationSymbol) + .call { + val modifier = if (serdeMeta.isStreaming) "suspend " else "" + writer.openBlock( + "override #Lfun serialize(context: #T, input: #T): #T {", + modifier, + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + RuntimeTypes.Http.Request.HttpRequestBuilder, + ) + .write("val builder = #T()", RuntimeTypes.Http.Request.HttpRequestBuilder) + .call { + renderHttpSerialize(ctx, op, writer) + } + .write("return builder") + .closeBlock("}") + } + .closeBlock("}") + } } } @@ -206,7 +246,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { ) { val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) val httpTrait = resolver.httpTrait(op) - val requestBindings = resolver.requestBindings(op) + val bindings = resolver.requestBindings(op) writer .addImport(RuntimeTypes.Core.ExecutionContext) @@ -218,17 +258,17 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { renderUri(ctx, op, writer) // Query Parameters - renderQueryParameters(ctx, httpTrait, requestBindings, writer) + renderQueryParameters(ctx, httpTrait, bindings, writer) } } .write("") .call { // headers - val headerBindings = requestBindings + val headerBindings = bindings .filter { it.location == HttpBinding.Location.HEADER } .sortedBy { it.memberName } - val prefixHeaderBindings = requestBindings + val prefixHeaderBindings = bindings .filter { it.location == HttpBinding.Location.PREFIX_HEADERS } if (headerBindings.isNotEmpty() || prefixHeaderBindings.isNotEmpty()) { @@ -256,6 +296,42 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } } + /** + * Generate HttpSerialize for a modeled error (exception) + */ + private fun generateExceptionSerializer(ctx: ProtocolGenerator.GenerationContext, shape: StructureShape) { + val serializationSymbol = ctx.symbolProvider.toSymbol(shape) + + val serializerSymbol = buildSymbol { + val deserializerName = "${serializationSymbol.name}Serializer" + definitionFile = "$deserializerName.kt" + name = deserializerName + namespace = ctx.settings.pkg.serde + reference(serializationSymbol, SymbolReference.ContextOption.DECLARE) + } + + ctx.delegator.useSymbolWriter(serializerSymbol) { writer -> + val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) + val bindings = resolver.responseBindings(shape) + val serializerResultSymbol = getHttpSerializerResultSymbol(protocolName) + val defaultResponse = getHttpSerializerDefaultResponse(protocolName) + writer.withBlock("internal class #T {", "}", serializerSymbol) { + writer.openBlock( + "public fun serialize(context: #T, input: #T): #T {", + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + serializerResultSymbol, + ) + .write("var response: #T = $defaultResponse", serializerResultSymbol) + .call { + renderExceptionSerializeBody(ctx, serializationSymbol, bindings, writer) + } + .write("return response") + .closeBlock("}") + } + } + } + /** * Calls the operation body serializer function and binds the results to `builder.body`. * By default if no members are bound to the body this function renders nothing. @@ -263,23 +339,57 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { */ protected open fun renderSerializeHttpBody(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) - if (!resolver.hasHttpBody(op)) return + if (ctx.settings.build.generateServiceProject) { + if (!resolver.hasHttpResponseBody(op)) return + } else { + if (!resolver.hasHttpRequestBody(op)) return + } // payload member(s) - val requestBindings = resolver.requestBindings(op) - val httpPayload = requestBindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD } + val bindings = if (ctx.settings.build.generateServiceProject) { + resolver.responseBindings(op) + } else { + resolver.requestBindings(op) + } + val httpPayload = bindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD } if (httpPayload != null) { renderExplicitHttpPayloadSerializer(ctx, httpPayload, writer) } else { - val documentMembers = requestBindings.filterDocumentBoundMembers() + val documentMembers = bindings.filterDocumentBoundMembers() // Unbound document members that should be serialized into the document format for the protocol. // delegate to the generate operation body serializer function val sdg = structuredDataSerializer(ctx) val opBodySerializerFn = sdg.operationSerializer(ctx, op, documentMembers) writer.write("val payload = #T(context, input)", opBodySerializerFn) - writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = payload.decodeToString()") + } else { + writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + } + } + if (!ctx.settings.build.generateServiceProject) { + renderContentTypeHeader(ctx, op, writer, resolver) } - renderContentTypeHeader(ctx, op, writer, resolver) + } + + /** + * Calls the operation body serializer function and binds the results to `builder.body`. + * By default if no members are bound to the body this function renders nothing. + * If there is a payload to render it should be bound to `builder.body` when this function returns + */ + protected open fun renderExceptionSerializeBody( + ctx: ProtocolGenerator.GenerationContext, + deserializationSymbol: Symbol, + bindings: List, + writer: KotlinWriter, + ) { + val documentMembers = bindings.filterDocumentBoundMembers() + // Unbound document members that should be serialized into the document format for the protocol. + // delegate to the generate operation body serializer function + val sdg = structuredDataSerializer(ctx) + val exceptionBodySerializerFn = sdg.errorSerializer(ctx, deserializationSymbol.shape as StructureShape, documentMembers) + writer.write("val payload = #T(context, input)", exceptionBodySerializerFn) + writer.write("response = payload") } protected open fun renderContentTypeHeader( @@ -303,24 +413,28 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { ) { val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) val httpTrait = resolver.httpTrait(op) - val requestBindings = resolver.requestBindings(op) - val pathBindings = requestBindings.filter { it.location == HttpBinding.Location.LABEL } + val bindings = if (ctx.settings.build.generateServiceProject) { + resolver.responseBindings(op) + } else { + resolver.requestBindings(op) + } + val pathBindings = bindings.filter { it.location == HttpBinding.Location.LABEL } if (pathBindings.isNotEmpty()) { // One of the few places we generate client side validation // httpLabel bindings must be non-null httpTrait.uri.segments.filter { it.isLabel || it.isGreedyLabel }.forEach { segment -> - val binding = pathBindings.find { + val bindings = pathBindings.find { it.memberName == segment.content } ?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}`") - val memberSymbol = ctx.symbolProvider.toSymbol(binding.member) + val memberSymbol = ctx.symbolProvider.toSymbol(bindings.member) if (memberSymbol.isNullable) { - writer.write("""requireNotNull(input.#1L) { "#1L is bound to the URI and must not be null" }""", binding.member.defaultName()) + writer.write("""requireNotNull(input.#1L) { "#1L is bound to the URI and must not be null" }""", bindings.member.defaultName()) } // check length trait if applicable - renderNonBlankGuard(ctx, binding.member, writer) + renderNonBlankGuard(ctx, bindings.member, writer) } if (httpTrait.uri.segments.isNotEmpty()) { @@ -328,37 +442,37 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { httpTrait.uri.segments.forEach { segment -> if (segment.isLabel || segment.isGreedyLabel) { // spec dictates member name and label name MUST be the same - val binding = pathBindings.find { binding -> + val bindings = pathBindings.find { binding -> binding.memberName == segment.content } ?: throw CodegenException("failed to find corresponding member for httpLabel `${segment.content}") // shape must be string, number, boolean, or timestamp - val targetShape = ctx.model.expectShape(binding.member.target) - val memberSymbol = ctx.symbolProvider.toSymbol(binding.member) + val targetShape = ctx.model.expectShape(bindings.member.target) + val memberSymbol = ctx.symbolProvider.toSymbol(bindings.member) val identifier = when { targetShape.isTimestampShape -> { addImport(RuntimeTypes.Core.TimestampFormat) val tsFormat = resolver.determineTimestampFormat( - binding.member, + bindings.member, HttpBinding.Location.LABEL, defaultTimestampFormat, ) val nullCheck = if (memberSymbol.isNullable) "?" else "" val tsLabel = formatInstant( - "input.${binding.member.defaultName()}$nullCheck", + "input.${bindings.member.defaultName()}$nullCheck", tsFormat, forceString = true, ) tsLabel } - targetShape.isStringEnumShape -> "input.${binding.member.defaultName()}.value" - targetShape.isIntEnumShape -> "input.${binding.member.defaultName()}.value.toString()" + targetShape.isStringEnumShape -> "input.${bindings.member.defaultName()}.value" + targetShape.isIntEnumShape -> "input.${bindings.member.defaultName()}.value.toString()" - targetShape.isStringShape -> "input.${binding.member.defaultName()}" + targetShape.isStringShape -> "input.${bindings.member.defaultName()}" - else -> "input.${binding.member.defaultName()}.toString()" + else -> "input.${bindings.member.defaultName()}.toString()" } val encodeFn = @@ -397,17 +511,17 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { private fun renderQueryParameters( ctx: ProtocolGenerator.GenerationContext, httpTrait: HttpTrait, - requestBindings: List, + bindings: List, writer: KotlinWriter, ) { // literals in the URI val queryLiterals = httpTrait.uri.queryLiterals // shape bindings - val queryBindings = requestBindings.filter { it.location == HttpBinding.Location.QUERY } + val queryBindings = bindings.filter { it.location == HttpBinding.Location.QUERY } // maps bound via httpQueryParams trait - val queryMapBindings = requestBindings.filter { it.location == HttpBinding.Location.QUERY_PARAMS } + val queryMapBindings = bindings.filter { it.location == HttpBinding.Location.QUERY_PARAMS } if (queryBindings.isEmpty() && queryLiterals.isEmpty() && queryMapBindings.isEmpty()) return @@ -493,7 +607,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { if (isBinaryStream) { writer.write("builder.body = input.#L.#T()", memberName, RuntimeTypes.Http.toHttpBody) } else { - writer.write("builder.body = #T.fromBytes(input.#L)", RuntimeTypes.Http.HttpBody, memberName) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L.decodeToString()", memberName) + } else { + writer.write("builder.body = #T.fromBytes(input.#L)", RuntimeTypes.Http.HttpBody, memberName) + } } } @@ -503,29 +621,46 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } else { memberName } - writer.write("builder.body = #T.fromBytes(input.#L.#T())", RuntimeTypes.Http.HttpBody, contents, KotlinTypes.Text.encodeToByteArray) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L", contents) + } else { + writer.write("builder.body = #T.fromBytes(input.#L.#T())", RuntimeTypes.Http.HttpBody, contents, KotlinTypes.Text.encodeToByteArray) + } } ShapeType.ENUM -> - writer.write( - "builder.body = #T.fromBytes(input.#L.value.#T())", - RuntimeTypes.Http.HttpBody, - memberName, - KotlinTypes.Text.encodeToByteArray, - ) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L.value.toString()", memberName) + } else { + writer.write( + "builder.body = #T.fromBytes(input.#L.value.#T())", + RuntimeTypes.Http.HttpBody, + memberName, + KotlinTypes.Text.encodeToByteArray, + ) + } + ShapeType.INT_ENUM -> - writer.write( - "builder.body = #T.fromBytes(input.#L.value.toString().#T())", - RuntimeTypes.Http.HttpBody, - memberName, - KotlinTypes.Text.encodeToByteArray, - ) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L.value.toString()", memberName) + } else { + writer.write( + "builder.body = #T.fromBytes(input.#L.value.toString().#T())", + RuntimeTypes.Http.HttpBody, + memberName, + KotlinTypes.Text.encodeToByteArray, + ) + } ShapeType.STRUCTURE, ShapeType.UNION, ShapeType.DOCUMENT -> { val sdg = structuredDataSerializer(ctx) val payloadSerializerFn = sdg.payloadSerializer(ctx, binding.member) writer.write("val payload = #T(input.#L)", payloadSerializerFn, memberName) - writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = payload.decodeToString()") + } else { + writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + } } else -> @@ -538,12 +673,16 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { * Generate request deserializer (HttpDeserialize) for an operation */ private fun generateOperationDeserializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape) { - if (!op.output.isPresent) { + val deserializationBindings = if (ctx.settings.build.generateServiceProject) { + op.input + } else { + op.output + } + if (!deserializationBindings.isPresent) { return } - - val outputShape = ctx.model.expectShape(op.output.get()) - val outputSymbol = ctx.symbolProvider.toSymbol(outputShape) + val deserializationShape = ctx.model.expectShape(deserializationBindings.get()) + val deserializationSymbol = ctx.symbolProvider.toSymbol(deserializationShape) // operation output shapes could be re-used across one or more operations. The protocol details may // be different though (e.g. uri/method). We need to generate a serializer/deserializer per/operation @@ -554,29 +693,44 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { name = op.deserializerName() namespace = ctx.settings.pkg.serde - reference(outputSymbol, SymbolReference.ContextOption.DECLARE) + reference(deserializationSymbol, SymbolReference.ContextOption.DECLARE) } val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) - val responseBindings = resolver.responseBindings(op) + val bindings = if (ctx.settings.build.generateServiceProject) { + resolver.requestBindings(op) + } else { + resolver.responseBindings(op) + } val serdeMeta = httpDeserializerInfo(ctx, op) ctx.delegator.useSymbolWriter(deserializerSymbol) { writer -> - writer - .write("") - .openBlock( - "internal class #T: #T.#L<#T> {", - deserializerSymbol, - RuntimeTypes.HttpClient.Operation.HttpDeserializer, - serdeMeta.variantName, - outputSymbol, - ) - .write("") - .call { - renderHttpDeserialize(ctx, outputSymbol, responseBindings, serdeMeta, op, writer) - } - .closeBlock("}") + when (ctx.settings.build.generateServiceProject) { + true -> + writer + .write("") + .openBlock( + "internal class #T {", + deserializerSymbol, + ) + .write("") + .call { renderServiceHttpDeserialize(ctx, deserializationSymbol, bindings, serdeMeta, op, writer) } + .closeBlock("}") + false -> + writer + .write("") + .openBlock( + "internal class #T: #T.#L<#T> {", + deserializerSymbol, + RuntimeTypes.HttpClient.Operation.HttpDeserializer, + serdeMeta.variantName, + deserializationSymbol, + ) + .write("") + .call { renderHttpDeserialize(ctx, deserializationSymbol, bindings, serdeMeta, op, writer) } + .closeBlock("}") + } } } @@ -599,7 +753,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { * Generate HttpDeserialize for a modeled error (exception) */ private fun generateExceptionDeserializer(ctx: ProtocolGenerator.GenerationContext, shape: StructureShape) { - val outputSymbol = ctx.symbolProvider.toSymbol(shape) + val deserializationSymbol = ctx.symbolProvider.toSymbol(shape) val exceptionDeserializerSymbols = setOf( RuntimeTypes.Core.ExecutionContext, RuntimeTypes.Http.Response.HttpResponse, @@ -611,11 +765,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { ) val deserializerSymbol = buildSymbol { - val deserializerName = "${outputSymbol.name}Deserializer" + val deserializerName = "${deserializationSymbol.name}Deserializer" definitionFile = "$deserializerName.kt" name = deserializerName namespace = ctx.settings.pkg.serde - reference(outputSymbol, SymbolReference.ContextOption.DECLARE) + reference(deserializationSymbol, SymbolReference.ContextOption.DECLARE) } // exception deserializers are never streaming @@ -623,23 +777,25 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { ctx.delegator.useSymbolWriter(deserializerSymbol) { writer -> val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) - val responseBindings = resolver.responseBindings(shape) + val bindings = if (ctx.settings.build.generateServiceProject) { + resolver.requestBindings(shape) + } else { + resolver.responseBindings(shape) + } writer .addImport(exceptionDeserializerSymbols) .write("") - .openBlock("internal class #T: #T.NonStreaming<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserializer, outputSymbol) + .openBlock("internal class #T: #T.NonStreaming<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserializer, deserializationSymbol) .write("") - .call { - renderHttpDeserialize(ctx, outputSymbol, responseBindings, serdeMeta, null, writer) - } + .call { renderHttpDeserialize(ctx, deserializationSymbol, bindings, serdeMeta, null, writer) } .closeBlock("}") } } private fun renderHttpDeserialize( ctx: ProtocolGenerator.GenerationContext, - outputSymbol: Symbol, - responseBindings: List, + deserializationSymbol: Symbol, + bindings: List, serdeMeta: HttpSerdeMeta, // this method is shared between operation and exception deserialization. In the case of operations this MUST be set op: OperationShape?, @@ -651,7 +807,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { "override suspend fun deserialize(context: #T, call: #T): #T {", RuntimeTypes.Core.ExecutionContext, RuntimeTypes.Http.HttpCall, - outputSymbol, + deserializationSymbol, ) } else { writer @@ -660,22 +816,22 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { RuntimeTypes.Core.ExecutionContext, RuntimeTypes.Http.HttpCall, KotlinTypes.ByteArray, - outputSymbol, + deserializationSymbol, ) } writer.write("val response = call.response") .call { - if (outputSymbol.shape?.isError == false && op != null) { + if (deserializationSymbol.shape?.isError == false && op != null) { // handle operation errors renderIsHttpError(ctx, op, writer) } } - .write("val builder = #T.Builder()", outputSymbol) + .write("val builder = #T.Builder()", deserializationSymbol) .write("") .call { // headers - val headerBindings = responseBindings + val headerBindings = bindings .filter { it.location == HttpBinding.Location.HEADER } .sortedBy { it.memberName } @@ -683,7 +839,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { // prefix headers // spec: "Only a single structure member can be bound to httpPrefixHeaders" - responseBindings.firstOrNull { it.location == HttpBinding.Location.PREFIX_HEADERS } + bindings.firstOrNull { it.location == HttpBinding.Location.PREFIX_HEADERS } ?.let { renderDeserializePrefixHeaders(ctx, it, writer) } @@ -693,11 +849,70 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { if (op != null && op.isOutputEventStream(ctx.model)) { deserializeViaEventStream(ctx, op, writer) } else { - deserializeViaPayload(ctx, outputSymbol, responseBindings, serdeMeta, op, writer) + deserializeViaPayload(ctx, deserializationSymbol, bindings, serdeMeta, op, writer) } } .call { - responseBindings.firstOrNull { it.location == HttpBinding.Location.RESPONSE_CODE } + bindings.firstOrNull { it.location == HttpBinding.Location.RESPONSE_CODE } + ?.let { + renderDeserializeResponseCode(ctx, it, writer) + } + } + // Render client side error correction for `@required` members. + // NOTE: nested members bound via the document/payload will be handled by the deserializer for the relevant + // content type. All other members (e.g. bound via REST semantics) will be corrected here. + .write("builder.correctErrors()") + .write("return builder.build()") + .closeBlock("}") + } + + private fun renderServiceHttpDeserialize( + ctx: ProtocolGenerator.GenerationContext, + deserializationSymbol: Symbol, + bindings: List, + serdeMeta: HttpSerdeMeta, + // this method is shared between operation and exception deserialization. In the case of operations this MUST be set + op: OperationShape?, + writer: KotlinWriter, + ) { + writer + .openBlock( + "public fun deserialize(context: #T, call: #T, payload: #T?): #T {", + RuntimeTypes.Core.ExecutionContext, + RuntimeTypes.KtorServerCore.ApplicationCallClass, + KotlinTypes.ByteArray, + deserializationSymbol, + ) + + writer.write("val request = call.request") + .write("val builder = #T.Builder()", deserializationSymbol) + .write("") + .call { + // headers + val headerBindings = bindings + .filter { it.location == HttpBinding.Location.HEADER } + .sortedBy { it.memberName } + + renderDeserializeHeaders(ctx, headerBindings, writer) + + // prefix headers + // spec: "Only a single structure member can be bound to httpPrefixHeaders" + bindings.firstOrNull { it.location == HttpBinding.Location.PREFIX_HEADERS } + ?.let { + renderDeserializePrefixHeaders(ctx, it, writer) + } + } + .write("") + .call { + // TODO: will never enter this block. event stream is not in the scope of service generation for now. + if (op != null && op.isOutputEventStream(ctx.model)) { + deserializeViaEventStream(ctx, op, writer) + } else { + deserializeViaPayload(ctx, deserializationSymbol, bindings, serdeMeta, op, writer) + } + } + .call { + bindings.firstOrNull { it.location == HttpBinding.Location.RESPONSE_CODE } ?.let { renderDeserializeResponseCode(ctx, it, writer) } @@ -715,20 +930,20 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { */ private fun deserializeViaPayload( ctx: ProtocolGenerator.GenerationContext, - outputSymbol: Symbol, - responseBindings: List, + deserializationSymbol: Symbol, + bindings: List, serdeMeta: HttpSerdeMeta, // this method is shared between operation and exception deserialization. In the case of operations this MUST be set op: OperationShape?, writer: KotlinWriter, ) { // payload member(s) - val httpPayload = responseBindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD } + val httpPayload = bindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD } if (httpPayload != null) { renderExplicitHttpPayloadDeserializer(ctx, httpPayload, writer) } else { // Unbound document members that should be deserialized from the document format for the protocol. - val documentMembers = responseBindings + val documentMembers = bindings .filter { it.location == HttpBinding.Location.DOCUMENT } .sortedBy { it.memberName } .map { it.member } @@ -741,7 +956,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { sdg.operationDeserializer(ctx, op, documentMembers) } else { // error - sdg.errorDeserializer(ctx, outputSymbol.shape as StructureShape, documentMembers) + sdg.errorDeserializer(ctx, deserializationSymbol.shape as StructureShape, documentMembers) } if (!serdeMeta.isStreaming) { @@ -770,7 +985,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val memberTarget = ctx.model.expectShape(binding.member.target) check(memberTarget.type == ShapeType.INTEGER) { "Unexpected target type in response code deserialization: ${memberTarget.id} (${memberTarget.type})" } - writer.write("builder.#L = response.status.value", memberName) } @@ -794,21 +1008,25 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } else { "" } - + val message = if (ctx.settings.build.generateServiceProject) { + "request" + } else { + "response" + } when (memberTarget) { is NumberShape -> { if (memberTarget is IntEnumShape) { val enumSymbol = ctx.symbolProvider.toSymbol(memberTarget) writer.addImport(enumSymbol) writer.write( - "builder.#L = response.headers[#S]?.let { #T.fromValue(it.toInt()) }", + "builder.#L = $message.headers[#S]?.let { #T.fromValue(it.toInt()) }", memberName, headerName, enumSymbol, ) } else { writer.write( - "builder.#L = response.headers[#S]?.#L$defaultValuePostfix", + "builder.#L = $message.headers[#S]?.#L$defaultValuePostfix", memberName, headerName, stringToNumber(memberTarget), @@ -817,13 +1035,13 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } is BooleanShape -> { writer.write( - "builder.#L = response.headers[#S]?.toBoolean()$defaultValuePostfix", + "builder.#L = $message.headers[#S]?.toBoolean()$defaultValuePostfix", memberName, headerName, ) } is BlobShape -> { - writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64) + writer.write("builder.#L = $message.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64) } is StringShape -> { when { @@ -831,17 +1049,17 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val enumSymbol = ctx.symbolProvider.toSymbol(memberTarget) writer.addImport(enumSymbol) writer.write( - "builder.#L = response.headers[#S]?.let { #T.fromValue(it) }", + "builder.#L = $message.headers[#S]?.let { #T.fromValue(it) }", memberName, headerName, enumSymbol, ) } memberTarget.hasTrait() -> { - writer.write("builder.#L = response.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64) + writer.write("builder.#L = $message.headers[#S]?.#T()", memberName, headerName, RuntimeTypes.Core.Text.Encoding.decodeBase64) } else -> { - writer.write("builder.#L = response.headers[#S]", memberName, headerName) + writer.write("builder.#L = $message.headers[#S]", memberName, headerName) } } } @@ -852,7 +1070,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { defaultTimestampFormat, ) writer.write( - "builder.#L = response.headers[#S]?.let { #L }", + "builder.#L = $message.headers[#S]?.let { #L }", memberName, headerName, writer.parseInstantExpr("it", tsFormat), @@ -912,7 +1130,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { writer .addImport(splitFn, KotlinDependency.HTTP, subpackage = "util") - .write("builder.#L = response.headers.getAll(#S)?.flatMap(::$splitFn)$mapFn", memberName, headerName) + .write("builder.#L = $message.headers.getAll(#S)?.flatMap(::$splitFn)$mapFn", memberName, headerName) } else -> throw CodegenException("unknown deserialization: header binding: $hdrBinding; member: `$memberName`") } @@ -941,7 +1159,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { "" } - writer.write("val $keyCollName = response.headers.names()$filter") + val message = if (ctx.settings.build.generateServiceProject) { + "request" + } else { + "response" + } + writer.write("val $keyCollName = $message.headers.names()$filter") writer.openBlock("if ($keyCollName.isNotEmpty()) {") .write("val map = mutableMapOf()", targetValueSymbol) .openBlock("for (hdrKey in $keyCollName) {") @@ -952,7 +1175,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { else -> throw CodegenException("invalid httpPrefixHeaders usage on ${binding.member}") } // get()/getAll() returns String? or List?, this shouldn't ever trigger the continue though... - writer.write("val el = response.headers$getFn ?: continue") + writer.write("val el = $message.headers$getFn ?: continue") if (prefix?.isNotEmpty() == true) { writer.write("val key = hdrKey.removePrefix(#S)", prefix) writer.write("map[key] = el") @@ -977,7 +1200,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val memberName = binding.member.defaultName() val target = ctx.model.expectShape(binding.member.target) val targetSymbol = ctx.symbolProvider.toSymbol(target) - + val message = if (ctx.settings.build.generateServiceProject) { + "request" + } else { + "response" + } // NOTE: we don't need serde metadata to know what to do here. Everything is non-streaming except streaming // blob payloads. when (target.type) { @@ -1003,7 +1230,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { ShapeType.BLOB -> { val isBinaryStream = target.hasTrait() if (isBinaryStream) { - writer.write("builder.#L = response.body.#T()", memberName, RuntimeTypes.Http.toByteStream) + writer.write("builder.#L = $message.body.#T()", memberName, RuntimeTypes.Http.toByteStream) } else { writer.write("builder.#L = payload", memberName) } @@ -1107,8 +1334,25 @@ private data class HttpSerdeMeta(val isStreaming: Boolean) { } private fun httpDeserializerInfo(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): HttpSerdeMeta { - val isStreaming = ctx.model.expectShape(op.output.get()).hasStreamingMember(ctx.model) || + val deserializationTarget = if (ctx.settings.build.generateServiceProject) { + op.input + } else { + op.output + } + val isStreaming = ctx.model.expectShape(deserializationTarget.get()).hasStreamingMember(ctx.model) || op.isOutputEventStream(ctx.model) return HttpSerdeMeta(isStreaming) } + +private fun getHttpSerializerResultSymbol(protocolName: String) = when (protocolName) { + "smithyRpcv2cbor" -> KotlinTypes.ByteArray + "awsRestjson1" -> KotlinTypes.String + else -> error("service project accepts only Cbor or JSON protocol") +} + +private fun getHttpSerializerDefaultResponse(protocolName: String) = when (protocolName) { + "smithyRpcv2cbor" -> "ByteArray(0)" + "awsRestjson1" -> "\"\"" + else -> error("service project accepts only Cbor or JSON protocol") +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt index f6cfad41b9..7e2345f836 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt @@ -53,10 +53,10 @@ interface HttpBindingResolver { /** * Return request bindings for an operation. - * @param operationShape [OperationShape] for which to retrieve bindings + * @param shape [Shape] for which to retrieve bindings * @return all found http request bindings */ - fun requestBindings(operationShape: OperationShape): List + fun requestBindings(shape: Shape): List /** * Return response bindings for an operation. @@ -90,11 +90,19 @@ interface HttpBindingResolver { /** * @return true if the operation contains request data bound to the PAYLOAD or DOCUMENT locations */ -fun HttpBindingResolver.hasHttpBody(operationShape: OperationShape): Boolean = +fun HttpBindingResolver.hasHttpRequestBody(operationShape: OperationShape): Boolean = requestBindings(operationShape).any { it.location == HttpBinding.Location.PAYLOAD || it.location == HttpBinding.Location.DOCUMENT } +/** + * @return true if the operation contains request data bound to the PAYLOAD or DOCUMENT locations + */ +fun HttpBindingResolver.hasHttpResponseBody(operationShape: OperationShape): Boolean = + responseBindings(operationShape).any { + it.location == HttpBinding.Location.PAYLOAD || it.location == HttpBinding.Location.DOCUMENT + } + /** * Protocol content type mappings */ @@ -141,10 +149,12 @@ class HttpTraitResolver( override fun httpTrait(operationShape: OperationShape): HttpTrait = operationShape.expectTrait() - override fun requestBindings(operationShape: OperationShape): List = bindingIndex - .getRequestBindings(operationShape) - .values - .map { HttpBindingDescriptor(it) } + override fun requestBindings(shape: Shape): List = when (shape) { + is OperationShape, + is StructureShape, + -> bindingIndex.getRequestBindings(shape.toShapeId()).values.map { HttpBindingDescriptor(it) } + else -> error { "Unimplemented resolving bindings for ${shape.javaClass.canonicalName}" } + } override fun responseBindings(shape: Shape): List = when (shape) { is OperationShape, diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborParserGenerator.kt index ac0518edd7..ba2fb9105e 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborParserGenerator.kt @@ -24,12 +24,17 @@ class CborParserGenerator( op: OperationShape, members: List, ): Symbol { - val outputSymbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(op.outputShape)) + val deserializationShape = if (ctx.settings.build.generateServiceProject) { + op.inputShape + } else { + op.outputShape + } + val deserializationsymbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(deserializationShape)) return op.bodyDeserializer(ctx.settings) { writer -> addNestedDocumentDeserializers(ctx, op, writer) val fnName = op.bodyDeserializerName() - writer.withBlock("private fun #L(builder: #T.Builder, payload: ByteArray) {", "}", fnName, outputSymbol) { + writer.withBlock("private fun #L(builder: #T.Builder, payload: ByteArray) {", "}", fnName, deserializationsymbol) { call { renderDeserializeOperationBody(ctx, op, members, writer) } } } @@ -91,9 +96,13 @@ class CborParserGenerator( writer: KotlinWriter, ) { writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeCbor.CborDeserializer) - - val shape = ctx.model.expectShape(op.output.get()) - renderDeserializerBody(ctx, shape, documentMembers, writer) + val deserializationTarget = if (ctx.settings.build.generateServiceProject) { + op.input + } else { + op.output + } + val deserializationShape = ctx.model.expectShape(deserializationTarget.get()) + renderDeserializerBody(ctx, deserializationShape, documentMembers, writer) } private fun renderDeserializerBody( @@ -103,7 +112,6 @@ class CborParserGenerator( writer: KotlinWriter, ) { descriptorGenerator(ctx, shape, members, writer).render() - if (shape.isUnionShape) { val name = ctx.symbolProvider.toSymbol(shape).name DeserializeUnionGenerator(ctx, name, members, writer, TimestampFormatTrait.Format.EPOCH_SECONDS).render() diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt index 545a484a90..a4467e10e3 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.kotlin.codegen.core.KotlinWriter import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex import software.amazon.smithy.kotlin.codegen.model.targetOrSelf import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator @@ -20,12 +21,28 @@ class CborSerializerGenerator( private val protocolGenerator: ProtocolGenerator, ) : StructuredDataSerializerGenerator { override fun operationSerializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List): Symbol { - val input = op.input.get().let { ctx.model.expectShape(it) } - val symbol = ctx.symbolProvider.toSymbol(input) + val serializationTarget = if (ctx.settings.build.generateServiceProject) { + op.output + } else { + op.input + } + val serializationShape = serializationTarget.get().let { ctx.model.expectShape(it) } + val serializationSymbol = ctx.symbolProvider.toSymbol(serializationShape) + val serializerResultSymbol = when { + ctx.settings.build.generateServiceProject -> KotlinTypes.ByteArray + else -> RuntimeTypes.Http.HttpBody + } return op.bodySerializer(ctx.settings) { writer -> addNestedDocumentSerializers(ctx, op, writer) - writer.withBlock("private fun #L(context: #T, input: #T): #T {", "}", op.bodySerializerName(), RuntimeTypes.Core.ExecutionContext, symbol, RuntimeTypes.Http.HttpBody) { + writer.withBlock( + "private fun #L(context: #T, input: #T): #T {", + "}", + op.bodySerializerName(), + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + serializerResultSymbol, + ) { call { renderSerializeOperationBody(ctx, op, members, writer) } @@ -39,10 +56,19 @@ class CborSerializerGenerator( documentMembers: List, writer: KotlinWriter, ) { - val shape = ctx.model.expectShape(op.input.get()) + val serializationTarget = if (ctx.settings.build.generateServiceProject) { + op.output + } else { + op.input + } + val serializationshape = ctx.model.expectShape(serializationTarget.get()) writer.write("val serializer = #T()", RuntimeTypes.Serde.SerdeCbor.CborSerializer) - renderSerializerBody(ctx, shape, documentMembers, writer) - writer.write("return serializer.toHttpBody()") + renderSerializerBody(ctx, serializationshape, documentMembers, writer) + if (ctx.settings.build.generateServiceProject) { + writer.write("return serializer.toByteArray()") + } else { + writer.write("return serializer.toHttpBody()") + } } private fun renderSerializerBody( @@ -112,4 +138,24 @@ class CborSerializerGenerator( } } } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + val symbol = ctx.symbolProvider.toSymbol(errorShape) + + return symbol.errorSerializer(ctx.settings) { writer -> + addNestedDocumentSerializers(ctx, errorShape, writer) + val fnName = symbol.errorSerializerName() + writer.openBlock("private fun #L(context: #T, input: #T): ByteArray {", fnName, RuntimeTypes.Core.ExecutionContext, symbol) + .write("val serializer = #T()", RuntimeTypes.Serde.SerdeCbor.CborSerializer) + .call { + renderSerializerBody(ctx, errorShape, members, writer) + } + .write("return serializer.toByteArray()") + .closeBlock("}") + } + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt index 4970e55f91..a98bb369e6 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt @@ -36,7 +36,12 @@ open class JsonParserGenerator( ) override fun operationDeserializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List): Symbol { - val outputSymbol = op.output.get().let { ctx.symbolProvider.toSymbol(ctx.model.expectShape(it)) } + val deserializationTarget = if (ctx.settings.build.generateServiceProject) { + op.input + } else { + op.output + } + val outputSymbol = deserializationTarget.get().let { ctx.symbolProvider.toSymbol(ctx.model.expectShape(it)) } return op.bodyDeserializer(ctx.settings) { writer -> addNestedDocumentDeserializers(ctx, op, writer) val fnName = op.bodyDeserializerName() @@ -74,8 +79,13 @@ open class JsonParserGenerator( documentMembers: List, writer: KotlinWriter, ) { + val deserializationTarget = if (ctx.settings.build.generateServiceProject) { + op.input + } else { + op.output + } writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeJson.JsonDeserializer) - val shape = ctx.model.expectShape(op.output.get()) + val shape = ctx.model.expectShape(deserializationTarget.get()) renderDeserializerBody(ctx, shape, documentMembers, writer) } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt index 6366fe8e04..62a78064b3 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt @@ -26,8 +26,13 @@ open class JsonSerializerGenerator( open val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS override fun operationSerializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List): Symbol { - val input = ctx.model.expectShape(op.input.get()) - val symbol = ctx.symbolProvider.toSymbol(input) + val serializationTarget = if (ctx.settings.build.generateServiceProject) { + op.output + } else { + op.input + } + val shape = ctx.model.expectShape(serializationTarget.get()) + val symbol = ctx.symbolProvider.toSymbol(shape) return op.bodySerializer(ctx.settings) { writer -> addNestedDocumentSerializers(ctx, op, writer) @@ -61,7 +66,12 @@ open class JsonSerializerGenerator( documentMembers: List, writer: KotlinWriter, ) { - val shape = ctx.model.expectShape(op.input.get()) + val serializationTarget = if (ctx.settings.build.generateServiceProject) { + op.output + } else { + op.input + } + val shape = ctx.model.expectShape(serializationTarget.get()) writer.write("val serializer = #T()", RuntimeTypes.Serde.SerdeJson.JsonSerializer) renderSerializerBody(ctx, shape, documentMembers, writer) writer.write("return serializer.toByteArray()") @@ -118,4 +128,24 @@ open class JsonSerializerGenerator( } } } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + val symbol = ctx.symbolProvider.toSymbol(errorShape) + + return symbol.errorSerializer(ctx.settings) { writer -> + addNestedDocumentSerializers(ctx, errorShape, writer) + val fnName = symbol.errorSerializerName() + writer.openBlock("private fun #L(context: #T, input: #T): String {", fnName, RuntimeTypes.Core.ExecutionContext, symbol) + .write("val serializer = #T()", RuntimeTypes.Serde.SerdeJson.JsonSerializer) + .call { + renderSerializerBody(ctx, errorShape, members, writer) + } + .write("return serializer.toByteArray().decodeToString()") + .closeBlock("}") + } + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt index 071d8ba364..2083b83524 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt @@ -149,6 +149,26 @@ fun Symbol.errorDeserializer(settings: KotlinSettings, block: SymbolRenderer): S renderBy = block } +/** + * Get the serializer name for an error shape + */ +fun Symbol.errorSerializerName(): String = "serialize" + StringUtils.capitalize(this.name) + "Error" + +/** + * Get the function responsible for serializing members bound to the payload of an error shape as [Symbol] and + * register [block] * which will be invoked to actually render the function (signature and implementation) + */ +fun Symbol.errorSerializer(settings: KotlinSettings, block: SymbolRenderer): Symbol = buildSymbol { + name = errorSerializerName() + namespace = settings.pkg.serde + val symbol = this@errorSerializer + // place it in the same file as the exception deserializer, e.g. for HTTP protocols this will be in + // same file as HttpDeserialize + definitionFile = "${symbol.name}Serializer.kt" + reference(symbol, SymbolReference.ContextOption.DECLARE) + renderBy = block +} + /** * Get the function responsible for deserializing the specific shape as a standalone payload */ diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt index c3e49c128d..7d08a03b86 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerato import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape /** * Responsible for rendering serialization of structured data (e.g. json, yaml, xml). @@ -56,4 +57,27 @@ interface StructuredDataSerializerGenerator { shape: Shape, members: Collection? = null, ): Symbol + + /** + * Render function responsible for serializing members bound to the payload for the given error shape. + * + * Because only a subset of fields of an operation error may be bound to the payload a builder is given + * as an argument. + * + * ``` + * fun serializeFooError(builder: FooError.Builder, payload: ByteArray) { + * ... + * } + * ``` + * + * Implementations are expected to instantiate an appropriate serializer for the protocol and serialize + * the error shape from the payload using the builder passed in. + * + * @param ctx the protocol generator context + * @param errorShape the error shape to render deserialize for + * @param members the members of the error shape that are bound to the payload. Not all members are + * bound to the document, some may be bound to e.g. headers, status code, etc + * @return the generated symbol which should be a function matching the signature expected for the protocol + */ + fun errorSerializer(ctx: ProtocolGenerator.GenerationContext, errorShape: StructureShape, members: List): Symbol } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt index 1f77136e15..0ff77e2e85 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingConte import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.model.traits.XmlAttributeTrait @@ -141,4 +142,12 @@ open class XmlSerializerGenerator( } } } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + TODO("Used for service-codegen. Not yet implemented") + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/README.md b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/README.md new file mode 100644 index 0000000000..18832f7e6a --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/README.md @@ -0,0 +1,43 @@ +# Smithy Kotlin Service Codegen (SKSC) + +## Overview + +This project generate **service-side code** from Smithy models, producing **complete service stubs**, including routing, serialization/deserialization, authentication, and validation, so developers can focus entirely on implementing business logic. + +While Ktor is the default backend, the architecture is framework-agnostic, allowing future support for other server frameworks. + + +## Getting Started + +- Get an [introduction to Smithy](https://smithy.io/2.0/index.html) +- Follow [Smithy's quickstart guide](https://smithy.io/2.0/quickstart.html) +- See the [Guide](docs/GettingStarted.md) to learn how to use SKSC to generate service. +- See a [Summary of Service Support](docs/FEATURES.md) to learn which features are supported + + +## Development + +### Module Structure + +- `constraints` – directory that contains the constraints validation generator. + - `ConstraintsGenerator.kt` - main generator for constraints. + - `ConstraintUtilsGenerator` - generator for constraint utilities. + - For each constraint trait, there is a dedicated file. +- `ktor` – directory that stores all features generators specific to Ktor. + - `ktorStubGenerator.kt` – main generator for ktor framework service stub generator. +- `ServiceStubConfiguration.kt` – configuration file for the service stub generator. +- `ServiceStubGenerator.kt` – abstract service stub generator file. +- `ServiceTypes.kt` – file that includes service component symbols. +- `utils.kt` – utilities file. + +### Testing + +The **service code generation tests** are located in `tests/codegen/service-codegen-tests`. These end-to-end tests generate the service, launch the server, send HTTP requests to validate functionality, and then shut down the service once testing is complete. This process typically takes around 2 minutes. To run tests specifically for SKSC, use the following command: +```bash + ./gradlew :tests:codegen:service-codegen-tests:test +``` + +## Feedback + +You can provide feedback or report a bug by submitting an [issue](https://github.com/smithy-lang/smithy-kotlin/issues/new/choose). +This is the preferred mechanism to give feedback so that other users can engage in the conversation, +1 issues, etc. \ No newline at end of file diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubConfigurations.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubConfigurations.kt new file mode 100644 index 0000000000..012f557174 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubConfigurations.kt @@ -0,0 +1,98 @@ +package software.amazon.smithy.kotlin.codegen.service + +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.kotlin.codegen.core.GenerationContext +import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator +import software.amazon.smithy.kotlin.codegen.service.ktor.KtorStubGenerator +import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.model.traits.HttpPayloadTrait +import software.amazon.smithy.model.traits.MediaTypeTrait +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait + +/** + * Enumeration of supported media types used by the generated service. + * + * These values define how payloads are encoded/decoded: + * - `CBOR` → Concise Binary Object Representation + * - `JSON` → JSON encoding + * - `PLAIN_TEXT` → Text/plain + * - `OCTET_STREAM` → Binary data + * - `ANY` → Fallback for arbitrary content types + */ +enum class MediaType(val value: String) { + CBOR("CBOR"), + JSON("JSON"), + PLAIN_TEXT("PLAIN_TEXT"), + OCTET_STREAM("OctetStream"), + ANY("ANY"), + ; + + override fun toString(): String = value + + companion object { + fun fromValue(value: String): MediaType = MediaType + .entries + .firstOrNull { it.name.equals(value.uppercase(), ignoreCase = true) } + ?: throw IllegalArgumentException("$value is not a validContentType value, expected one of ${MediaType.entries}") + + fun fromServiceShape(ctx: GenerationContext, shape: ServiceShape, targetShapeId: ShapeId): MediaType { + return when { + shape.hasTrait(Rpcv2CborTrait.ID) -> CBOR + shape.hasTrait(RestJson1Trait.ID) -> { + val targetShape = ctx.model.expectShape(targetShapeId) + for (memberShape in targetShape.allMembers.values) { + if (!memberShape.hasTrait(HttpPayloadTrait.ID)) continue + val memberType = ctx.model.expectShape(memberShape.target).type + when (memberType) { + ShapeType.STRING -> return PLAIN_TEXT + ShapeType.BLOB -> return OCTET_STREAM + ShapeType.DOCUMENT, + ShapeType.STRUCTURE, + ShapeType.UNION, + -> return JSON + else -> { + if (memberShape.hasTrait(MediaTypeTrait.ID)) return ANY + } + } + } + return JSON + } + + else -> throw IllegalArgumentException("Cannot find supported MediaType for the service") + } + } + } +} + +/** + * Enumeration of supported service frameworks for generated stubs. + * + * Currently only supports: + * - `KTOR`: Generates Ktor-based service stubs + */ +enum class ServiceFramework(val value: String) { + KTOR("ktor"), + ; + + override fun toString(): String = value + + companion object { + fun fromValue(value: String): ServiceFramework = when (value.lowercase()) { + "ktor" -> KTOR + else -> throw IllegalArgumentException("$value is not a valid ServerFramework value, expected $KTOR") + } + } + + internal fun getServiceFrameworkGenerator( + ctx: GenerationContext, + delegator: KotlinDelegator, + fileManifest: FileManifest, + ): AbstractStubGenerator { + when (this) { + KTOR -> return KtorStubGenerator(ctx, delegator, fileManifest) + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt new file mode 100644 index 0000000000..ef3a46a7be --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt @@ -0,0 +1,271 @@ +package software.amazon.smithy.kotlin.codegen.service + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.kotlin.codegen.core.GenerationContext +import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.defaultName +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.core.withInlineBlock +import software.amazon.smithy.model.knowledge.TopDownIndex + +/** + * Interface representing a generator that produces service stubs (boilerplate code) for a Smithy service. + */ +internal interface ServiceStubGenerator { + /** + * Render the stub code into the target files. + */ + fun render() +} + +/** + * Abstract base class for generating service stubs. + * + * Provides a framework for generating common service artifacts such as: + * - Configuration (`ServiceFrameworkConfig.kt`) + * - Framework bootstrap (`ServiceFramework.kt`) + * - Plugins, utils, authentication, validators + * - Operation handlers and routing + * - Main launcher (`Main.kt`) + * + * Concrete subclasses must implement abstract methods for framework-specific + * code (e.g., Ktor). + */ +internal abstract class AbstractStubGenerator( + val ctx: GenerationContext, + val delegator: KotlinDelegator, + val fileManifest: FileManifest, +) : ServiceStubGenerator { + + val serviceShape = ctx.settings.getService(ctx.model) + val operations = TopDownIndex.of(ctx.model) + .getContainedOperations(serviceShape) + .sortedBy { it.defaultName() } + + val pkgName = ctx.settings.pkg.name + + /** + * Render all service stub files by invoking the component renderers. + * This acts as the main entrypoint for code generation. + */ + final override fun render() { + renderServiceFrameworkConfig() + renderServiceFramework() + renderPlugins() + renderUtils() + renderAuthModule() + renderConstraintValidators() + renderPerOperationHandlers() + renderRouting() + renderMainFile() + } + + /** + * Generate the service configuration file (`ServiceFrameworkConfig.kt`). + * + * Defines enums for: + * - `LogLevel`: Logging verbosity levels + * - `ServiceEngine`: Available server engines (Netty, CIO, Jetty) + * + * Provides a singleton `ServiceFrameworkConfig` object that stores runtime + * settings such as port, engine, region, timeouts, and log level. + */ + protected fun renderServiceFrameworkConfig() { + delegator.useFileWriter("ServiceFrameworkConfig.kt", "${ctx.settings.pkg.name}.config") { writer -> + writer.withBlock("internal enum class LogLevel(val value: String) {", "}") { + write("INFO(#S),", "INFO") + write("WARN(#S),", "WARN") + write("DEBUG(#S),", "DEBUG") + write("ERROR(#S),", "ERROR") + write("TRACE(#S),", "TRACE") + write("OFF(#S),", "OFF") + write(";") + write("") + write("override fun toString(): String = value") + write("") + withBlock("companion object {", "}") { + withBlock("fun fromValue(value: String): #T = when (value.uppercase()) {", "}", ServiceTypes(pkgName).logLevel) { + write("INFO.value -> INFO") + write("WARN.value -> WARN") + write("DEBUG.value -> DEBUG") + write("ERROR.value -> ERROR") + write("TRACE.value -> TRACE") + write("OFF.value -> OFF") + write("else -> throw IllegalArgumentException(#S)", "Unknown LogLevel value: \$value") + } + } + } + writer.write("") + + writer.withBlock("internal enum class ServiceEngine(val value: String) {", "}") { + write("NETTY_ENGINE(#S),", "netty") + write("CIO_ENGINE(#S),", "cio") + write("JETTY_JAKARTA_ENGINE(#S),", "jetty-jakarta") + write(";") + write("") + write("override fun toString(): String = value") + write("") + withBlock("companion object {", "}") { + withBlock("fun fromValue(value: String): #T {", "}", ServiceTypes(pkgName).serviceEngine) { + write( + "return #T.entries.firstOrNull { it.value.equals(value.lowercase(), ignoreCase = true) } ?: throw IllegalArgumentException(#S)", + ServiceTypes(pkgName).serviceEngine, + "\$value is not a validContentType value, expected one of \${ServiceEngine.entries}", + ) + } + } + write("") + withBlock("fun toEngineFactory(): #T<*, *> {", "}", RuntimeTypes.KtorServerCore.ApplicationEngineFactory) { + withBlock("return when(this) {", "}") { + write("NETTY_ENGINE -> #T as #T<*, *>", RuntimeTypes.KtorServerNetty.Netty, RuntimeTypes.KtorServerCore.ApplicationEngineFactory) + write("CIO_ENGINE -> #T as #T<*, *>", RuntimeTypes.KtorServerCio.CIO, RuntimeTypes.KtorServerCore.ApplicationEngineFactory) + write("JETTY_JAKARTA_ENGINE -> #T as #T<*, *>", RuntimeTypes.KtorServerJettyJakarta.Jetty, RuntimeTypes.KtorServerCore.ApplicationEngineFactory) + } + } + } + writer.write("") + + writer.withBlock("internal object ServiceFrameworkConfig {", "}") { + write("private var backing: Data? = null") + write("") + withBlock("private data class Data(", ")") { + write("val port: Int,") + write("val engine: #T,", ServiceTypes(pkgName).serviceEngine) + write("val region: String,") + write("val requestBodyLimit: Long,") + write("val requestReadTimeoutSeconds: Int,") + write("val responseWriteTimeoutSeconds: Int,") + write("val closeGracePeriodMillis: Long,") + write("val closeTimeoutMillis: Long,") + write("val logLevel: #T,", ServiceTypes(pkgName).logLevel) + } + write("") + write("val port: Int get() = backing?.port ?: notInitialised(#S)", "port") + write("val engine: #T get() = backing?.engine ?: notInitialised(#S)", ServiceTypes(pkgName).serviceEngine, "engine") + write("val region: String get() = backing?.region ?: notInitialised(#S)", "region") + write("val requestBodyLimit: Long get() = backing?.requestBodyLimit ?: notInitialised(#S)", "requestBodyLimit") + write("val requestReadTimeoutSeconds: Int get() = backing?.requestReadTimeoutSeconds ?: notInitialised(#S)", "requestReadTimeoutSeconds") + write("val responseWriteTimeoutSeconds: Int get() = backing?.responseWriteTimeoutSeconds ?: notInitialised(#S)", "responseWriteTimeoutSeconds") + write("val closeGracePeriodMillis: Long get() = backing?.closeGracePeriodMillis ?: notInitialised(#S)", "closeGracePeriodMillis") + write("val closeTimeoutMillis: Long get() = backing?.closeTimeoutMillis ?: notInitialised(#S)", "closeTimeoutMillis") + write("val logLevel: #T get() = backing?.logLevel ?: notInitialised(#S)", ServiceTypes(pkgName).logLevel, "logLevel") + write("") + withInlineBlock("fun init(", ")") { + write("port: Int,") + write("engine: #T,", ServiceTypes(pkgName).serviceEngine) + write("region: String,") + write("requestBodyLimit: Long,") + write("requestReadTimeoutSeconds: Int,") + write("responseWriteTimeoutSeconds: Int,") + write("closeGracePeriodMillis: Long,") + write("closeTimeoutMillis: Long,") + write("logLevel: #T,", ServiceTypes(pkgName).logLevel) + } + withBlock("{", "}") { + write("check(backing == null) { #S }", "ServiceFrameworkConfig has already been initialised") + write("backing = Data(port, engine, region, requestBodyLimit, requestReadTimeoutSeconds, responseWriteTimeoutSeconds, closeGracePeriodMillis, closeTimeoutMillis, logLevel)") + } + write("") + withBlock("private fun notInitialised(prop: String): Nothing {", "}") { + write("error(#S)", "ServiceFrameworkConfig.\$prop accessed before init()") + } + } + } + } + + /** + * Generate the service framework interface and bootstrap (`ServiceFramework.kt`). + * + * Declares a common `ServiceFramework` interface with lifecycle methods and + * delegates framework-specific implementation details to subclasses. + */ + protected fun renderServiceFramework() { + delegator.useFileWriter("ServiceFramework.kt", "${ctx.settings.pkg.name}.framework") { writer -> + + writer.withBlock("internal interface ServiceFramework: #T {", "}", RuntimeTypes.Core.IO.Closeable) { + write("// start the service and begin accepting connections") + write("public fun start()") + } + .write("") + + renderServerFrameworkImplementation(writer) + } + } + + /** Render the specific server framework implementation (e.g., Ktor). */ + protected abstract fun renderServerFrameworkImplementation(writer: KotlinWriter) + + /** Generate service plugins such as content-type guards, error handlers, etc. */ + protected abstract fun renderPlugins() + + /** Generate supporting utility classes and functions. */ + protected abstract fun renderUtils() + + /** Generate authentication module interfaces and installers (e.g., bearer auth, SigV4, SigV4A). */ + protected abstract fun renderAuthModule() + + /** Generate request-level constraint validators for Smithy model constraints. */ + protected abstract fun renderConstraintValidators() + + /** Generate a request handler for each Smithy operation. */ + protected abstract fun renderPerOperationHandlers() + + /** Generate the route table that maps Smithy operations to runtime endpoints. */ + protected abstract fun renderRouting() + + /** + * Generate the top-level `Main.kt` launcher file. + * + * This file provides the `main()` entrypoint: + * - Parses command-line arguments + * - Applies defaults for configuration values + * - Initializes the `ServiceFrameworkConfig` + * - Starts the appropriate service framework + */ + protected fun renderMainFile() { + val portName = "port" + val engineFactoryName = "engineFactory" + val regionName = "region" + val requestBodyLimitName = "requestBodyLimit" + val requestReadTimeoutSecondsName = "requestReadTimeoutSeconds" + val responseWriteTimeoutSecondsName = "responseWriteTimeoutSeconds" + val closeGracePeriodMillisName = "closeGracePeriodMillis" + val closeTimeoutMillisName = "closeTimeoutMillis" + val logLevelName = "logLevel" + delegator.useFileWriter("Main.kt", ctx.settings.pkg.name) { writer -> + + writer.withBlock("public fun main(args: Array): Unit {", "}") { + write("val argMap: Map = args.asList().chunked(2).associate { (k, v) -> k.removePrefix(#S) to v }", "--") + write("") + write("val defaultPort = 8080") + write("val defaultEngine = #T.NETTY_ENGINE.value", ServiceTypes(pkgName).serviceEngine) + write("val defaultRegion = #S", "us-east-1") + write("val defaultRequestBodyLimit = 10L * 1024 * 1024") + write("val defaultRequestReadTimeoutSeconds = 30") + write("val defaultResponseWriteTimeoutSeconds = 30") + write("val defaultCloseGracePeriodMillis = 1_000L") + write("val defaultCloseTimeoutMillis = 5_000L") + write("val defaultLogLevel = #T.INFO.value", ServiceTypes(pkgName).logLevel) + write("") + withBlock("#T.init(", ")", ServiceTypes(pkgName).serviceFrameworkConfig) { + write("port = argMap[#S]?.toInt() ?: defaultPort, ", portName) + write("engine = #T.fromValue(argMap[#S] ?: defaultEngine), ", ServiceTypes(pkgName).serviceEngine, engineFactoryName) + write("region = argMap[#S]?.toString() ?: defaultRegion, ", regionName) + write("requestBodyLimit = argMap[#S]?.toLong() ?: defaultRequestBodyLimit, ", requestBodyLimitName) + write("requestReadTimeoutSeconds = argMap[#S]?.toInt() ?: defaultRequestReadTimeoutSeconds, ", requestReadTimeoutSecondsName) + write("responseWriteTimeoutSeconds = argMap[#S]?.toInt() ?: defaultResponseWriteTimeoutSeconds, ", responseWriteTimeoutSecondsName) + write("closeGracePeriodMillis = argMap[#S]?.toLong() ?: defaultCloseGracePeriodMillis, ", closeGracePeriodMillisName) + write("closeTimeoutMillis = argMap[#S]?.toLong() ?: defaultCloseTimeoutMillis, ", closeTimeoutMillisName) + write("logLevel = #T.fromValue(argMap[#S] ?: defaultLogLevel), ", ServiceTypes(pkgName).logLevel, logLevelName) + } + write("") + when (ctx.settings.serviceStub.framework) { + ServiceFramework.KTOR -> write("val service = #T()", ServiceTypes(pkgName).ktorServiceFramework) + } + write("service.start()") + } + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt new file mode 100644 index 0000000000..74385533cf --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt @@ -0,0 +1,85 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.kotlin.codegen.service + +import software.amazon.smithy.kotlin.codegen.model.buildSymbol + +class ServiceTypes(val pkgName: String) { + val logLevel = buildSymbol { + name = "LogLevel" + namespace = "$pkgName.config" + } + + val serviceEngine = buildSymbol { + name = "ServiceEngine" + namespace = "$pkgName.config" + } + + val serviceFrameworkConfig = buildSymbol { + name = "ServiceFrameworkConfig" + namespace = "$pkgName.config" + } + + val ktorServiceFramework = buildSymbol { + name = "KtorServiceFramework" + namespace = "$pkgName.framework" + } + + val module = buildSymbol { + name = "module" + namespace = "$pkgName.framework" + } + + val configureErrorHandling = buildSymbol { + name = "configureErrorHandling" + namespace = "$pkgName.plugins" + } + + val configureRouting = buildSymbol { + name = "configureRouting" + namespace = pkgName + } + + val configureLogging = buildSymbol { + name = "configureLogging" + namespace = "$pkgName.utils" + } + + val configureAuthentication = buildSymbol { + name = "configureAuthentication" + namespace = "$pkgName.auth" + } + + val errorEnvelope = buildSymbol { + name = "ErrorEnvelope" + namespace = "$pkgName.plugins" + } + + val contentTypeGuard = buildSymbol { + name = "ContentTypeGuard" + namespace = "$pkgName.plugins" + } + + val acceptTypeGuard = buildSymbol { + name = "AcceptTypeGuard" + namespace = "$pkgName.plugins" + } + + val sizeOf = buildSymbol { + name = "sizeOf" + namespace = "$pkgName.constraints" + } + + val hasAllUniqueElements = buildSymbol { + name = "hasAllUniqueElements" + namespace = "$pkgName.constraints" + } + + val responseHandledKey = buildSymbol { + name = "ResponseHandledKey" + namespace = "$pkgName.plugins" + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/AbstractConstraintTraitGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/AbstractConstraintTraitGenerator.kt new file mode 100644 index 0000000000..fc9c0319c0 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/AbstractConstraintTraitGenerator.kt @@ -0,0 +1,5 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +internal abstract class AbstractConstraintTraitGenerator { + abstract fun render() +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/ConstraintGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/ConstraintGenerator.kt new file mode 100644 index 0000000000..cde979575d --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/ConstraintGenerator.kt @@ -0,0 +1,104 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.GenerationContext +import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.RequiredTrait +import kotlin.collections.iterator + +/** + * Generates validation code for request constraints on Smithy operation inputs. + * + * For a given [operation], this generator traverses the input structure and: + * - Recursively inspects members of structures and lists. + * - Applies trait-based validations (e.g., required, length, range). + * - Generates Kotlin validation functions that check constraints at runtime. + * + * Output is written into a `RequestConstraints.kt` file in the generated `constraints` package. + */ +internal class ConstraintGenerator( + val ctx: GenerationContext, + val operation: OperationShape, + val delegator: KotlinDelegator, +) { + val inputShape = ctx.model.expectShape(operation.input.get()) as StructureShape + val inputMembers = inputShape.allMembers + + val opName = operation.id.name + val pkgName = ctx.settings.pkg.name + + /** + * Entry point for emitting validation code for the operation’s request type. + * Delegates to [renderRequestConstraintsValidation]. + */ + fun render() { + renderRequestConstraintsValidation() + } + + /** + * Recursively generates validation code for a given [memberShape]. + * + * - If the target is a list, iterates over elements and validates them. + * - If the target is a structure, recursively validates its members. + * - For each trait (on the member or its target), invokes the matching trait generator. + * - `@required` traits are always enforced. + * - Other traits are wrapped in a null check before validation. + */ + private fun generateConstraintValidations(prefix: String, memberShape: MemberShape, writer: KotlinWriter) { + val targetShape = ctx.model.expectShape(memberShape.target) + + val memberName = memberShape.memberName + val memberAndTargetTraits = memberShape.allTraits + targetShape.allTraits + when { + targetShape.isListShape -> + for (member in targetShape.allMembers) { + val newMemberPrefix = "${targetShape.id.name}".replaceFirstChar { it.lowercase() } + writer.withBlock("if ($prefix$memberName != null) {", "}") { + withBlock("for ($newMemberPrefix${member.key} in $prefix$memberName ?: listOf()) {", "}") { + call { generateConstraintValidations(newMemberPrefix, member.value, writer) } + } + } + } + targetShape.isStructureShape -> + for (member in targetShape.allMembers) { + val newMemberPrefix = "$prefix$memberName?." + generateConstraintValidations(newMemberPrefix, member.value, writer) + } + } + for (memberTrait in memberAndTargetTraits.values) { + val traitGenerator = getTraitGeneratorFromTrait(prefix, memberName, memberTrait, pkgName, writer) + traitGenerator?.apply { + if (memberTrait !is RequiredTrait) { + writer.withBlock("if ($prefix$memberName != null) {", "}") { + render() + } + } else { + render() + } + } + } + } + + /** + * Writes the top-level validation function for the operation’s input type. + * + * Inside, it calls [generateConstraintValidations] for each input member, + * ensuring all modeled constraints are enforced. + */ + private fun renderRequestConstraintsValidation() { + delegator.useFileWriter("${opName}RequestConstraints.kt", "$pkgName.constraints") { writer -> + val inputShape = ctx.model.expectShape(operation.input.get()) + val inputSymbol = ctx.symbolProvider.toSymbol(inputShape) + + writer.withBlock("public fun check${opName}RequestConstraint(data: #T) {", "}", inputSymbol) { + for (memberShape in inputMembers.values) { + generateConstraintValidations("data.", memberShape, writer) + } + } + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/ConstraintUtilsGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/ConstraintUtilsGenerator.kt new file mode 100644 index 0000000000..b93f690287 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/ConstraintUtilsGenerator.kt @@ -0,0 +1,104 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.GenerationContext +import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.withBlock + +/** + * Generates utility functions to support constraint validation for Smithy models. + * + * This generator emits reusable helpers that can be called from operation-specific + * validation code (e.g., generated by [ConstraintGenerator]). These helpers enforce + * common traits like `@length` and `@uniqueItems`. + * + * Output is written into a `utils.kt` file under the generated `constraints` package. + */ +internal class ConstraintUtilsGenerator( + val ctx: GenerationContext, + val delegator: KotlinDelegator, +) { + val pkgName = ctx.settings.pkg.name + + fun render() { + delegator.useFileWriter("utils.kt", "$pkgName.constraints") { writer -> + renderLengthTraitUtils(writer) + + writer.write("") + renderUniqueItemsTraitUtils(writer) + } + } + + /** + * Generates the `sizeOf()` function. + * + * This utility computes a generalized "size" for multiple types: + * - Collections, arrays, maps → `size` + * - Strings → Unicode code point count + * - Byte arrays → length + * + * Any unsupported type will throw an `IllegalArgumentException`. + */ + private fun renderLengthTraitUtils(writer: KotlinWriter) { + writer.withBlock("internal fun sizeOf(value: Any?): Long = when (value) {", "}") { + write("is Collection<*> -> value.size.toLong()") + write("is Array<*> -> value.size.toLong()") + write("is Map<*, *> -> value.size.toLong()") + write("is String -> value.codePointCount(0, value.length).toLong()") + write("is ByteArray -> value.size.toLong()") + withBlock("else -> {", "}") { + write("val typeName = value?.javaClass?.simpleName ?: #S", "null") + write("throw IllegalArgumentException( #S )", "sizeOf does not support \${typeName} type") + } + } + } + + /** + * Generates the `hasAllUniqueElements()` function. + * + * This utility checks if a list contains only unique elements, where uniqueness + * is defined by deep structural equality: + * - Primitive wrappers (String, Boolean, Number, Instant) → compared by value + * - Byte arrays → compared by contents + * - Lists → recursively compared element by element + * - Maps → recursively compared entries by key/value + */ + private fun renderUniqueItemsTraitUtils(writer: KotlinWriter) { + writer.withBlock("internal fun hasAllUniqueElements(elements: List): Boolean {", "}") { + withBlock("class Wrapped(private val v: Any?) {", "}") { + withBlock("override fun equals(other: Any?): Boolean {", "}") { + write("if (other !is Wrapped) return false") + write("if (v?.javaClass != other.v?.javaClass) return false") + withBlock("return when (v) {", "}") { + write("null -> true") + write("is String,") + write("is Boolean,") + write("is java.time.Instant,") + write("is Number -> v == other.v") + write("is ByteArray -> v.contentEquals(other.v as ByteArray)") + withBlock("is List<*> -> {", "}") { + write("val o = other.v as List<*>") + write("v.size == o.size && v.indices.all { i -> Wrapped(v[i]) == Wrapped(o[i]) }") + } + withBlock("is Map<*, *> -> {", "}") { + write("val o = other.v as Map<*, *>") + write("v.size == o.size && v.all { (k, value) -> o.containsKey(k) && Wrapped(value) == Wrapped(o[k]) }") + } + write("else -> v == other.v") + } + } + withBlock("override fun hashCode(): Int = when (v) {", "}") { + write("null -> 0") + write("is ByteArray -> v.contentHashCode()") + write("is List<*> -> v.fold(1) { acc, e -> 31 * acc + Wrapped(e).hashCode() }") + write("is Map<*, *> -> v.entries.fold(1) { acc, (k, e) -> 31 * acc + Wrapped(k).hashCode() xor Wrapped(e).hashCode() }") + write("else -> v.hashCode()") + } + } + write("") + write("val seen = HashSet(elements.size)") + write("for (e in elements) if (!seen.add(Wrapped(e))) return false") + write("return true") + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/LengthConstraintGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/LengthConstraintGenerator.kt new file mode 100644 index 0000000000..04cef3ea9b --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/LengthConstraintGenerator.kt @@ -0,0 +1,21 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.service.ServiceTypes +import software.amazon.smithy.model.traits.LengthTrait + +internal class LengthConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: LengthTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() { + override fun render() { + val min = trait.min.orElse(null) + val max = trait.max.orElse(null) + val member = "$memberPrefix$memberName" + + if (max != null && min != null) { + writer.write("require(#T($member) in $min..$max) { #S }", ServiceTypes(pkgName).sizeOf, "The size of `$memberName` must be between $min and $max (inclusive)") + } else if (max != null) { + writer.write("require(#T($member) <= $max) { #S }", ServiceTypes(pkgName).sizeOf, "The size of `$memberName` must be less than or equal to $max") + } else { + writer.write("require(#T($member) >= $min) { #S }", ServiceTypes(pkgName).sizeOf, "The size of `$memberName` must be greater than or equal to $min") + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/PatternConstraintGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/PatternConstraintGenerator.kt new file mode 100644 index 0000000000..9c02278190 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/PatternConstraintGenerator.kt @@ -0,0 +1,12 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.model.traits.PatternTrait + +internal class PatternConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: PatternTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() { + override fun render() { + val member = "$memberPrefix$memberName" + + writer.write("require(Regex(#S).containsMatchIn($member)) { #S }", trait.pattern.toString(), "Value `\${$member}` does not match required pattern: `${trait.pattern}`") + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/RangeConstraintGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/RangeConstraintGenerator.kt new file mode 100644 index 0000000000..9c2eeeb279 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/RangeConstraintGenerator.kt @@ -0,0 +1,20 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.model.traits.RangeTrait + +internal class RangeConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: RangeTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() { + override fun render() { + val min = trait.min.orElse(null) + val max = trait.max.orElse(null) + val member = "$memberPrefix$memberName" + + if (max != null && min != null) { + writer.write("require($member in $min..$max) { #S }", "`$memberName` must be between $min and $max (inclusive)") + } else if (max != null) { + writer.write("require($member <= $max) { #S }", "`$memberName` must be less than or equal to $max") + } else { + writer.write("require($member >= $min) { #S }", "`$memberName` must be greater than or equal to $min") + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/RequiredConstraintGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/RequiredConstraintGenerator.kt new file mode 100644 index 0000000000..a0fc814a94 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/RequiredConstraintGenerator.kt @@ -0,0 +1,11 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.model.traits.RequiredTrait + +internal class RequiredConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: RequiredTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() { + override fun render() { + val member = "$memberPrefix$memberName" + writer.write("require($member != null) { #S }", "`$memberName` must be provided") + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/UniqueItemsConstraintGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/UniqueItemsConstraintGenerator.kt new file mode 100644 index 0000000000..644b2bc99c --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/UniqueItemsConstraintGenerator.kt @@ -0,0 +1,12 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.service.ServiceTypes +import software.amazon.smithy.model.traits.UniqueItemsTrait + +internal class UniqueItemsConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: UniqueItemsTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() { + override fun render() { + val member = "$memberPrefix$memberName" + writer.write("require(#T($member)) { #S }", ServiceTypes(pkgName).hasAllUniqueElements, "`$memberName` must contain only unique items, duplicate values are not allowed") + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/utils.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/utils.kt new file mode 100644 index 0000000000..82985ac9fa --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/constraints/utils.kt @@ -0,0 +1,24 @@ +package software.amazon.smithy.kotlin.codegen.service.constraints + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.model.traits.LengthTrait +import software.amazon.smithy.model.traits.PatternTrait +import software.amazon.smithy.model.traits.RangeTrait +import software.amazon.smithy.model.traits.RequiredTrait +import software.amazon.smithy.model.traits.Trait +import software.amazon.smithy.model.traits.UniqueItemsTrait + +internal fun getTraitGeneratorFromTrait( + memberPrefix: String, + memberName: String, + trait: Trait, + pkgName: String, + writer: KotlinWriter, +): AbstractConstraintTraitGenerator? = when (trait) { + is LengthTrait -> LengthConstraintGenerator(memberPrefix, memberName, trait, pkgName, writer) + is PatternTrait -> PatternConstraintGenerator(memberPrefix, memberName, trait, pkgName, writer) + is RangeTrait -> RangeConstraintGenerator(memberPrefix, memberName, trait, pkgName, writer) + is UniqueItemsTrait -> UniqueItemsConstraintGenerator(memberPrefix, memberName, trait, pkgName, writer) + is RequiredTrait -> RequiredConstraintGenerator(memberPrefix, memberName, trait, pkgName, writer) + else -> null +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/docs/FEATURES.md b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/docs/FEATURES.md new file mode 100644 index 0000000000..03784e96b6 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/docs/FEATURES.md @@ -0,0 +1,59 @@ +# Summary + +--- + +### Features Support + +| **Features** | **Description** | +|-----------------------------------|-------------------------------------------------------------------------------------------------| +| Service Framework | Abstracted service framework interface and base implementation with Ktor as the default backend | +| CBOR Protocol | Support for CBOR serialization / deserialization and CBOR protocol traits | +| Json Protocol | Support for Json serialization / deserialization and Json protocol traits | +| Routing | Per-operation routing generation with Ktor DSL; ties to handler and validation | +| Error Handler | Unified exception handling logic mapped to HTTP status codes and support for error trait | +| Authentication (bearer) | Bearer token authentication middleware with model-driven configuration | +| Authentication (SigV4 and SigV4A) | SigV4 and SigV4A authentication middleware with model-driven configuration | +| Logging | Structured logging setup | +| Constraints Checker | Validation logic generated from Smithy traits and invoked pre-handler | +| Unit Test | Covers serialization/deserialization, routing, validation, and integration tests | + +### Smithy Protocol Traits Support + +| **Traits** | **CBOR Protocol** | **Json Protocol** | +|--------------------------|-------------------|-------------------| +| http | Yes | Yes | +| httpError | Yes | Yes | +| httpHeader | Not supported | Yes | +| httpPrefixHeader | Not supported | Yes | +| httpLabel | Not supported | Yes | +| httpQuery | Not supported | Yes | +| httpQueryParams | Not supported | Yes | +| httpPayload | Not supported | Yes | +| jsonName | Not supported | Yes | +| timestampFormat | Not supported | Yes | +| httpChecksumRequired | Not supported | Not implemented yet | +| requestCompression | Not implemented yet | Not implemented yet | + +### Constraint Traits Support + +| **Traits** | **CBOR Protocol** | **Json Protocol** | +|-----------------|------------------------------|------------------------------| +| required | Yes | Yes | +| length | Yes | Yes | +| pattern | Yes | Yes | +| private | Yes (handled by Smithy) | Yes (handled by Smithy) | +| range | Yes | Yes | +| uniqueItems | Yes | Yes | +| idRef | Not implemented yet | Not implemented yet | + + +### Future Features + +| Feature | Description | +|-----------------------------------|-------------------------------------------------------------------------------------------------| +| Additional Protocols | XML, Ec2Query, AWSQuery protocols | +| Middleware / Interceptors | Cross-cutting logic support (e.g., metrics, headers, rate limiting) via middleware architecture | +| API Versioning | Built-in support for versioned APIs to maintain backward compatibility | +| gRPC / WebSocket Protocol Support | High-performance binary RPC and real-time bidirectional communication | +| Metrics & Tracing | Observability support with metrics, logs, and distributed tracing for debugging and monitoring | +| Caching Middleware | Per-route or global cache support to improve response times and reduce backend load | diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/docs/GettingStarted.md b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/docs/GettingStarted.md new file mode 100644 index 0000000000..73a7cdb348 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/docs/GettingStarted.md @@ -0,0 +1,181 @@ +# Getting Started + +### Step 1: Build & Publish Codegen to Local Maven +First, in **this repository**, build and publish the code generator locally: +```bash + ./gradlew :codegen:smithy-kotlin-codegen:build + ./gradlew publishToMavenLocal +``` + +### Step 2: Create a New Kotlin Project +Now, create a **new Kotlin project** where you will use the Smithy Kotlin service code generator. You can find a full example demo project [here](../../../../../../../../../../../../examples/service-codegen) + +From this point forward, **all steps apply to the new Kotlin project** you just created. + + +### Step 3: Configure `build.gradle.kts` in the New Project + +```kotlin +plugins { + alias(libs.plugins.kotlin.jvm) + id("software.amazon.smithy.gradle.smithy-jar") version "1.3.0" // check for latest version + application +} + +repositories { + mavenLocal() + mavenCentral() +} + +val codegenVersion = "0.35.2-SNAPSHOT" +val smithyVersion = "1.60.2" + +dependencies { + smithyBuild("software.amazon.smithy.kotlin:smithy-kotlin-codegen:$codegenVersion") + implementation("software.amazon.smithy.kotlin:smithy-aws-kotlin-codegen:$codegenVersion") + implementation("software.amazon.smithy:smithy-model:$smithyVersion") + implementation("software.amazon.smithy:smithy-build:$smithyVersion") + implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") + ... +} +``` + + +### Step 4: Create `smithy-build.json` in the New Project +This is an example of smithy-build.json. +```json +{ + "version": "1.0", + "outputDirectory": "build/generated-src", + "plugins": { + "kotlin-codegen": { + "service": "com.demo#DemoService", + "package": { + "name": "com.demo.server", + "version": "1.0.0" + }, + "build": { + "rootProject": true, + "generateServiceProject": true, + "optInAnnotations": [ + "aws.smithy.kotlin.runtime.InternalApi", + "kotlinx.serialization.ExperimentalSerializationApi" + ] + }, + "serviceStub": { + "framework": "ktor" + } + } + } +} +``` + +**Notes:** +- The most important fields are: + - **`outputDirectory`** — defines where the generated service code will be placed in your new project. + - **`service`** — must match your Smithy model’s `#`. + - **`serviceStub.framework`** — defines the server framework for generated code. Currently only `"ktor"` is supported. + +### Step 5: Define Your Smithy Model in the New Project + +Create a `model` directory and add your `.smithy` files. +Example `model/greeter.smithy`: + +```smithy +$version: "2.0" +namespace com.demo + +use aws.protocols#restJson1 +use smithy.api#httpBearerAuth + +@restJson1 +@httpBearerAuth +service DemoService { + version: "1.0.0" + operations: [ + SayHello + ] +} + +@http(method: "POST", uri: "/greet", code: 201) +operation SayHello { + input: SayHelloInput + output: SayHelloOutput + errors: [ + CustomError + ] +} + +@input +structure SayHelloInput { + @required + @length(min: 3, max: 10) + name: String + @httpHeader("X-User-ID") + id: Integer +} + +@output +structure SayHelloOutput { + greeting: String +} + +@error("server") +@httpError(500) +structure CustomError { + msg: String + @httpHeader("X-User-error") + err: String +} +``` + +### Step 6: Generate the Service in the New Project + +Run: +```bash + gradle build +``` + +⚠️ Running gradle build will delete the previous build output before creating a new one. + +If you want to prevent accidentally losing previous build, use the provided scripts instead: + +You can find script for Linux / macOS [here](../../../../../../../../../../../../examples/service-codegen/build.sh): +```bash + chmod +x build.sh + ./build.sh +``` + +You can find script for Windows [here](../../../../../../../../../../../../examples/service-codegen/build.bat): +```bash + icacls build.bat /grant %USERNAME%:RX + .\build.bat +``` + +If you want to clean previously generated code: +```bash + gradle clean +``` + +### Step 7: Run the Generated Service + +The generated service will be in the directory specified in `smithy-build.json` (`outputDirectory`). +You can start it by running: +```bash + gradle run +``` +By default, it listens on port **8080**. + +### Step 8: Adjust Service Configuration + +You can override runtime settings (such as port or HTTP engine) using command-line arguments: +```bash + gradle run --args="port 8000 engineFactory cio" +``` +You can find all available settings [here](https://github.com/smithy-lang/smithy-kotlin/blob/16bd523e2ccd6177dcc662466107189b013a818d/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt#L179C1-L186C38) + +--- + +## Notes +- **Business Logic**: Implement your own logic in the generated operation handler interfaces. +- **Configuration**: Adjust port, engine, auth, and other settings via `ServiceFrameworkConfig` or CLI args. diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Authentication.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Authentication.kt new file mode 100644 index 0000000000..5cac41a4b1 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Authentication.kt @@ -0,0 +1,68 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.aws.traits.auth.SigV4ATrait +import software.amazon.smithy.aws.traits.auth.SigV4Trait +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.model.getTrait +import software.amazon.smithy.kotlin.codegen.service.ServiceTypes + +/** + * Writes Ktor-based authentication support classes and configuration + * for a generated service. + * + * This generates three files: + * 1. UserPrincipal.kt → Represents the authenticated user. + * 2. Validation.kt → Provides bearer token validation logic. + * 3. Authentication.kt → Configures authentication providers in Ktor. + */ +internal fun KtorStubGenerator.writeAuthentication() { + delegator.useFileWriter("UserPrincipal.kt", "$pkgName.auth") { writer -> + writer.withBlock("public data class UserPrincipal(", ")") { + write("val user: String") + } + } + + delegator.useFileWriter("Validation.kt", "$pkgName.auth") { writer -> + + writer.withBlock("internal object BearerValidation {", "}") { + withBlock("public fun bearerValidation(token: String): UserPrincipal? {", "}") { + write("// TODO: implement me:") + write("// Validate the provided bearer token and return a UserPrincipal if valid.") + write("// Return a UserPrincipal with user information (e.g., user id, roles) if valid,") + write("// or return null if the token is invalid or expired.") + write("if (true) return UserPrincipal(#S) else return null", "Authenticated User") + } + } + } + + delegator.useFileWriter("Authentication.kt", "$pkgName.auth") { writer -> + writer.withBlock("internal fun #T.configureAuthentication() {", "}", RuntimeTypes.KtorServerCore.Application) { + write("") + withBlock( + "#T(#T) {", + "}", + RuntimeTypes.KtorServerCore.install, + RuntimeTypes.KtorServerAuth.Authentication, + ) { + withBlock("#T(#S) {", "}", RuntimeTypes.KtorServerAuth.bearer, "auth-bearer") { + write("realm = #S", "Access to API") + write("authenticate { cred -> BearerValidation.bearerValidation(cred.token) }") + } + withBlock("sigV4(name = #S) {", "}", "aws-sigv4") { + write("region = #T.region", ServiceTypes(pkgName).serviceFrameworkConfig) + serviceShape.getTrait()?.let { + write("service = #S", it.name) + } + } + withBlock("sigV4A(name = #S) {", "}", "aws-sigv4a") { + write("region = #T.region", ServiceTypes(pkgName).serviceFrameworkConfig) + serviceShape.getTrait()?.let { + write("service = #S", it.name) + } + } + write("provider(#S) { authenticate { ctx -> ctx.principal(Unit) } }", "no-auth") + } + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/AuthenticationAWS.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/AuthenticationAWS.kt new file mode 100644 index 0000000000..e174f19c6b --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/AuthenticationAWS.kt @@ -0,0 +1,493 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.closeAndOpenBlock +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.core.withInlineBlock +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes + +/** + * Writes AWS-specific authentication support for Ktor service stubs. + * + * This generates the following files: + * - AWSValidation.kt → Stub stores for SigV4 credentials and SigV4A public keys. + * - AWSSigV4.kt → Ktor authentication provider and verifier for AWS Signature V4 (HMAC). + * - AWSSigV4A.kt → Ktor authentication provider and verifier for AWS Signature V4A (ECDSA). + */ +internal fun KtorStubGenerator.writeAWSAuthentication() { + delegator.useFileWriter("AWSValidation.kt", "$pkgName.auth") { writer -> + writer.withBlock("internal object SigV4CredentialStore {", "}") { + write("private val table: Map = mapOf()", RuntimeTypes.Auth.Credentials.AwsCredentials.Credentials) + withBlock("internal fun get(accessKeyId: String): #T? {", "}", RuntimeTypes.Auth.Credentials.AwsCredentials.Credentials) { + write("// TODO: implement me:") + write("// Look up the credentials associated with this accessKeyId.") + write("// Return a Credentials object like this:") + write("// Credentials(") + write("// accessKeyId = #S,", "") + write("// secretAccessKey = #S,", "") + write("// sessionToken = #S,", "") + write("// )") + write("return table[accessKeyId]") + } + } + writer.write("") + writer.withBlock("internal object SigV4aPublicKeyStore {", "}") { + write("private val table: Map = mapOf()") + write("") + withBlock("internal fun get(accessKeyId: String): java.security.PublicKey? {", "}") { + write("// TODO: implement me:") + write("// Look up the public key associated with this accessKeyId.") + write("// Example if loading from bytes:") + write("// val spec = X509EncodedKeySpec(keyBytes)") + write("// val kf = KeyFactory.getInstance(\"EC\")") + write("// return kf.generatePublic(spec)") + write("// Return the java.security.PublicKey that should be used to verify SigV4A ECDSA signatures.") + write("return table[accessKeyId]") + } + } + } + + delegator.useFileWriter("AWSSigV4.kt", "$pkgName.auth") { writer -> + writer.withInlineBlock("internal fun #T.sigV4(", ")", RuntimeTypes.KtorServerAuth.AuthenticationConfig) { + write("name: String = #S,", "aws-sigv4") + write("configure: SigV4AuthProvider.Configuration.() -> Unit = {}") + } + .withBlock("{", "}") { + write("val provider = SigV4AuthProvider(SigV4AuthProvider.Configuration(name).apply(configure))") + write("register(provider)") + } + .write("") + + writer.withBlock("internal class SigV4AuthProvider(config: Configuration) : #T(config) {", "}", RuntimeTypes.KtorServerAuth.AuthenticationProvider) { + withBlock("internal class Configuration(name: String?) : #T.Config(name) {", "}", RuntimeTypes.KtorServerAuth.AuthenticationProvider) { + write("var region: String = #S", "us-east-1") + write("var service: String = #S", "execute-api") + write("var clockSkew: #T = 5.#T", KotlinTypes.Time.Duration, KotlinTypes.Time.minutes) + } + write("") + write("private val region = (config as Configuration).region") + write("private val service = config.service") + write("private val skew = config.clockSkew") + write("") + withBlock("override suspend fun onAuthenticate(context: #T) {", "}", RuntimeTypes.KtorServerAuth.AuthenticationContext) { + write("val creds = verifySigV4(context.call, region, service, skew)") + withInlineBlock("if (creds == null) {", "}") { + withBlock("context.challenge(#S, #T.InvalidCredentials) { challenge, call ->", "}", "AWS4-HMAC-SHA256", RuntimeTypes.KtorServerAuth.AuthenticationFailedCause) { + write("call.#T(#T.Unauthorized, #S)", RuntimeTypes.KtorServerRouting.responseResponse, RuntimeTypes.KtorServerHttp.HttpStatusCode, "Unauthorized") + write("challenge.complete()") + } + } + withBlock(" else {", "}") { + write("context.principal(UserPrincipal(creds.accessKeyId))") + } + } + } + .write("") + + writer.withInlineBlock("public suspend fun verifySigV4(", ")") { + write("call: #T,", RuntimeTypes.KtorServerCore.ApplicationCallClass) + write("region: String,") + write("service: String,") + write("maxClockSkew: #T", KotlinTypes.Time.Duration) + } + .withBlock(": #T? {", "}", RuntimeTypes.Auth.Credentials.AwsCredentials.Credentials) { + retrieveAuthInformation(writer, "AWS4-HMAC-SHA256") + write("") + write("val scope = credential.substringAfter(#S, missingDelimiterValue = #S)", "/", "") + write("val parts = scope.split(#S)", "/") + write("if (parts.size != 4) return null") + write("val (yyyyMMdd, scopeRegion, scopeService, term) = parts") + write("if (scopeRegion != region || scopeService != service || term != #S) return null", "aws4_request") + write("if (!Regex(#S).matches(yyyyMMdd)) return null", "^\\d{8}$") + write("") + authDateValidation(writer) + write("") + write("val creds = SigV4CredentialStore.get(accessKeyId) ?: return null") + write("") + write("val secTokenHeaderName = #S", "x-amz-security-token") + write("val secToken = call.request.headers[secTokenHeaderName]") + withBlock("if (creds.sessionToken != null) {", "}") { + write("if (secToken == null || secToken != creds.sessionToken) return null") + write("if (secTokenHeaderName !in signedHeaders) return null") + } + write("") + write("val contentSha256 = call.request.headers[#S]", "x-amz-content-sha256") + write("val isUnsigned = contentSha256 == #S", "UNSIGNED-PAYLOAD") + write("") + createHttpRequestBuilder(writer) + write("") + validateSigV4(writer) + write("") + write("return if (expectedSig == signatureHex) creds else null") + } + } + + delegator.useFileWriter("AWSSigV4A.kt", "$pkgName.auth") { writer -> + writer.withInlineBlock("internal fun #T.sigV4A(", ")", RuntimeTypes.KtorServerAuth.AuthenticationConfig) { + write("name: String = #S,", "aws-sigv4a") + write("configure: SigV4AAuthProvider.Configuration.() -> Unit = {}") + } + .withBlock("{", "}") { + write("val provider = SigV4AAuthProvider(SigV4AAuthProvider.Configuration(name).apply(configure))") + write("register(provider)") + } + .write("") + + writer.withBlock("internal class SigV4AAuthProvider(config: Configuration) : #T(config) {", "}", RuntimeTypes.KtorServerAuth.AuthenticationProvider) { + withBlock("internal class Configuration(name: String?) : #T.Config(name) {", "}", RuntimeTypes.KtorServerAuth.AuthenticationProvider) { + write("var region: String = #S", "us-east-1") + write("var service: String = #S", "execute-api") + write("var clockSkew: #T = 5.#T", KotlinTypes.Time.Duration, KotlinTypes.Time.minutes) + } + write("") + write("private val region = (config as Configuration).region") + write("private val service = config.service") + write("private val skew = config.clockSkew") + write("") + withBlock("override suspend fun onAuthenticate(context: #T) {", "}", RuntimeTypes.KtorServerAuth.AuthenticationContext) { + write("val creds = verifySigV4A(context.call, region, service, skew)") + withInlineBlock("if (creds == null) {", "}") { + withBlock("context.challenge(#S, #T.InvalidCredentials) { challenge, call ->", "}", "AWS4-HMAC-SHA256", RuntimeTypes.KtorServerAuth.AuthenticationFailedCause) { + write("call.#T(#T.Unauthorized, #S)", RuntimeTypes.KtorServerRouting.responseResponse, RuntimeTypes.KtorServerHttp.HttpStatusCode, "Unauthorized") + write("challenge.complete()") + } + } + withBlock(" else {", "}") { + write("context.principal(UserPrincipal(creds.accessKeyId))") + } + } + } + .write("") + + writer.withInlineBlock("public suspend fun verifySigV4A(", ")") { + write("call: #T,", RuntimeTypes.KtorServerCore.ApplicationCallClass) + write("region: String,") + write("service: String,") + write("maxClockSkew: #T", KotlinTypes.Time.Duration) + } + .withBlock(": #T? {", "}", RuntimeTypes.Auth.Credentials.AwsCredentials.Credentials) { + retrieveAuthInformation(writer, "AWS4-ECDSA-P256-SHA256") + write("") + write("val scope = credential.substringAfter(#S, missingDelimiterValue = #S)", "/", "") + write("val parts = scope.split(#S)", "/") + write("if (parts.size != 3) return null") + write("val (yyyyMMdd, scopeService, term) = parts") + write("if (scopeService != service || term != #S) return null", "aws4_request") + write("if (!Regex(#S).matches(yyyyMMdd)) return null", "^\\d{8}$") + write("") + write("val regionSetHeaderName = #S", "x-amz-region-set") + write("val rawRegionSet = call.request.headers[regionSetHeaderName] ?: return null") + write("if (regionSetHeaderName !in signedHeaders) return null") + write("") + write("val regionSet: List = rawRegionSet.split(',').map { it.trim().lowercase() }.filter { it.isNotEmpty() }.ifEmpty { return null }") + write("") + withBlock("fun matchesRegion(pattern: String, value: String): Boolean {", "}") { + write("if (pattern == #S) return true", "*") + write("val normalized = pattern.trim().replace(Regex(#S), #S)", "\\*+", "*") + write("val sb = StringBuilder(#S)", "^") + write("val parts = normalized.split(#S)", "*") + withBlock("parts.forEachIndexed { i, part ->", "}") { + write("sb.append(Regex.escape(part))") + write("if (i < parts.lastIndex) sb.append(#S)", "[^-]+") + } + write("sb.append(#S)", "$") + write("return Regex(sb.toString(), RegexOption.IGNORE_CASE).matches(value.trim())") + } + write("if (regionSet.none { matchesRegion(it, region.lowercase()) }) return null") + write("") + authDateValidation(writer) + write("") + write("val creds = SigV4CredentialStore.get(accessKeyId) ?: return null") + write("") + write("val secTokenHeaderName = #S", "x-amz-security-token") + write("val secToken = call.request.headers[secTokenHeaderName]") + withBlock("if (creds.sessionToken != null) {", "}") { + write("if (secToken == null || secToken != creds.sessionToken) return null") + write("if (secTokenHeaderName !in signedHeaders) return null") + } + write("") + write("val contentSha256 = call.request.headers[#S]", "x-amz-content-sha256") + write("val isUnsigned = contentSha256 == #S", "UNSIGNED-PAYLOAD") + write("") + createCanonicalRequest(writer) + write("") + validateSigV4A(writer) + write("") + write("return if (ok) creds else null") + } + .write("") + renderHelperFunctions(writer) + } +} + +/** + * Extracts and parses AWS authentication header information + * (Credential, SignedHeaders, Signature) from a request. + */ +private fun retrieveAuthInformation(writer: KotlinWriter, algorithm: String) { + writer.write("val authHeader = call.request.#T(#T.Authorization) ?: return null", RuntimeTypes.KtorServerRouting.requestHeader, RuntimeTypes.KtorServerHttp.HttpHeaders) + .write("if (!authHeader.startsWith(#S, ignoreCase = true)) return null", algorithm) + .write("") + .write("fun part(name: String) = authHeader.substringAfter(#S).substringBefore(#S).trim()", "\$name=", ",") + .write("") + .write("val credential = part(#S) // accessKeyId/scope", "Credential") + .write("val signedHeadersStr = part(#S)", "SignedHeaders") + .write("val signatureHex = part(#S)", "Signature") + .write("") + .write("val signedHeaders: Set = signedHeadersStr.split(';').map { it.trim().lowercase() }.toSet()") + .write("if (#S !in signedHeaders) return null", "host") + .write("if (!signedHeaders.any { it == #S || it == #S }) return null", "x-amz-date", "date") + .write("val accessKeyId = credential.substringBefore(#S).takeIf { it.matches(Regex(#S)) } ?: return null", "/", "^[A-Z0-9]{16,128}$") +} + +/** + * Validates signing date against request scope date and clock skew. + */ +private fun authDateValidation(writer: KotlinWriter) { + writer.write("val rawXAmzDate = call.request.#T(#S)", RuntimeTypes.KtorServerRouting.requestHeader, "X-Amz-Date") + .write("val rawHttpDate = call.request.#T(#T.Date)", RuntimeTypes.KtorServerRouting.requestHeader, RuntimeTypes.KtorServerHttp.HttpHeaders) + .withBlock("val signingInstant: #T = when {", "}", RuntimeTypes.Core.Instant) { + write("rawXAmzDate != null -> { try { #T.fromIso8601(rawXAmzDate) } catch (_: Exception) { return null } }", RuntimeTypes.Core.Instant) + write("rawHttpDate != null -> { try { #T.fromRfc5322(rawHttpDate) } catch (_: Exception) { return null } }", RuntimeTypes.Core.Instant) + write("else -> return null") + } + .write("val scopeDate = signingInstant.format(#T.ISO_8601_CONDENSED_DATE)", RuntimeTypes.Core.TimestampFormat) + .write("if (scopeDate != yyyyMMdd) return null") + .write("") + .write("val now = #T.now()", RuntimeTypes.Core.Instant) + .write("if (signingInstant < now - maxClockSkew || signingInstant > now + maxClockSkew) return null") +} + +/** + * Builds a full HttpRequestBuilder object from the Ktor request, + * used for SigV4 canonical request signing. + */ +private fun createHttpRequestBuilder(writer: KotlinWriter) { + writer.write("val origin = call.request.local") + .write("val payload: ByteArray = call.#T()", RuntimeTypes.KtorServerRouting.requestReceive) + .write("") + .withBlock("val requestBuilder: #T = #T().apply {", "}", RuntimeTypes.Http.Request.HttpRequestBuilder, RuntimeTypes.Http.Request.HttpRequestBuilder) { + write("method = #T.parse(call.request.#T.value)", RuntimeTypes.Http.HttpMethod, RuntimeTypes.KtorServerRouting.requestHttpMethod) + write("") + write("val protoHeader = call.request.headers[#S] ?: origin.scheme", "X-Forwarded-Proto") + write("val isHttps = (protoHeader.equals(#S, ignoreCase = true))", "https") + write("val hostHeader = call.request.headers[#S] ?: call.request.headers[#S] ?: return null", "X-Forwarded-Host", "Host") + write("val hostOnly: String") + write("val portValue: Int?") + withBlock("hostHeader.split(':', limit = 2).let {", "}") { + write("hostOnly = it[0]") + write("portValue = it.getOrNull(1)?.toIntOrNull()") + } + withBlock("#T {", "}", RuntimeTypes.Http.Request.url) { + write("scheme = if (isHttps) #T.HTTPS else #T.HTTP", RuntimeTypes.Core.Net.Scheme, RuntimeTypes.Core.Net.Scheme) + write("host = #T.parse(hostOnly)", RuntimeTypes.Core.Net.Host) + write("if (portValue != null) port = portValue") + withBlock("path {", "}") { + write("decoded = call.request.#T()", RuntimeTypes.KtorServerRouting.requestPath) + } + withBlock("parameters {", "}") { + withBlock("decodedParameters {", "}") { + write("call.request.queryParameters.forEach { key, values -> values.forEach { v -> add(key, v) } }") + } + } + } + + write("") + + withBlock("for (name in call.request.headers.names()) {", "}") { + write("val lowerName = name.lowercase()") + withBlock("if (lowerName != #T.Authorization.lowercase() && lowerName in signedHeaders) {", "}", RuntimeTypes.KtorServerHttp.HttpHeaders) { + write("call.request.headers.getAll(name)?.forEach { value -> headers.append(name, value) }") + } + } + + write("body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + } +} + +/** + * Builds a canonical request string for SigV4A verification. + */ +private fun createCanonicalRequest(writer: KotlinWriter) { + writer.write("val origin = call.request.local") + .write("val payload: ByteArray = call.#T()", RuntimeTypes.KtorServerRouting.requestReceive) + .write("") + .write("val protoHeader = call.request.headers[#S] ?: origin.scheme", "X-Forwarded-Proto") + .write("val isHttps = (protoHeader.equals(#S, ignoreCase = true))", "https") + .write("val hostHeader = call.request.headers[#S] ?: call.request.headers[#S] ?: return null", "X-Forwarded-Host", "Host") + .write("val hostOnly: String") + .write("val portValue: Int?") + .withBlock("hostHeader.split(':', limit = 2).let {", "}") { + write("hostOnly = it[0]") + write("portValue = it.getOrNull(1)?.toIntOrNull()") + } + .write("") + .write("val canonicalUri = encodeCanonicalPath(call.request.#T())", RuntimeTypes.KtorServerRouting.requestPath) + .write("val canonicalQuery = buildCanonicalQuery(#T.build { call.request.queryParameters.forEach { k, vs -> vs.forEach { v -> append(k, v) } } })", RuntimeTypes.KtorServerHttp.Parameters) + .write("") + .withBlock("val filteredHeaders = #T().apply {", "}.build()", RuntimeTypes.KtorServerHttp.HeadersBuilder) { + withBlock("for (name in call.request.headers.names()) {", "}") { + write("val ln = name.lowercase()") + withBlock("if (ln != HttpHeaders.Authorization.lowercase() && ln in signedHeaders) {", "}") { + write("call.request.headers.getAll(name)?.forEach { v -> append(name, v) }") + } + } + withBlock("if (!names().any { it.equals(#S, ignoreCase = true) }) {", "}", "X-Amz-Region-Set") { + write("append(#S, rawRegionSet)", "X-Amz-Region-Set") + } + withBlock("if (!names().any { it.equals(#S, ignoreCase = true) }) {", "}", "Host") { + write("val defaultPort = (isHttps && portValue == 443) || (!isHttps && portValue == 80)") + write("append(#S, if (portValue != null && !defaultPort) #S else hostOnly) ", "Host", "\$hostOnly:\$portValue") + } + } + .withBlock("val (canonicalHeaders, signedHeaderList) = run {", "}") { + write("val map = mutableMapOf>()") + withBlock("filteredHeaders.names().forEach { name ->", "}") { + write("val ln = name.lowercase()") + withBlock("if (ln in signedHeaders) {", "}") { + write("val values = filteredHeaders.getAll(name).orEmpty()") + write(" .map { it.trim().replace(Regex(#S), #S) }", "\\s+", " ") + write("map.getOrPut(ln) { mutableListOf() }.addAll(values)") + } + } + withBlock("if (#S !in map) {", "}", "x-amz-region-set") { + write("map[#S] = mutableListOf(rawRegionSet)", "x-amz-region-set") + } + withBlock("val canon = map.toSortedMap().entries.joinToString(#S, postfix = #S) { entry ->", "}", "\n", "\n") { + write("val key = entry.key") + write("val vs = entry.value") + write("\"\$key:\${vs.joinToString(#S)}\"", ",") + } + + write("val signedList = map.keys.sorted().joinToString(#S)", ";") + write("canon to signedList") + } + .write("val payloadHash = if (isUnsigned) #S else { sha256Hex(payload) }", "UNSIGNED-PAYLOAD") + .withBlock("val canonicalRequest = buildString {", "}") { + write("append(call.request.#T.value.uppercase()).append('\\n')", RuntimeTypes.KtorServerRouting.requestHttpMethod) + write("append(canonicalUri).append('\\n')") + write("append(canonicalQuery).append('\\n')") + write("append(canonicalHeaders)") + write("append('\\n')") // empty line before SignedHeaders list + write("append(signedHeaderList).append('\\n')") + write("append(payloadHash)") + } +} + +/** + * Performs AWS SigV4 signature validation against expected HMAC. + */ +private fun validateSigV4(writer: KotlinWriter) { + writer.withBlock("val signer = #T(", ")", RuntimeTypes.Auth.HttpAuthAws.AwsHttpSigner) { + withBlock("#T.Config().apply {", "}", RuntimeTypes.Auth.HttpAuthAws.AwsHttpSigner) { + write("this.signer = #T", RuntimeTypes.Auth.Signing.AwsSigningStandard.DefaultAwsSigner) + write("this.service = service") + write("this.isUnsignedPayload = isUnsigned") + } + } + .withBlock("val attrs = #T {", "}", RuntimeTypes.Core.Collections.attributesOf) { + write("#T to creds", RuntimeTypes.Auth.Credentials.AwsCredentials.Credentials) + write("#T.SigningRegion to region", RuntimeTypes.Auth.Signing.AwsSigningCommon.AwsSigningAttributes) + write("#T.SigningDate to signingInstant", RuntimeTypes.Auth.Signing.AwsSigningCommon.AwsSigningAttributes) + withBlock("if (isUnsigned) {", "}") { + write( + "#T.HashSpecification to #T.UnsignedPayload", + RuntimeTypes.Auth.Signing.AwsSigningCommon.AwsSigningAttributes, + RuntimeTypes.Auth.Signing.AwsSigningCommon.HashSpecification, + ) + write( + "#T.SignedBodyHeader to #T.NONE", + RuntimeTypes.Auth.Signing.AwsSigningCommon.AwsSigningAttributes, + RuntimeTypes.Auth.Signing.AwsSigningCommon.AwsSignedBodyHeader, + ) + } + } + .write( + "signer.sign(#T(requestBuilder, creds, attrs))", + RuntimeTypes.Auth.HttpAuthAws.SignHttpRequest, + ) + .write( + "val expectedAuth = requestBuilder.headers.getAll(#T.Authorization)?.firstOrNull() ?: return null", + RuntimeTypes.KtorServerHttp.HttpHeaders, + ) + .write("val expectedSig = expectedAuth.substringAfter(#S).trim()", "Signature=") +} + +/** + * Performs AWS SigV4A signature validation against ECDSA public key. + */ +private fun validateSigV4A(writer: KotlinWriter) { + writer.write("val crHashHex = sha256Hex(canonicalRequest.toByteArray())") + .withBlock("val stringToSign = buildString {", "}") { + write("append(#S).append('\\n')", "AWS4-ECDSA-P256-SHA256") + write("append(signingInstant.format(#T.ISO_8601_CONDENSED)).append('\\n')", RuntimeTypes.Core.TimestampFormat) + write("append(#S).append('\\n')", "\$yyyyMMdd/\$service/aws4_request") + write("append(crHashHex)") + } + .write("val publicKey: java.security.PublicKey = SigV4aPublicKeyStore.get(accessKeyId) ?: return null") + .write("val sigDer = signatureHex.chunked(2).map { it.toInt(16).toByte() }.toByteArray()") + .write("val verifier = java.security.Signature.getInstance(#S)", "SHA256withECDSA") + .write("verifier.initVerify(publicKey)") + .write("verifier.update(stringToSign.toByteArray())") + .write("val ok = verifier.verify(sigDer)") +} + +/** + * Writes common helper functions used by SigV4 and SigV4A verification, + * such as hashing, encoding, and canonical path/query building. + */ +private fun renderHelperFunctions(writer: KotlinWriter) { + writer.withBlock("private fun sha256Hex(bytes: ByteArray): String {", "}") { + write("return java.security.MessageDigest.getInstance(#S).digest(bytes).joinToString(#S) { #S.format(it) }", "SHA-256", "", "%02x") + } + .write("") + + writer.withBlock("private val UNRESERVED: BooleanArray = BooleanArray(128).apply {", "}") { + write("for (c in 'A'..'Z') this[c.code] = true") + write("for (c in 'a'..'z') this[c.code] = true") + write("for (c in '0'..'9') this[c.code] = true") + write("this['-'.code] = true; this['_'.code] = true; this['.'.code] = true; this['~'.code] = true\n") + } + .write("") + + writer.withBlock("private fun rfc3986Encode(bytes: ByteArray): String {", "}") { + write("val out = StringBuilder(bytes.size * 3)") + withBlock("for (b in bytes) {", "}") { + write("val i = b.toInt() and 0xFF") + withInlineBlock("if (i < 128 && UNRESERVED[i]) {", "}") { + write("out.append(i.toChar())") + } + withBlock(" else {", "}") { + write("out.append('%')") + write("val hi = #S[(i ushr 4) and 0xF]", "0123456789ABCDEF") + write("val lo = #S[i and 0xF]", "0123456789ABCDEF") + write("out.append(hi).append(lo)") + } + } + write("return out.toString()") + } + .write("") + writer.write("private fun encodeString(s: String): String = rfc3986Encode(s.toByteArray(Charsets.UTF_8))") + .write("") + + writer.withBlock("private fun encodeCanonicalPath(rawPath: String): String {", "}") { + write("val p = if (rawPath.isEmpty()) #S else rawPath", "/") + write("return p.split('/').joinToString(#S) { seg -> if (seg.isEmpty()) #S else encodeString(seg) }", "/", "") + } + .write("") + + writer.withBlock("private fun buildCanonicalQuery(params: #T): String {", "}", RuntimeTypes.KtorServerHttp.Parameters) { + withBlock("val pairs = buildList {", "}") { + withBlock("params.names().sorted().forEach { name ->", "}") { + write("val values = params.getAll(name)") + openBlock("if (values == null || values.isEmpty()) {") + write("add(#S)", "\${encodeString(name)}=") + closeAndOpenBlock("} else {") + write("values.sorted().forEach { v -> add(#S) }", "\${encodeString(name)}=\${encodeString(v)}") + closeBlock("}") + } + } + write("return pairs.joinToString(#S)", "&") + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/KtorStubGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/KtorStubGenerator.kt new file mode 100644 index 0000000000..e578020e81 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/KtorStubGenerator.kt @@ -0,0 +1,67 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.kotlin.codegen.core.GenerationContext +import software.amazon.smithy.kotlin.codegen.core.InlineCodeWriterFormatter +import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.service.AbstractStubGenerator +import software.amazon.smithy.kotlin.codegen.service.constraints.ConstraintGenerator +import software.amazon.smithy.kotlin.codegen.service.constraints.ConstraintUtilsGenerator +import software.amazon.smithy.utils.AbstractCodeWriter + +class LoggingWriter(parent: LoggingWriter? = null) : AbstractCodeWriter() { + init { + trimBlankLines(parent?.trimBlankLines ?: 1) + trimTrailingSpaces(parent?.trimTrailingSpaces ?: true) + indentText = parent?.indentText ?: " " + expressionStart = parent?.expressionStart ?: '#' + putFormatter('W', InlineCodeWriterFormatter(::LoggingWriter)) + } +} + +/** + * Stub generator for Ktor-based services. + * + * Implements [AbstractStubGenerator] for the Ktor runtime, generating: + * - Framework implementation + * - Utilities + * - Authentication modules + * - Constraint validators + * - Routing tables + * - Plugins + * - Operation handlers + */ +internal class KtorStubGenerator( + ctx: GenerationContext, + delegator: KotlinDelegator, + fileManifest: FileManifest, +) : AbstractStubGenerator(ctx, delegator, fileManifest) { + + /** Generate the Ktor server framework implementation. */ + override fun renderServerFrameworkImplementation(writer: KotlinWriter) = writeServerFrameworkImplementation(writer) + + /** Generate utility classes and helpers. */ + override fun renderUtils() = writeUtils() + + /** Generate authentication modules (AWS auth + bearer/no-auth). */ + override fun renderAuthModule() { + writeAWSAuthentication() + writeAuthentication() + } + + /** Generate request constraint validators for all operations. */ + override fun renderConstraintValidators() { + ConstraintUtilsGenerator(ctx, delegator).render() + operations.forEach { operation -> ConstraintGenerator(ctx, operation, delegator).render() } + } + + /** Generate routing file mapping Smithy operations to Ktor routes. */ + override fun renderRouting() = writeRouting() + + /** Generate plugin configurations (e.g., error handlers, guards). */ + override fun renderPlugins() = writePlugins() + + /** Generate stub handler files for each Smithy operation. */ + override fun renderPerOperationHandlers() = writePerOperationHandlers() +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/OperationHandlers.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/OperationHandlers.kt new file mode 100644 index 0000000000..b9be8a925b --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/OperationHandlers.kt @@ -0,0 +1,51 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.kotlin.codegen.core.withBlock + +/** + * Generates stub handler files for each Smithy operation. + * + * For every operation, this creates an `OperationNameOperation.kt` file + * under the `operations` package containing: + * - A `handleRequest()` function + * - TODO implementation guidance for constructing response objects + * - Documentation of available custom errors + * + * Each generated handler accepts the operation's input type and returns + * the operation's output type. + */ +internal fun KtorStubGenerator.writePerOperationHandlers() { + operations.forEach { shape -> + val inputShape = ctx.model.expectShape(shape.input.get()) + val inputSymbol = ctx.symbolProvider.toSymbol(inputShape) + + val outputShape = ctx.model.expectShape(shape.output.get()) + val outputSymbol = ctx.symbolProvider.toSymbol(outputShape) + + val name = shape.id.name + + delegator.useFileWriter("${name}Operation.kt", "$pkgName.operations") { writer -> + + writer.withBlock("public fun handle${name}Request(req: #T): #T {", "}", inputSymbol, outputSymbol) { + write("// TODO: implement me") + write("// To build a #T object:", outputSymbol) + write("// 1. Use`#T.Builder()`", outputSymbol) + write("// 2. Set fields like `#T.variable = ...`", outputSymbol) + write("// 3. Return the built object using `return #T.build()`", outputSymbol) + write("//") + val errorSymbolNames: List = shape.errors.map { errorShapeId -> + val errorShape = ctx.model.expectShape(errorShapeId) + ctx.symbolProvider.toSymbol(errorShape).name + } + write("// You may also throw custom errors if needed.") + write("// Custom errors can be created using the same builder pattern.") + if (errorSymbolNames.isNotEmpty()) { + write("// Available errors : ${errorSymbolNames.joinToString(", ")}") + } else { + write("// There are no available errors for this operation.") + } + write("return #T.Builder().build()", outputSymbol) + } + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Plugins.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Plugins.kt new file mode 100644 index 0000000000..d39b519ee0 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Plugins.kt @@ -0,0 +1,293 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.core.withInlineBlock +import software.amazon.smithy.kotlin.codegen.service.ServiceTypes + +/** + * Entry point for writing plugin modules (error handling, content-type guard, accept-type guard). + * + * Generates three files under the `plugins` package: + * - `ErrorHandler.kt`: common error envelope and exception-to-response mapping + * - `ContentTypeGuard.kt`: validates request `Content-Type` header + * - `AcceptTypeGuard.kt`: validates request `Accept` header + */ +internal fun KtorStubGenerator.writePlugins() { + renderErrorHandler() + renderContentTypeGuard() + renderAcceptTypeGuard() +} + +/** + * Generates `ErrorHandler.kt`, which contains: + * - `ErrorEnvelope` exception wrapper for standard error responses + * - JSON and CBOR serialization for error payloads + * - Extension to respond with an `ErrorEnvelope` + * - `configureErrorHandling()` that installs a `StatusPages` plugin + * mapping HTTP status codes and exceptions → structured error responses + */ +private fun KtorStubGenerator.renderErrorHandler() { + delegator.useFileWriter("ErrorHandler.kt", "$pkgName.plugins") { writer -> + writer.write("internal val ResponseHandledKey = #T(#S)", RuntimeTypes.KtorServerUtils.AttributeKey, "ResponseHandled") + .write("") + writer.write("@#T", RuntimeTypes.KotlinxCborSerde.Serializable) + .write("private data class ErrorPayload(val code: Int, val message: String)") + .write("") + .withInlineBlock("internal class ErrorEnvelope(", ")") { + write("val code: Int,") + write("val msg: String,") + write("cause: Throwable? = null,") + } + .withBlock(" : RuntimeException(msg, cause) {", "}") { + withBlock("fun toJson(json: #T = #T): String {", "}", RuntimeTypes.KotlinxJsonSerde.Json, RuntimeTypes.KotlinxJsonSerde.Json) { + withInlineBlock("return json.encodeToString(", ")") { + write("ErrorPayload(code, message ?: #S)", "Unknown error") + } + } + withBlock("fun toCbor(cbor: #T = #T { }): ByteArray {", "}", RuntimeTypes.KotlinxCborSerde.Cbor, RuntimeTypes.KotlinxCborSerde.Cbor) { + withInlineBlock("return cbor.#T(", ")", RuntimeTypes.KotlinxCborSerde.encodeToByteArray) { + write("ErrorPayload(code, message ?: #S)", "Unknown error") + } + } + } + .write("") + .withInlineBlock("private suspend fun #T.respondEnvelope(", ")", RuntimeTypes.KtorServerCore.ApplicationCallClass) { + write("envelope: ErrorEnvelope,") + write("status: #T,", RuntimeTypes.KtorServerHttp.HttpStatusCode) + } + .withBlock("{", "}") { + write("val acceptsCbor = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestAcceptItems, "application/cbor") + write("val acceptsJson = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestAcceptItems, "application/json") + write("") + write("val log = #T.getLogger(#S)", RuntimeTypes.KtorLoggingSlf4j.LoggerFactory, pkgName) + write("log.info(#S)", "Route Error Message: \${envelope.msg}") + write("") + withBlock("when {", "}") { + withBlock("acceptsCbor -> {", "}") { + withBlock("#T(", ")", RuntimeTypes.KtorServerRouting.responseRespondBytes) { + write("bytes = envelope.toCbor(),") + write("status = status,") + write("contentType = #T", RuntimeTypes.KtorServerHttp.Cbor) + } + } + withBlock("acceptsJson -> {", "}") { + withBlock("#T(", ")", RuntimeTypes.KtorServerRouting.responseResponseText) { + write("envelope.toJson(),") + write("status = status,") + write("contentType = #T", RuntimeTypes.KtorServerHttp.Json) + } + } + withBlock("else -> {", "}") { + withBlock("#T(", ")", RuntimeTypes.KtorServerRouting.responseResponseText) { + write("envelope.msg,") + write("status = status") + } + } + } + } + .write("") + .withBlock("internal fun #T.configureErrorHandling() {", "}", RuntimeTypes.KtorServerCore.Application) { + write("") + withBlock( + "#T(#T) {", + "}", + RuntimeTypes.KtorServerCore.install, + RuntimeTypes.KtorServerStatusPage.StatusPages, + ) { + withBlock("status(#T.Unauthorized) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) { + write("if (call.attributes.getOrNull(#T) == true) { return@status }", ServiceTypes(pkgName).responseHandledKey) + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) + write("val missing = call.request.headers[#S].isNullOrBlank()", "Authorization") + write("val message = if (missing) #S else #S", "Missing bearer token", "Invalid or expired authentication") + write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )") + } + write("") + withBlock("status(#T.NotFound) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) { + write("if (call.attributes.getOrNull(#T) == true) { return@status }", ServiceTypes(pkgName).responseHandledKey) + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) + write("val message = #S", "Resource not found") + write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )") + } + write("") + withBlock("status(#T.MethodNotAllowed) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) { + write("if (call.attributes.getOrNull(#T) == true) { return@status }", ServiceTypes(pkgName).responseHandledKey) + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) + write("val message = #S", "Method not allowed for this resource") + write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )") + } + write("") + withBlock("#T { call, cause ->", "}", RuntimeTypes.KtorServerStatusPage.exception) { + withBlock("val status = when (cause) {", "}") { + write( + "is ErrorEnvelope -> #T.fromValue(cause.code)", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + write( + "is #T -> #T.BadRequest", + RuntimeTypes.KtorServerCore.BadRequestException, + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + write( + "is #T -> #T.PayloadTooLarge", + RuntimeTypes.KtorServerBodyLimit.PayloadTooLargeException, + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + write("else -> #T.InternalServerError", RuntimeTypes.KtorServerHttp.HttpStatusCode) + } + write("") + + write("val envelope = if (cause is ErrorEnvelope) cause else ErrorEnvelope(status.value, cause.message ?: #S)", "Unexpected error") + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) + write("call.respondEnvelope( envelope, status )") + } + } + } + } +} + +/** + * Generates `ContentTypeGuard.kt`, which installs a route-scoped plugin that: + * - Defines a configurable allow-list of acceptable request `Content-Type`s + * - Rejects unsupported media types with an `ErrorEnvelope` + * - Provides convenience configuration (e.g., `json()`, `cbor()`, `binary()`) + */ +private fun KtorStubGenerator.renderContentTypeGuard() { + delegator.useFileWriter("ContentTypeGuard.kt", "$pkgName.plugins") { writer -> + + writer.withBlock("private fun #T.hasBody(): Boolean {", "}", RuntimeTypes.KtorServerRouting.requestApplicationRequest) { + write( + "return (#T()?.let { it > 0 } == true) || headers.contains(#T.TransferEncoding)", + RuntimeTypes.KtorServerRouting.requestContentLength, + RuntimeTypes.KtorServerHttp.HttpHeaders, + ) + } + writer.withBlock("public class ContentTypeGuardConfig {", "}") { + write("public var allow: List<#T> = emptyList()", RuntimeTypes.KtorServerHttp.ContentType) + write("") + withBlock("public fun any(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Any) + } + write("") + withBlock("public fun json(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Json) + } + write("") + withBlock("public fun cbor(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Cbor) + } + write("") + withBlock("public fun text(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.PlainText) + } + write("") + withBlock("public fun binary(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.OctetStream) + } + } + .write("") + + writer.withInlineBlock( + "public val ContentTypeGuard: #T = #T(", + ")", + RuntimeTypes.KtorServerCore.ApplicationRouteScopedPlugin, + RuntimeTypes.KtorServerCore.ApplicationCreateRouteScopedPlugin, + ) { + write("name = #S,", "ContentTypeGuard") + write("createConfiguration = ::ContentTypeGuardConfig,") + } + .withBlock("{", "}") { + write("val allowed: List<#T> = pluginConfig.allow", RuntimeTypes.KtorServerHttp.ContentType) + write("require(allowed.isNotEmpty()) { #S }", "ContentTypeGuard installed with empty allow-list.") + write("") + withBlock("onCall { call ->", "}") { + write("if (!call.request.hasBody()) return@onCall") + write("val incoming = call.request.#T()", RuntimeTypes.KtorServerRouting.requestContentType) + withBlock("if (incoming == #T.Any || allowed.none { incoming.match(it) }) {", "}", RuntimeTypes.KtorServerHttp.ContentType) { + withBlock("throw #T(", ")", ServiceTypes(pkgName).errorEnvelope) { + write("#T.UnsupportedMediaType.value, ", RuntimeTypes.KtorServerHttp.HttpStatusCode) + write("#S", "Not acceptable Content‑Type found: '\${incoming}'. Accepted content types: \${allowed.joinToString()}") + } + } + } + } + } +} + +/** + * Generates `AcceptTypeGuard.kt`, which installs a route-scoped plugin that: + * - Defines a configurable allow-list of acceptable `Accept` header values + * - Rejects unsupported response types with an `ErrorEnvelope` + * - Provides convenience configuration (e.g., `json()`, `cbor()`, `text()`) + */ +private fun KtorStubGenerator.renderAcceptTypeGuard() { + delegator.useFileWriter("AcceptTypeGuard.kt", "${ctx.settings.pkg.name}.plugins") { writer -> + + writer.withBlock( + "private fun #T.acceptedContentTypes(): List<#T> {", + "}", + RuntimeTypes.KtorServerRouting.requestApplicationRequest, + RuntimeTypes.KtorServerHttp.ContentType, + ) { + write("val raw = headers[#T.Accept] ?: return emptyList()", RuntimeTypes.KtorServerHttp.HttpHeaders) + write( + "return #T(raw).mapNotNull { it.value?.let(#T::parse) }", + RuntimeTypes.KtorServerHttp.parseAndSortHeader, + RuntimeTypes.KtorServerHttp.ContentType, + ) + } + + writer.withBlock("public class AcceptTypeGuardConfig {", "}") { + write("public var allow: List<#T> = emptyList()", RuntimeTypes.KtorServerHttp.ContentType) + write("") + withBlock("public fun any(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Any) + } + write("") + withBlock("public fun json(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Json) + } + write("") + withBlock("public fun cbor(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Cbor) + } + write("") + withBlock("public fun text(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.PlainText) + } + write("") + withBlock("public fun binary(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.OctetStream) + } + } + .write("") + + writer.withInlineBlock( + "public val AcceptTypeGuard: #T = #T(", + ")", + RuntimeTypes.KtorServerCore.ApplicationRouteScopedPlugin, + RuntimeTypes.KtorServerCore.ApplicationCreateRouteScopedPlugin, + ) { + write("name = #S,", "AcceptTypeGuard") + write("createConfiguration = ::AcceptTypeGuardConfig,") + } + .withBlock("{", "}") { + write("val allowed: List<#T> = pluginConfig.allow", RuntimeTypes.KtorServerHttp.ContentType) + write("require(allowed.isNotEmpty()) { #S }", "AcceptTypeGuard installed with empty allow-list.") + write("") + withBlock("onCall { call ->", "}") { + write("val accepted = call.request.acceptedContentTypes()") + write("if (accepted.isEmpty()) return@onCall") + write("") + write("val isOk = accepted.any { candidate -> allowed.any { candidate.match(it) } }") + + withBlock("if (!isOk) {", "}") { + withBlock("throw #T(", ")", ServiceTypes(pkgName).errorEnvelope) { + write("#T.NotAcceptable.value, ", RuntimeTypes.KtorServerHttp.HttpStatusCode) + write("#S", "Not acceptable Accept type found: '\${accepted}'. Accepted types: \${allowed.joinToString()}") + } + } + } + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Routing.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Routing.kt new file mode 100644 index 0000000000..2a09b38cf0 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Routing.kt @@ -0,0 +1,470 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.aws.traits.auth.SigV4ATrait +import software.amazon.smithy.aws.traits.auth.SigV4Trait +import software.amazon.smithy.codegen.core.SymbolReference +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.core.withInlineBlock +import software.amazon.smithy.kotlin.codegen.model.buildSymbol +import software.amazon.smithy.kotlin.codegen.model.getTrait +import software.amazon.smithy.kotlin.codegen.rendering.serde.deserializerName +import software.amazon.smithy.kotlin.codegen.rendering.serde.serializerName +import software.amazon.smithy.kotlin.codegen.service.MediaType +import software.amazon.smithy.kotlin.codegen.service.MediaType.ANY +import software.amazon.smithy.kotlin.codegen.service.MediaType.JSON +import software.amazon.smithy.kotlin.codegen.service.MediaType.OCTET_STREAM +import software.amazon.smithy.kotlin.codegen.service.MediaType.PLAIN_TEXT +import software.amazon.smithy.kotlin.codegen.service.ServiceTypes +import software.amazon.smithy.kotlin.codegen.service.renderCastingPrimitiveFromShapeType +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.traits.AuthTrait +import software.amazon.smithy.model.traits.HttpBearerAuthTrait +import software.amazon.smithy.model.traits.HttpErrorTrait +import software.amazon.smithy.model.traits.HttpHeaderTrait +import software.amazon.smithy.model.traits.HttpLabelTrait +import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait +import software.amazon.smithy.model.traits.HttpQueryParamsTrait +import software.amazon.smithy.model.traits.HttpQueryTrait +import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.model.traits.MediaTypeTrait +import software.amazon.smithy.model.traits.TimestampFormatTrait + +/** + * Generates Ktor server routing for all operations in the service model. + * + * - Creates `Routing.kt` file. + * - Installs appropriate content-type and accept-type guards. + * - Handles request deserialization, validation, business logic invocation, + * response serialization, and error handling. + */ +internal fun KtorStubGenerator.writeRouting() { + delegator.useFileWriter("Routing.kt", pkgName) { writer -> + operations.forEach { shape -> + writer.addImport("$pkgName.constraints", "check${shape.id.name}RequestConstraint") + writer.addImport("$pkgName.operations", "handle${shape.id.name}Request") + } + + writer.withBlock("internal fun #T.configureRouting(): Unit {", "}", RuntimeTypes.KtorServerCore.Application) { + withBlock("#T {", "}", RuntimeTypes.KtorServerRouting.routing) { + withBlock("#T(#S) {", "}", RuntimeTypes.KtorServerRouting.get, "/") { + write(" #T.#T(#S)", RuntimeTypes.KtorServerCore.applicationCall, RuntimeTypes.KtorServerRouting.responseResponseText, "hello world") + } + operations.filter { it.hasTrait(HttpTrait.ID) } + .forEach { shape -> + val inputShape = ctx.model.expectShape(shape.input.get()) + val inputSymbol = ctx.symbolProvider.toSymbol(inputShape) + + val outputShape = ctx.model.expectShape(shape.output.get()) + val outputSymbol = ctx.symbolProvider.toSymbol(outputShape) + + val serializerSymbol = buildSymbol { + definitionFile = "${shape.serializerName()}.kt" + name = shape.serializerName() + namespace = ctx.settings.pkg.serde + reference(inputSymbol, SymbolReference.ContextOption.DECLARE) + } + val deserializerSymbol = buildSymbol { + definitionFile = "${shape.deserializerName()}.kt" + name = shape.deserializerName() + namespace = ctx.settings.pkg.serde + reference(outputSymbol, SymbolReference.ContextOption.DECLARE) + } + + val httpTrait = shape.getTrait()!! + + val uri = httpTrait.uri + val successCode = httpTrait.code + val method = when (httpTrait.method) { + "GET" -> RuntimeTypes.KtorServerRouting.get + "POST" -> RuntimeTypes.KtorServerRouting.post + "PUT" -> RuntimeTypes.KtorServerRouting.put + "PATCH" -> RuntimeTypes.KtorServerRouting.patch + "DELETE" -> RuntimeTypes.KtorServerRouting.delete + "HEAD" -> RuntimeTypes.KtorServerRouting.head + "OPTIONS" -> RuntimeTypes.KtorServerRouting.options + else -> error("Unsupported http trait ${httpTrait.method}") + } + val contentType = MediaType.fromServiceShape(ctx, serviceShape, shape.input.get()) + val contentTypeGuard = when (contentType) { + MediaType.CBOR -> "cbor()" + JSON -> "json()" + PLAIN_TEXT -> "text()" + OCTET_STREAM -> "binary()" + ANY -> "any()" + } + + val acceptType = MediaType.fromServiceShape(ctx, serviceShape, shape.output.get()) + val acceptTypeGuard = when (acceptType) { + MediaType.CBOR -> "cbor()" + JSON -> "json()" + PLAIN_TEXT -> "text()" + OCTET_STREAM -> "binary()" + ANY -> "any()" + } + + withBlock("#T (#S) {", "}", RuntimeTypes.KtorServerRouting.route, uri) { + write("#T(#T) { $contentTypeGuard }", RuntimeTypes.KtorServerCore.install, ServiceTypes(pkgName).contentTypeGuard) + write("#T(#T) { $acceptTypeGuard }", RuntimeTypes.KtorServerCore.install, ServiceTypes(pkgName).acceptTypeGuard) + withBlock( + "#W", + "}", + { w: KotlinWriter -> renderRoutingAuth(w, shape) }, + ) { + withBlock("#T {", "}", method) { + withInlineBlock("try {", "}") { + write( + "val request = #T.#T()", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.requestReceive, + ) + write("val deserializer = #T()", deserializerSymbol) + withBlock( + "var requestObj = try { deserializer.deserialize(#T(), call, request) } catch (ex: Exception) {", + "}", + RuntimeTypes.Core.ExecutionContext, + ) { + write( + "throw #T(ex?.message ?: #S, ex)", + RuntimeTypes.KtorServerCore.BadRequestException, + "Malformed input data", + ) + } + if (ctx.model.expectShape(shape.input.get()).allMembers.isNotEmpty()) { + withBlock("requestObj = requestObj.copy {", "}") { + call { readHttpLabel(shape, writer) } + call { readHttpQuery(shape, writer) } + } + } + + write( + "try { check${shape.id.name}RequestConstraint(requestObj) } catch (ex: Exception) { throw #T(ex?.message ?: #S, ex) }", + RuntimeTypes.KtorServerCore.BadRequestException, + "Error while validating constraints", + ) + write("val responseObj = handle${shape.id.name}Request(requestObj)") + write("val serializer = #T()", serializerSymbol) + withBlock( + "val response = try { serializer.serialize(#T(), responseObj) } catch (ex: Exception) {", + "}", + RuntimeTypes.Core.ExecutionContext, + ) { + write( + "throw #T(ex?.message ?: #S, ex)", + RuntimeTypes.KtorServerCore.BadRequestException, + "Malformed output data", + ) + } + call { readResponseHttpHeader("responseObj", shape.output.get(), writer) } + call { readResponseHttpPrefixHeader("responseObj", shape.output.get(), writer) } + call { renderResponseCall("response", writer, acceptType, successCode.toString(), shape.output.get()) } + } + withBlock(" catch (t: Throwable) {", "}") { + writeInline("val errorObj: Any? = ") + withBlock("when (t) {", "}") { + shape.errors.forEach { errorShapeId -> + val errorShape = ctx.model.expectShape(errorShapeId) + val errorSymbol = ctx.symbolProvider.toSymbol(errorShape) + write("is #T -> t as #T", errorSymbol, errorSymbol) + } + write("else -> null") + } + write("") + writeInline("val errorResponse: Pair? = ") + withBlock("when (errorObj) {", "}") { + shape.errors.forEach { errorShapeId -> + val errorShape = ctx.model.expectShape(errorShapeId) + val errorSymbol = ctx.symbolProvider.toSymbol(errorShape) + val exceptionSymbol = buildSymbol { + val exceptionName = "${errorSymbol.name}Serializer" + definitionFile = "$errorSymbol.kt" + name = exceptionName + namespace = ctx.settings.pkg.serde + reference(errorSymbol, SymbolReference.ContextOption.DECLARE) + } + write("is #T -> Pair(#T().serialize(#T(), errorObj), ${errorShape.getTrait()?.code})", errorSymbol, exceptionSymbol, RuntimeTypes.Core.ExecutionContext) + } + write("else -> null") + } + write("if (errorResponse == null) throw t") + + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) + withBlock("when (errorObj) {", "}") { + shape.errors.forEach { errorShapeId -> + val errorShape = ctx.model.expectShape(errorShapeId) + val errorSymbol = ctx.symbolProvider.toSymbol(errorShape) + withBlock("is #T -> {", "}", errorSymbol) { + readResponseHttpHeader("errorObj", errorShapeId, writer) + readResponseHttpPrefixHeader("errorObj", errorShapeId, writer) + } + } + write("else -> null") + } + call { renderResponseCall("errorResponse.first", writer, acceptType, "\"\${errorResponse.second}\"", shape.output.get()) } + } + } + } + } + } + } + } + } +} + +/** + * Reads `HttpLabelTrait` annotated members from request URI parameters + * and casts them to appropriate Kotlin types before populating request object. + */ +private fun KtorStubGenerator.readHttpLabel(shape: OperationShape, writer: KotlinWriter) { + val inputShape = ctx.model.expectShape(shape.input.get()) + inputShape.allMembers + .filter { member -> member.value.hasTrait(HttpLabelTrait.ID) } + .forEach { member -> + val memberName = member.key + val memberShape = member.value + + val httpLabelVariableName = "call.parameters[\"$memberName\"]?" + val targetShape = ctx.model.expectShape(memberShape.target) + writer.writeInline("$memberName = ") + .call { + renderCastingPrimitiveFromShapeType( + httpLabelVariableName, + targetShape.type, + writer, + memberShape.getTrait() ?: inputShape.getTrait(), + "Unsupported type ${memberShape.type} for httpLabel", + ) + } + } +} + +/** + * Reads `HttpQueryTrait` and `HttpQueryParamsTrait` annotated members + * from query parameters. Handles both simple and list-valued query params, + * casting them to correct Kotlin types before populating request object. + */ +private fun KtorStubGenerator.readHttpQuery(shape: OperationShape, writer: KotlinWriter) { + val inputShape = ctx.model.expectShape(shape.input.get()) + val httpQueryKeys = mutableSetOf() + inputShape.allMembers + .filter { member -> member.value.hasTrait(HttpQueryTrait.ID) } + .forEach { member -> + val memberName = member.key + val memberShape = member.value + val httpQueryTrait = memberShape.getTrait()!! + val httpQueryVariableName = "call.request.queryParameters[\"${httpQueryTrait.value}\"]?" + val targetShape = ctx.model.expectShape(memberShape.target) + httpQueryKeys.add(httpQueryTrait.value) + writer.writeInline("$memberName = ") + .call { + when { + targetShape.isListShape -> { + val listMemberShape = targetShape.allMembers.values.first() + val listMemberTargetShapeId = ctx.model.expectShape(listMemberShape.target) + val httpQueryListVariableName = "(call.request.queryParameters.getAll(\"${httpQueryTrait.value}\") " + + "?: call.request.queryParameters.getAll(\"${httpQueryTrait.value}[]\") " + + "?: emptyList())" + writer.withBlock("$httpQueryListVariableName.mapNotNull{", "}") { + renderCastingPrimitiveFromShapeType( + "it?", + listMemberTargetShapeId.type, + writer, + listMemberShape.getTrait() ?: targetShape.getTrait(), + "Unsupported type ${memberShape.type} for list in httpLabel", + ) + } + } + else -> renderCastingPrimitiveFromShapeType( + httpQueryVariableName, + targetShape.type, + writer, + memberShape.getTrait() ?: inputShape.getTrait(), + "Unsupported type ${memberShape.type} for httpQuery", + ) + } + } + } + val httpQueryParamsMember = inputShape.allMembers.values.firstOrNull { it.hasTrait(HttpQueryParamsTrait.ID) } + httpQueryParamsMember?.apply { + val httpQueryParamsMemberName = httpQueryParamsMember.memberName + val httpQueryParamsMapShape = ctx.model.expectShape(httpQueryParamsMember.target) as MapShape + val httpQueryParamsMapValueTypeShape = ctx.model.expectShape(httpQueryParamsMapShape.value.target) + val httpQueryKeysLiteral = httpQueryKeys.joinToString(", ") { "\"$it\"" } + writer.withInlineBlock("$httpQueryParamsMemberName = call.request.queryParameters.entries().filter { (key, _) ->", "}") { + write("key !in setOf($httpQueryKeysLiteral)") + } + .withBlock(".associate { (key, values) ->", "}") { + if (httpQueryParamsMapValueTypeShape.isListShape) { + write("key to values!!") + } else { + write("key to values.first()") + } + } + .withBlock(".mapValues { (_, value) ->", "}") { + renderCastingPrimitiveFromShapeType( + "value", + httpQueryParamsMapValueTypeShape.type, + writer, + httpQueryParamsMapValueTypeShape.getTrait() ?: httpQueryParamsMapShape.getTrait(), + "Unsupported type ${httpQueryParamsMapValueTypeShape.type} for httpQuery", + ) + } + } +} + +/** + * Configures authentication for a given operation shape. + * Determines available authentication strategies (Bearer, SigV4, SigV4A) + * at service and operation level and installs them in Ktor's `authenticate` block. + */ +private fun KtorStubGenerator.renderRoutingAuth(w: KotlinWriter, shape: OperationShape) { + val serviceAuthTrait = serviceShape.getTrait() + val hasServiceHttpBearerAuthTrait = serviceShape.hasTrait(HttpBearerAuthTrait.ID) + val hasServiceSigV4AuthTrait = serviceShape.hasTrait(SigV4Trait.ID) + val hasServiceSigV4AAuthTrait = serviceShape.hasTrait(SigV4ATrait.ID) + val authTrait = shape.getTrait() + val hasOperationBearerAuthTrait = authTrait?.valueSet?.contains(HttpBearerAuthTrait.ID) ?: true + val hasOperationSigV4AuthTrait = authTrait?.valueSet?.contains(SigV4Trait.ID) ?: true + val hasOperationSigV4AAuthTrait = authTrait?.valueSet?.contains(SigV4ATrait.ID) ?: true + + val availableAuthTraitOrderedSet = authTrait?.valueSet ?: serviceAuthTrait?.valueSet ?: setOf(HttpBearerAuthTrait.ID, SigV4Trait.ID, SigV4ATrait.ID) + + val authList = mutableListOf() + availableAuthTraitOrderedSet.forEach { authTraitId -> + when (authTraitId) { + HttpBearerAuthTrait.ID -> if (hasServiceHttpBearerAuthTrait && hasOperationBearerAuthTrait) authList.add("auth-bearer") + SigV4Trait.ID -> if (hasServiceSigV4AuthTrait && hasOperationSigV4AuthTrait) authList.add("aws-sigv4") + SigV4ATrait.ID -> if (hasServiceSigV4AAuthTrait && hasOperationSigV4AAuthTrait) authList.add("aws-sigv4a") + } + } + authList.ifEmpty { authList.add("no-auth") } + + w.write( + "#T(#L, strategy = #T.FirstSuccessful) {", + RuntimeTypes.KtorServerAuth.authenticate, + authList.joinToString(", ") { "\"$it\"" }, + RuntimeTypes.KtorServerAuth.AuthenticationStrategy, + ) +} + +/** + * Reads and appends HTTP headers from response object fields annotated + * with `HttpHeaderTrait` to the Ktor response. + */ +private fun KtorStubGenerator.readResponseHttpHeader(dataName: String, shapeId: ShapeId, writer: KotlinWriter) { + val shape = ctx.model.expectShape(shapeId) + shape.allMembers + .filter { member -> member.value.hasTrait(HttpHeaderTrait.ID) } + .forEach { member -> + val headerName = member.value.getTrait()!!.value + val memberName = member.key + writer.write("call.response.headers.append(#S, $dataName.$memberName.toString())", headerName) + } +} + +/** + * Reads and appends HTTP prefix headers from response object fields annotated + * with `HttpPrefixHeadersTrait`. Dynamically appends prefixed headers with suffix values. + */ +private fun KtorStubGenerator.readResponseHttpPrefixHeader(dataName: String, shapeId: ShapeId, writer: KotlinWriter) { + val shape = ctx.model.expectShape(shapeId) + shape.allMembers + .filter { member -> member.value.hasTrait(HttpPrefixHeadersTrait.ID) } + .forEach { member -> + val prefixHeaderName = member.value.getTrait()!!.value + val memberName = member.key + writer.withBlock("for ((suffixHeader, headerValue) in $dataName?.$memberName ?: mapOf()) {", "}") { + writer.write("call.response.headers.append(#S, headerValue.toString())", "$prefixHeaderName\${suffixHeader}") + } + } +} + +/** + * Writes the Ktor call to send the response back to the client. + * + * - Selects correct responder (`respondBytes` or `responseText`) based on content type. + * - Sets appropriate content type and HTTP status code. + * - Supports CBOR, JSON, text, binary, and dynamic media types. + */ +private fun KtorStubGenerator.renderResponseCall( + responseName: String, + w: KotlinWriter, + acceptType: MediaType, + successCode: String, + outputShapeId: ShapeId, +) { + when (acceptType) { + MediaType.CBOR -> w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseRespondBytes, + ) { + write("bytes = $responseName as ByteArray,") + write("contentType = #T,", RuntimeTypes.KtorServerHttp.Cbor) + write( + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + OCTET_STREAM -> w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseRespondBytes, + ) { + write("bytes = $responseName as ByteArray,") + write("contentType = #T,", RuntimeTypes.KtorServerHttp.OctetStream) + write( + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + JSON -> w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseResponseText, + ) { + write("text = $responseName as String,") + write("contentType = #T,", RuntimeTypes.KtorServerHttp.Json) + write( + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + PLAIN_TEXT -> w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseResponseText, + ) { + write("text = $responseName as String,") + write("contentType = #T,", RuntimeTypes.KtorServerHttp.PlainText) + write( + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + ANY -> { + val outputShape = ctx.model.expectShape(outputShapeId) + val mediaTraits = outputShape.allMembers.values.firstNotNullOf { it.getTrait() } + w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseRespondBytes, + ) { + write("bytes = $responseName as ByteArray,") + write("contentType = #S,", mediaTraits.value) + write( + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/ServerFrameworkImplementation.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/ServerFrameworkImplementation.kt new file mode 100644 index 0000000000..3a008d8601 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/ServerFrameworkImplementation.kt @@ -0,0 +1,73 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.core.withInlineBlock +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes +import software.amazon.smithy.kotlin.codegen.service.ServiceTypes + +/** + * Writes the server framework implementation for the generated Ktor service. + * + * - Defines the `Application.module()` entry point for configuring the Ktor pipeline + * (logging, body size limits, double receive, error handling, authentication, routing). + * - Generates a concrete `KtorServiceFramework` implementation of `ServiceFramework` + * that manages lifecycle of the Ktor embedded server engine. + */ +internal fun KtorStubGenerator.writeServerFrameworkImplementation(writer: KotlinWriter) { + writer.withBlock("internal fun #T.module(): Unit {", "}", RuntimeTypes.KtorServerCore.Application) { + write("#T()", ServiceTypes(pkgName).configureLogging) + withBlock("#T(#T) {", "}", RuntimeTypes.KtorServerCore.install, RuntimeTypes.KtorServerBodyLimit.RequestBodyLimit) { + write("bodyLimit { #T.requestBodyLimit }", ServiceTypes(pkgName).serviceFrameworkConfig) + } + write("#T(#T)", RuntimeTypes.KtorServerCore.install, RuntimeTypes.KtorServerDoubleReceive.DoubleReceive) + write("#T()", ServiceTypes(pkgName).configureErrorHandling) + write("#T()", ServiceTypes(pkgName).configureAuthentication) + write("#T()", ServiceTypes(pkgName).configureRouting) + } + .write("") + writer.withBlock("internal class KtorServiceFramework() : ServiceFramework {", "}") { + write("private var engine: #T<*, *>? = null", RuntimeTypes.KtorServerCore.EmbeddedServerType) + write("") + write("private val engineFactory = #T.engine.toEngineFactory()", ServiceTypes(pkgName).serviceFrameworkConfig) + + write("") + withBlock("override fun start() {", "}") { + withInlineBlock("engine = #T(", ")", RuntimeTypes.KtorServerCore.embeddedServer) { + write("engineFactory,") + withBlock("configure = {", "}") { + withBlock("#T {", "}", RuntimeTypes.KtorServerCore.connector) { + write("host = #S", "0.0.0.0") + write("port = #T.port", ServiceTypes(pkgName).serviceFrameworkConfig) + } + withBlock("when (this) {", "}") { + withBlock("is #T -> {", "}", RuntimeTypes.KtorServerNetty.Configuration) { + write("requestReadTimeoutSeconds = #T.requestReadTimeoutSeconds", ServiceTypes(pkgName).serviceFrameworkConfig) + write("responseWriteTimeoutSeconds = #T.responseWriteTimeoutSeconds", ServiceTypes(pkgName).serviceFrameworkConfig) + } + + withBlock("is #T -> {", "}", RuntimeTypes.KtorServerCio.Configuration) { + write("connectionIdleTimeoutSeconds = #T.requestReadTimeoutSeconds", ServiceTypes(pkgName).serviceFrameworkConfig) + } + + withBlock("is #T -> {", "}", RuntimeTypes.KtorServerJettyJakarta.Configuration) { + write( + "idleTimeout = #T.requestReadTimeoutSeconds.#T", + ServiceTypes(pkgName).serviceFrameworkConfig, + KotlinTypes.Time.seconds, + ) + } + } + } + } + write("{ #T() }", ServiceTypes(pkgName).module) + write("engine?.apply { start(wait = true) }") + } + write("") + withBlock("final override fun close() {", "}") { + write("engine?.stop(#T.closeGracePeriodMillis, #T.closeTimeoutMillis)", ServiceTypes(pkgName).serviceFrameworkConfig, ServiceTypes(pkgName).serviceFrameworkConfig) + write("engine = null") + } + } +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Utils.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Utils.kt new file mode 100644 index 0000000000..57a68f19f3 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ktor/Utils.kt @@ -0,0 +1,97 @@ +package software.amazon.smithy.kotlin.codegen.service.ktor + +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.service.ServiceTypes + +/** + * Entry point for generating utility files required by the Ktor server stub. + * + * Currently delegates to [renderLogging] to generate logging configuration. + */ +internal fun KtorStubGenerator.writeUtils() { + renderLogging() +} + +/** + * Generates logging configuration for the generated Ktor service. + * + * - Creates a `Logging.kt` file with `configureLogging()` extension function for + * the Ktor `Application`, setting up SLF4J/Logback integration. + * - Maps service log levels to SLF4J/Logback levels. + * - Configures Ktor's `CallLogging` plugin to log HTTP method, URI, and status. + * - Registers lifecycle log messages (`starting`, `started`, `stopping`, `stopped`). + * - Writes a default `logback.xml` file into `src/main/resources` to ensure + * console logging is available out of the box. + */ +private fun KtorStubGenerator.renderLogging() { + delegator.useFileWriter("Logging.kt", "$pkgName.utils") { writer -> + + writer.withBlock("internal fun #T.configureLogging() {", "}", RuntimeTypes.KtorServerCore.Application) { + withBlock( + "val slf4jLevel: #T? = when (#T.logLevel) {", + "}", + RuntimeTypes.KtorLoggingSlf4j.Level, + ServiceTypes(pkgName).serviceFrameworkConfig, + ) { + write("#T.INFO -> #T.INFO", ServiceTypes(pkgName).logLevel, RuntimeTypes.KtorLoggingSlf4j.Level) + write("#T.TRACE -> #T.TRACE", ServiceTypes(pkgName).logLevel, RuntimeTypes.KtorLoggingSlf4j.Level) + write("#T.DEBUG -> #T.DEBUG", ServiceTypes(pkgName).logLevel, RuntimeTypes.KtorLoggingSlf4j.Level) + write("#T.WARN -> #T.WARN", ServiceTypes(pkgName).logLevel, RuntimeTypes.KtorLoggingSlf4j.Level) + write("#T.ERROR -> #T.ERROR", ServiceTypes(pkgName).logLevel, RuntimeTypes.KtorLoggingSlf4j.Level) + write("#T.OFF -> null", ServiceTypes(pkgName).logLevel) + } + write("") + write("val logbackLevel = slf4jLevel?.let { #T.valueOf(it.name) } ?: #T.OFF", RuntimeTypes.KtorLoggingLogback.Level, RuntimeTypes.KtorLoggingLogback.Level) + write("") + write( + "(#T.getILoggerFactory() as #T).getLogger(#T).level = logbackLevel", + RuntimeTypes.KtorLoggingSlf4j.LoggerFactory, + RuntimeTypes.KtorLoggingLogback.LoggerContext, + RuntimeTypes.KtorLoggingSlf4j.ROOT_LOGGER_NAME, + ) + write("") + withBlock("if (slf4jLevel != null) {", "}") { + withBlock("#T(#T) {", "}", RuntimeTypes.KtorServerCore.install, RuntimeTypes.KtorServerLogging.CallLogging) { + write("level = slf4jLevel") + withBlock("format { call ->", "}") { + write("val status = call.response.status()") + write("\"\${call.request.#T.value} \${call.request.#T} → \$status\"", RuntimeTypes.KtorServerRouting.requestHttpMethod, RuntimeTypes.KtorServerRouting.requestUri) + } + } + } + write("val log = #T.getLogger(#S)", RuntimeTypes.KtorLoggingSlf4j.LoggerFactory, ctx.settings.pkg.name) + + withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStarting) { + write("log.info(#S)", "Server is starting...") + } + + withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStarted) { + write("log.info(#S)", "Server started – ready to accept requests.") + } + + withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStopping) { + write("log.warn(#S)", "Server is stopping – waiting for in-flight requests...") + } + + withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStopped) { + write("log.info(#S)", "Server stopped cleanly.") + } + } + } + val loggingWriter = LoggingWriter() + loggingWriter.withBlock("", "") { + withBlock("", "", "STDOUT", "ch.qos.logback.core.ConsoleAppender") { + withBlock("", "") { + withBlock("", "") { + write("%d{yyyy-MM-dd'T'HH:mm:ss.SSSXXX} %-5level %logger{36} - %msg%n") + } + } + } + withBlock("", "") { + write("", "STDOUT") + } + } + val contents = loggingWriter.toString() + fileManifest.writeFile("src/main/resources/logback.xml", contents) +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/utils.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/utils.kt new file mode 100644 index 0000000000..22e40ee7a7 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/utils.kt @@ -0,0 +1,43 @@ +package software.amazon.smithy.kotlin.codegen.service + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.model.traits.TimestampFormatTrait + +/** + * Renders Kotlin code that casts a variable to the correct primitive type + * based on its Smithy [ShapeType]. + */ +fun renderCastingPrimitiveFromShapeType( + variable: String, + type: ShapeType, + writer: KotlinWriter, + timestampFormatTrait: TimestampFormatTrait? = null, + errorMessage: String? = null, +) { + when (type) { + ShapeType.BLOB -> writer.write("$variable.toByteArray()") + ShapeType.STRING -> writer.write("$variable.toString()") + ShapeType.BYTE -> writer.write("$variable.toByte()") + ShapeType.INTEGER -> writer.write("$variable.toInt()") + ShapeType.SHORT -> writer.write("$variable.toShort()") + ShapeType.LONG -> writer.write("$variable.toLong()") + ShapeType.FLOAT -> writer.write("$variable.toFloat()") + ShapeType.DOUBLE -> writer.write("$variable.toDouble()") + ShapeType.BIG_DECIMAL -> writer.write("$variable.toBigDecimal()") + ShapeType.BIG_INTEGER -> writer.write("$variable.toBigInteger()") + ShapeType.BOOLEAN -> writer.write("$variable.toBoolean()") + ShapeType.TIMESTAMP -> + when (timestampFormatTrait?.format) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> + writer.write("$variable.let{ #T.fromEpochSeconds(it) }", RuntimeTypes.Core.Instant) + TimestampFormatTrait.Format.DATE_TIME -> + writer.write("$variable.let{ #T.fromIso8601(it) }", RuntimeTypes.Core.Instant) + TimestampFormatTrait.Format.HTTP_DATE -> + writer.write("$variable.let{ #T.fromRfc5322(it) }", RuntimeTypes.Core.Instant) + else -> writer.write("$variable.let{ #T.fromEpochSeconds(it) }", RuntimeTypes.Core.Instant) + } + else -> throw IllegalStateException(errorMessage ?: "Unable to render casting primitive for $type") + } +} diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettingsTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettingsTest.kt index 2de132c9fa..7910b25cc1 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettingsTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettingsTest.kt @@ -11,6 +11,7 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.CsvSource +import org.junit.jupiter.params.support.ParameterDeclarations import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.kotlin.codegen.test.TestModelDefault import software.amazon.smithy.kotlin.codegen.test.toSmithyModel @@ -392,7 +393,10 @@ class TestProtocolSelectionArgumentProvider : ArgumentsProvider { private const val NO_CBOR = "awsJson1_0, awsJson1_1, restJson1, restXml, awsQuery, ec2Query" } - override fun provideArguments(context: ExtensionContext?): Stream = Stream.of( + override fun provideArguments( + parameters: ParameterDeclarations?, + context: ExtensionContext?, + ): Stream = Stream.of( Arguments.of( ALL_PROTOCOLS, "rpcv2Cbor, awsJson1_0", diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegatorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegatorTest.kt index 162f32ac92..a5fb28e902 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegatorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDelegatorTest.kt @@ -97,7 +97,9 @@ class KotlinDelegatorTest { Node.parse(configContents).expectObjectNode(), ) val manifest = MockManifest() - val delegator = KotlinDelegator(settings, model, manifest, KotlinSymbolProvider(model, settings)) + val symbolProvider = KotlinSymbolProvider(model, settings) + val ctx = GenerationContext(model, symbolProvider, settings, protocolGenerator = null) + val delegator = KotlinDelegator(ctx, manifest) val generatedSymbol = buildSymbol { name = "Foo" diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionGeneratorTest.kt index a2272852fe..6356d3fbf6 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ExceptionGeneratorTest.kt @@ -241,7 +241,7 @@ class ExceptionGeneratorTest { get() = listOf(SectionWriterBinding(ExceptionBaseClassGenerator.ExceptionBaseClassSection, exceptionSectionWriter)) private val exceptionSectionWriter = SectionWriter { writer, _ -> - val ctx = writer.getContextValue(ExceptionBaseClassGenerator.ExceptionBaseClassSection.CodegenContext) + val ctx = writer.getContextValue(CodegenContext.Key) ServiceExceptionBaseClassGenerator(exceptionBaseClassSymbol).render(ctx, writer) } } diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/DefaultEndpointDiscovererGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/DefaultEndpointDiscovererGeneratorTest.kt new file mode 100644 index 0000000000..d5e0a6e152 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/DefaultEndpointDiscovererGeneratorTest.kt @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.rendering.endpoints.discovery + +import org.junit.jupiter.api.Test +import software.amazon.smithy.build.MockManifest +import software.amazon.smithy.kotlin.codegen.test.formatForTest +import software.amazon.smithy.kotlin.codegen.test.newTestContext +import software.amazon.smithy.kotlin.codegen.test.shouldContainOnlyOnceWithDiff +import software.amazon.smithy.kotlin.codegen.test.toCodegenContext + +class DefaultEndpointDiscovererGeneratorTest { + private val renderedCodegen: String = run { + val model = model() + val testCtx = model.newTestContext() + val delegator = testCtx.generationCtx.delegator + val generator = DefaultEndpointDiscovererGenerator(testCtx.toCodegenContext(), delegator) + generator.render() + + delegator.flushWriters() + val testManifest = delegator.fileManifest as MockManifest + testManifest.expectFileString("/src/main/kotlin/com/test/endpoints/DefaultTestEndpointDiscoverer.kt") + } + + @Test + fun testClass() { + renderedCodegen.shouldContainOnlyOnceWithDiff( + """ + /** + * A class which looks up specific endpoints for Test calls via the `getEndpoints` API. These + * unique endpoints are cached as appropriate to avoid unnecessary latency in subsequent calls. + * @param cache An [ExpiringKeyedCache] implementation used to cache discovered hosts + */ + public class DefaultTestEndpointDiscoverer(public val cache: ExpiringKeyedCache = PeriodicSweepCache(10.minutes)) : TestEndpointDiscoverer { + """.trimIndent(), + ) + } + + @Test + fun testAsEndpointResolver() { + renderedCodegen.shouldContainOnlyOnceWithDiff( + """ + override fun asEndpointResolver(client: TestClient, delegate: EndpointResolver): EndpointResolver = EndpointResolver { request -> + if (client.config.endpointUrl == null) { + val identity = request.identity + require(identity is Credentials) { "Endpoint discovery requires AWS credentials" } + + val cacheKey = DiscoveryParams(client.config.region, identity.accessKeyId) + request.context[DiscoveryParamsKey] = cacheKey + val discoveredHost = cache.get(cacheKey) { discoverHost(client) } + + val originalEndpoint = delegate.resolve(request) + Endpoint( + originalEndpoint.uri.copy { host = discoveredHost }, + originalEndpoint.headers, + originalEndpoint.attributes, + ) + } else { + delegate.resolve(request) + } + } + """.formatForTest(), + ) + } + + @Test + fun testInvalidate() { + renderedCodegen.shouldContainOnlyOnceWithDiff( + """ + override public suspend fun invalidate(context: ExecutionContext) { + context.getOrNull(DiscoveryParamsKey)?.let { cache.invalidate(it) } + } + """.formatForTest(), + ) + } +} diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererGeneratorTest.kt deleted file mode 100644 index 37b9e6a900..0000000000 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererGeneratorTest.kt +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.kotlin.codegen.rendering.endpoints.discovery - -import org.junit.jupiter.api.Test -import software.amazon.smithy.build.MockManifest -import software.amazon.smithy.kotlin.codegen.test.* - -class EndpointDiscovererGeneratorTest { - @Test - fun testClass() { - val actual = render() - - actual.shouldContainOnlyOnceWithDiff( - """ - public class TestEndpointDiscoverer { - private val cache = ReadThroughCache(10.minutes, Clock.System) - """.trimIndent(), - ) - - actual.shouldContainOnlyOnceWithDiff( - """ - } - - private val discoveryParamsKey = AttributeKey("DiscoveryParams") - private data class DiscoveryParams(private val region: String?, private val identity: String) - """.trimIndent(), - ) - } - - @Test - fun testAsEndpointResolver() { - val actual = render() - - actual.shouldContainOnlyOnceWithDiff( - """ - internal fun asEndpointResolver(client: TestClient, delegate: EndpointResolverAdapter) = EndpointResolver { request -> - if (client.config.endpointUrl == null) { - val identity = request.identity - require(identity is Credentials) { "Endpoint discovery requires AWS credentials" } - - val cacheKey = DiscoveryParams(client.config.region, identity.accessKeyId) - request.context[discoveryParamsKey] = cacheKey - val discoveredHost = cache.get(cacheKey) { discoverHost(client) } - - val originalEndpoint = delegate.resolve(request) - Endpoint( - originalEndpoint.uri.copy { host = discoveredHost }, - originalEndpoint.headers, - originalEndpoint.attributes, - ) - } else { - delegate.resolve(request) - } - } - """.formatForTest(), - ) - } - - @Test - fun testDiscoverHost() { - val actual = render() - - actual.shouldContainOnlyOnceWithDiff( - """ - private suspend fun discoverHost(client: TestClient): ExpiringValue = - client.getEndpoints() - .endpoints - ?.map { ep -> ExpiringValue( - Host.parse(ep.address!!), - Instant.now() + ep.cachePeriodInMinutes.minutes, - )} - ?.firstOrNull() - ?: throw EndpointProviderException("Unable to discover any endpoints when invoking getEndpoints!") - """.formatForTest(), - ) - } - - @Test - fun testInvalidate() { - val actual = render() - - actual.shouldContainOnlyOnceWithDiff( - """ - internal suspend fun invalidate(context: ExecutionContext) { - context.getOrNull(discoveryParamsKey)?.let { cache.invalidate(it) } - } - """.formatForTest(), - ) - } - - private fun render(): String { - val model = model() - val testCtx = model.newTestContext() - val delegator = testCtx.generationCtx.delegator - val generator = EndpointDiscovererGenerator(testCtx.toCodegenContext(), delegator) - generator.render() - - delegator.flushWriters() - val testManifest = delegator.fileManifest as MockManifest - return testManifest.expectFileString("/src/main/kotlin/com/test/endpoints/TestEndpointDiscoverer.kt") - } - - private fun model() = - """ - namespace com.test - - use aws.protocols#awsJson1_1 - use aws.api#service - use aws.auth#sigv4 - - @service(sdkId: "test") - @sigv4(name: "test") - @awsJson1_1 - @aws.api#clientEndpointDiscovery( - operation: GetEndpoints, - error: BadEndpointError - ) - service Test { - version: "1.0.0", - operations: [GetEndpoints] - } - - @error("client") - @httpError(421) - structure BadEndpointError { } - - @http(method: "GET", uri: "/endpoints") - operation GetEndpoints { - input: GetEndpointsInput - output: GetEndpointsOutput - } - - @input - structure GetEndpointsInput { } - - @output - structure GetEndpointsOutput { - Endpoints: Endpoints - } - - list Endpoints { - member: Endpoint - } - - structure Endpoint { - Address: String - CachePeriodInMinutes: Long - } - """.toSmithyModel() -} diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererInterfaceGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererInterfaceGeneratorTest.kt new file mode 100644 index 0000000000..998fc8e577 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscovererInterfaceGeneratorTest.kt @@ -0,0 +1,75 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.rendering.endpoints.discovery + +import org.junit.jupiter.api.Test +import software.amazon.smithy.build.MockManifest +import software.amazon.smithy.kotlin.codegen.test.formatForTest +import software.amazon.smithy.kotlin.codegen.test.newTestContext +import software.amazon.smithy.kotlin.codegen.test.shouldContainOnlyOnceWithDiff +import software.amazon.smithy.kotlin.codegen.test.toCodegenContext + +class EndpointDiscovererInterfaceGeneratorTest { + @Test + fun testInterface() { + val actual = render() + + actual.shouldContainOnlyOnceWithDiff( + """ + /** + * Represents the logic for automatically discovering endpoints for Test calls + */ + public interface TestEndpointDiscoverer { + public fun asEndpointResolver(client: TestClient, delegate: EndpointResolver): EndpointResolver + """.trimIndent(), + ) + + actual.shouldContainOnlyOnceWithDiff( + """ + public suspend fun invalidate(context: ExecutionContext) + """.trimIndent(), + ) + + actual.shouldContainOnlyOnceWithDiff( + """ + } + + public data class DiscoveryParams(private val region: String?, private val identity: String) + public val DiscoveryParamsKey: AttributeKey = AttributeKey("DiscoveryParams") + """.trimIndent(), + ) + } + + @Test + fun testDiscoverHost() { + val actual = render() + + actual.shouldContainOnlyOnceWithDiff( + """ + public suspend fun discoverHost(client: TestClient): ExpiringValue = + client.getEndpoints() + .endpoints + ?.map { ep -> ExpiringValue( + Host.parse(ep.address!!), + Instant.now() + ep.cachePeriodInMinutes.minutes, + )} + ?.firstOrNull() + ?: throw EndpointProviderException("Unable to discover any endpoints when invoking getEndpoints!") + """.formatForTest(), + ) + } + + private fun render(): String { + val model = model() + val testCtx = model.newTestContext() + val delegator = testCtx.generationCtx.delegator + val generator = EndpointDiscovererInterfaceGenerator(testCtx.toCodegenContext(), delegator) + generator.render() + + delegator.flushWriters() + val testManifest = delegator.fileManifest as MockManifest + return testManifest.expectFileString("/src/main/kotlin/com/test/endpoints/TestEndpointDiscoverer.kt") + } +} diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegrationTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegrationTest.kt index 48fec68de5..1fdb3f3973 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegrationTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryIntegrationTest.kt @@ -36,7 +36,7 @@ class EndpointDiscoveryIntegrationTest { val contents = writer.toString() if (discoveryRequired) { - val configStr = "public val endpointDiscoverer: TestEndpointDiscoverer = builder.endpointDiscoverer ?: TestEndpointDiscoverer()" + val configStr = "public val endpointDiscoverer: TestEndpointDiscoverer = builder.endpointDiscoverer ?: DefaultTestEndpointDiscoverer()" contents.shouldContainOnlyOnceWithDiff(configStr) val builderStr = """ @@ -52,9 +52,8 @@ class EndpointDiscoveryIntegrationTest { val builderStr = """ /** - * The endpoint discoverer for this client, if applicable. By default, no endpoint - * discovery is provided. To use endpoint discovery, set this to a valid - * [TestEndpointDiscoverer] instance. + * The endpoint discoverer for this client, if applicable. By default, no endpoint discovery is + * provided. To use endpoint discovery, set this to a valid [TestEndpointDiscoverer] instance. */ public var endpointDiscoverer: TestEndpointDiscoverer? = null """.formatForTest(" ") @@ -62,6 +61,23 @@ class EndpointDiscoveryIntegrationTest { } } + @Test + fun testDiscoveredEndpointErrorMiddleware() { + val model = model() + val ctx = model.newTestContext(integrations = listOf(EndpointDiscoveryIntegration())) + val generator = MockHttpProtocolGenerator(model) + generator.generateProtocolClient(ctx.generationCtx) + + ctx.generationCtx.delegator.finalize() + ctx.generationCtx.delegator.flushWriters() + + val actual = ctx.manifest.expectFileString("/src/main/kotlin/com/test/DefaultTestClient.kt") + + val getFooMethod = actual.lines(" override suspend fun getFoo(input: GetFooRequest): GetFooResponse {", " }") + val expectedInterceptor = "config.endpointDiscoverer?.let { op.interceptors.add(DiscoveredEndpointErrorInterceptor(BadEndpointError::class, it::invalidate)) }" + getFooMethod.shouldContainOnlyOnceWithDiff(expectedInterceptor) + } + private fun model(discoveryRequired: Boolean = true) = """ namespace com.test diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryTestUtils.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryTestUtils.kt new file mode 100644 index 0000000000..55d58d6cc4 --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/discovery/EndpointDiscoveryTestUtils.kt @@ -0,0 +1,56 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.kotlin.codegen.rendering.endpoints.discovery + +import software.amazon.smithy.kotlin.codegen.test.toSmithyModel + +fun model() = + // language=smithy + """ + namespace com.test + + use aws.protocols#awsJson1_1 + use aws.api#service + use aws.auth#sigv4 + + @service(sdkId: "test") + @sigv4(name: "test") + @awsJson1_1 + @aws.api#clientEndpointDiscovery( + operation: GetEndpoints, + error: BadEndpointError + ) + service Test { + version: "1.0.0", + operations: [GetEndpoints] + } + + @error("client") + @httpError(421) + structure BadEndpointError { } + + @http(method: "GET", uri: "/endpoints") + operation GetEndpoints { + input: GetEndpointsInput + output: GetEndpointsOutput + } + + @input + structure GetEndpointsInput { } + + @output + structure GetEndpointsOutput { + Endpoints: Endpoints + } + + list Endpoints { + member: Endpoint + } + + structure Endpoint { + Address: String + CachePeriodInMinutes: Long + } + """.toSmithyModel() diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpStringValuesMapSerializerTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpStringValuesMapSerializerTest.kt index 95138ad529..097db4b4c8 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpStringValuesMapSerializerTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpStringValuesMapSerializerTest.kt @@ -53,7 +53,8 @@ class HttpStringValuesMapSerializerTest { @Test fun `it handles primitive header shapes when different mode`() { - val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.HEADER) + val settings = defaultModel.defaultSettings(defaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT) + val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.HEADER, settings) contents.assertBalancedBracesAndParens() val expectedContents = """ @@ -68,7 +69,8 @@ class HttpStringValuesMapSerializerTest { @Test fun `it handles primitive query shapes when different mode`() { - val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.QUERY) + val settings = defaultModel.defaultSettings(defaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT) + val contents = getTestContents(defaultModel, "com.test#PrimitiveShapesOperation", HttpBinding.Location.QUERY, settings) contents.assertBalancedBracesAndParens() val expectedContents = """ @@ -129,7 +131,8 @@ class HttpStringValuesMapSerializerTest { } """.prependNamespaceAndService(operations = listOf("Foo")).toSmithyModel() - val contents = getTestContents(model, "com.test#Foo", HttpBinding.Location.HEADER) + val settings = defaultModel.defaultSettings(defaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT) + val contents = getTestContents(model, "com.test#Foo", HttpBinding.Location.HEADER, settings) contents.assertBalancedBracesAndParens() val expectedContents = """ diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt index 347da007cb..c0eae1f05b 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerializeStructGeneratorTest.kt @@ -72,7 +72,8 @@ class SerializeStructGeneratorTest { } """.trimIndent() - val actual = codegenSerializerForShape(model, "com.test#Foo") + val settings = model.defaultSettings(defaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT) + val actual = codegenSerializerForShape(model, "com.test#Foo", settings = settings) actual.shouldContainOnlyOnceWithDiff(expected) } diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt index 49ee90f484..038a2ff7da 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt @@ -11,13 +11,10 @@ import software.amazon.smithy.build.MockManifest import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin import software.amazon.smithy.kotlin.codegen.KotlinSettings -import software.amazon.smithy.kotlin.codegen.core.CodegenContext +import software.amazon.smithy.kotlin.codegen.core.GenerationContext import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator -import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration import software.amazon.smithy.kotlin.codegen.loadModelFromResource -import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator import software.amazon.smithy.kotlin.codegen.test.* -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.ShapeId import kotlin.test.Test import kotlin.test.assertEquals @@ -130,16 +127,9 @@ class ServiceWaitersGeneratorTest { val service = model.getShape(ShapeId.from(TestModelDefault.SERVICE_SHAPE_ID)).get().asServiceShape().get() val settings = KotlinSettings(service.id, KotlinSettings.PackageSettings(TestModelDefault.NAMESPACE, TestModelDefault.MODEL_VERSION), sdkId = service.id.name) - val ctx = object : CodegenContext { - override val model: Model = model - override val symbolProvider: SymbolProvider = provider - override val settings: KotlinSettings = settings - override val protocolGenerator: ProtocolGenerator? = null - override val integrations: List = listOf() - } - val manifest = MockManifest() - val delegator = KotlinDelegator(settings, model, manifest, provider) + val ctx = GenerationContext(model, provider, settings, protocolGenerator = null) + val delegator = KotlinDelegator(ctx, manifest) val generator = ServiceWaitersGenerator() generator.writeAdditionalFiles(ctx, delegator) diff --git a/examples/service-codegen/build.bat b/examples/service-codegen/build.bat new file mode 100644 index 0000000000..2141f3f940 --- /dev/null +++ b/examples/service-codegen/build.bat @@ -0,0 +1,13 @@ +@echo off +if exist build ( + set /p choice="The 'build' directory already exists. Removing it will delete previous build artifacts. Continue? (y/n): " + if /i "%choice%"=="y" ( + gradle clean + echo Previous build directory removed. + ) else ( + echo Aborted. + exit /b 1 + ) +) + +gradle build diff --git a/examples/service-codegen/build.gradle.kts b/examples/service-codegen/build.gradle.kts new file mode 100644 index 0000000000..28003a6fef --- /dev/null +++ b/examples/service-codegen/build.gradle.kts @@ -0,0 +1,54 @@ +/* + * This file was generated by the Gradle 'init' task. + * + * This generated file contains a sample Kotlin application project to get you started. + * For more details on building Java & JVM projects, please refer to https://docs.gradle.org/8.14.2/userguide/building_java_projects.html in the Gradle documentation. + */ + +plugins { + // Apply the org.jetbrains.kotlin.jvm Plugin to add support for Kotlin. + alias(libs.plugins.kotlin.jvm) + + id("software.amazon.smithy.gradle.smithy-jar") version "1.3.0" + // Apply the application plugin to add support for building a CLI application in Java. + application +} + +repositories { + // Use Maven Central for resolving dependencies. + mavenCentral() + mavenLocal() +} + +val codegenVersion = "0.35.2-SNAPSHOT" +val smithyVersion = "1.60.2" + +dependencies { + smithyBuild("software.amazon.smithy.kotlin:smithy-kotlin-codegen:$codegenVersion") + implementation("software.amazon.smithy.kotlin:smithy-aws-kotlin-codegen:$codegenVersion") + implementation("software.amazon.smithy:smithy-model:$smithyVersion") + implementation("software.amazon.smithy:smithy-build:$smithyVersion") + implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion") + // Use the Kotlin JUnit 5 integration. + testImplementation("org.jetbrains.kotlin:kotlin-test-junit5") + + // Use the JUnit 5 integration. + testImplementation(libs.junit.jupiter.engine) + + testRuntimeOnly("org.junit.platform:junit-platform-launcher") + + // This dependency is used by the application. + implementation(libs.guava) +} + +// Apply a specific Java toolchain to ease working on different environments. +java { + toolchain { + languageVersion = JavaLanguageVersion.of(21) + } +} + +tasks.named("test") { + // Use JUnit Platform for unit tests. + useJUnitPlatform() +} diff --git a/examples/service-codegen/build.sh b/examples/service-codegen/build.sh new file mode 100644 index 0000000000..dc22c98a9c --- /dev/null +++ b/examples/service-codegen/build.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +if [ -d "build" ]; then + read -p "The 'build' directory already exists. Removing it will delete previous build artifacts. Continue? (y/n): " choice + case "$choice" in + y|Y ) + gradle clean + echo "Previous build directory removed." + ;; + * ) + echo "Aborted." + exit 1 + ;; + esac +fi + +gradle build diff --git a/examples/service-codegen/model/demo.smithy b/examples/service-codegen/model/demo.smithy new file mode 100644 index 0000000000..b887aa4e2c --- /dev/null +++ b/examples/service-codegen/model/demo.smithy @@ -0,0 +1,49 @@ +// model/greeter.smithy +$version: "2.0" + +namespace com.demo + +use aws.protocols#restJson1 +use smithy.api#httpBearerAuth + +@restJson1 +@httpBearerAuth +service DemoService { + version: "1.0.0" + operations: [ + SayHello + ] +} + +@http(method: "POST", uri: "/greet", code: 201) +operation SayHello { + input: SayHelloInput + output: SayHelloOutput + errors: [ + CustomError + ] +} + +@input +structure SayHelloInput { + @required + @length(min: 3, max: 10) + name: String + + @httpHeader("X-User-ID") + id: Integer +} + +@output +structure SayHelloOutput { + greeting: String +} + +@error("server") +@httpError(500) +structure CustomError { + msg: String + + @httpHeader("X-User-error") + err: String +} diff --git a/examples/service-codegen/smithy-build.json b/examples/service-codegen/smithy-build.json new file mode 100644 index 0000000000..7b1948b71c --- /dev/null +++ b/examples/service-codegen/smithy-build.json @@ -0,0 +1,30 @@ +{ + "version": "1.0", + "outputDirectory": "build/generated-src-test", + "plugins": { + "kotlin-codegen": { + "service": "com.demo#DemoService", + + "package": { + "name": "com.demo.server", + "version": "1.0.0" + }, + + "build": { + "rootProject": true, + "generateServiceProject": true, + "optInAnnotations": [ + "aws.smithy.kotlin.runtime.InternalApi", + "kotlinx.serialization.ExperimentalSerializationApi" + ] + }, + + "serviceStub": { + "framework": "ktor" + } + } + } +} + + + diff --git a/gradle.properties b/gradle.properties index aa80c53d95..9bb3537255 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,10 +13,10 @@ kotlinx.atomicfu.enableNativeIrTransformation=false org.gradle.jvmargs=-Xmx2G -XX:MaxMetaspaceSize=1G # SDK -sdkVersion=1.4.24-SNAPSHOT +sdkVersion=1.5.2-SNAPSHOT # codegen -codegenVersion=0.34.24-SNAPSHOT +codegenVersion=0.35.2-SNAPSHOT # FIXME Remove after Dokka 2.0 Gradle plugin is stable org.jetbrains.dokka.experimental.gradle.pluginMode=V2Enabled diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 94f31d2cbc..edfc411abf 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,33 +1,33 @@ [versions] -kotlin-version = "2.1.0" +kotlin-version = "2.2.0" dokka-version = "2.0.0" aws-kotlin-repo-tools-version = "0.4.32" # libs -coroutines-version = "1.9.0" -atomicfu-version = "0.25.0" -okhttp-version = "5.0.0-alpha.14" +coroutines-version = "1.10.2" +atomicfu-version = "0.29.0" +okhttp-version = "5.1.0" okhttp4-version = "4.12.0" -okio-version = "3.9.1" +okio-version = "3.15.0" otel-version = "1.45.0" slf4j-version = "2.0.16" slf4j-v1x-version = "1.7.36" -crt-kotlin-version = "0.9.1" +crt-kotlin-version = "0.10.0" micrometer-version = "1.14.2" -binary-compatibility-validator-version = "0.16.3" +binary-compatibility-validator-version = "0.18.0" # codegen smithy-version = "1.60.2" # testing -junit-version = "5.10.5" +junit-version = "5.13.2" kotest-version = "5.9.1" kotlin-compile-testing-version = "0.7.0" kotlinx-benchmark-version = "0.4.12" kotlinx-serialization-version = "1.7.3" docker-java-version = "3.4.0" -ktor-version = "3.1.1" +ktor-version = "3.2.2" kaml-version = "0.55.0" jsoup-version = "1.19.1" @@ -88,6 +88,7 @@ kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version. kotest-assertions-core-jvm = { module = "io.kotest:kotest-assertions-core-jvm", version.ref = "kotest-version" } kotlinx-benchmark-runtime = { module = "org.jetbrains.kotlinx:kotlinx-benchmark-runtime", version.ref = "kotlinx-benchmark-version" } kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-serialization-version" } +kotlinx-serialization-cbor = { module = "org.jetbrains.kotlinx:kotlinx-serialization-cbor", version.ref = "kotlinx-serialization-version" } docker-core = { module = "com.github.docker-java:docker-java-core", version.ref = "docker-java-version" } docker-transport-zerodep = { module = "com.github.docker-java:docker-java-transport-zerodep", version.ref = "docker-java-version" } diff --git a/runtime/auth/aws-credentials/api/aws-credentials.api b/runtime/auth/aws-credentials/api/aws-credentials.api index 15df96aff6..429f5eec3b 100644 --- a/runtime/auth/aws-credentials/api/aws-credentials.api +++ b/runtime/auth/aws-credentials/api/aws-credentials.api @@ -18,9 +18,9 @@ public abstract interface class aws/smithy/kotlin/runtime/auth/awscredentials/Cl public abstract interface class aws/smithy/kotlin/runtime/auth/awscredentials/Credentials : aws/smithy/kotlin/runtime/identity/Identity { public static final field Companion Laws/smithy/kotlin/runtime/auth/awscredentials/Credentials$Companion; public abstract fun getAccessKeyId ()Ljava/lang/String; - public abstract fun getProviderName ()Ljava/lang/String; + public fun getProviderName ()Ljava/lang/String; public abstract fun getSecretAccessKey ()Ljava/lang/String; - public abstract fun getSessionToken ()Ljava/lang/String; + public fun getSessionToken ()Ljava/lang/String; } public final class aws/smithy/kotlin/runtime/auth/awscredentials/Credentials$Companion { diff --git a/runtime/auth/aws-signing-common/api/aws-signing-common.api b/runtime/auth/aws-signing-common/api/aws-signing-common.api index c9353509be..046afa8c53 100644 --- a/runtime/auth/aws-signing-common/api/aws-signing-common.api +++ b/runtime/auth/aws-signing-common/api/aws-signing-common.api @@ -199,13 +199,6 @@ public final class aws/smithy/kotlin/runtime/auth/awssigning/PresignerKt { public static final fun presignRequest (Laws/smithy/kotlin/runtime/http/request/HttpRequestBuilder;Laws/smithy/kotlin/runtime/operation/ExecutionContext;Laws/smithy/kotlin/runtime/auth/awscredentials/CredentialsProvider;Laws/smithy/kotlin/runtime/http/operation/EndpointResolver;Laws/smithy/kotlin/runtime/auth/awssigning/AwsSigner;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } -public final class aws/smithy/kotlin/runtime/auth/awssigning/UnsupportedSigningAlgorithmException : aws/smithy/kotlin/runtime/ClientException { - public fun (Ljava/lang/String;Laws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAlgorithm;)V - public fun (Ljava/lang/String;Laws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAlgorithm;Ljava/lang/Throwable;)V - public synthetic fun (Ljava/lang/String;Laws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAlgorithm;Ljava/lang/Throwable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun getSigningAlgorithm ()Laws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAlgorithm; -} - public final class aws/smithy/kotlin/runtime/auth/awssigning/internal/AwsChunkedUtilKt { public static final field AWS_CHUNKED_THRESHOLD I public static final field CHUNK_SIZE_BYTES I diff --git a/runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningExceptions.kt b/runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningExceptions.kt deleted file mode 100644 index 8ff8c1e35f..0000000000 --- a/runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningExceptions.kt +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package aws.smithy.kotlin.runtime.auth.awssigning - -import aws.smithy.kotlin.runtime.ClientException -import aws.smithy.kotlin.runtime.InternalApi - -/** - * Is thrown when a signing algorithm is not supported by a signer - * - * See: [AwsSigningAlgorithm], [AwsSigner] - * - * @param message The message displayed by the exception - * @param signingAlgorithm The unsupported signing algorithm - * @param cause The cause of the exception - */ -@InternalApi -@Deprecated("This exception is no longer thrown. It will be removed in the next minor version, v1.5.x.") -public class UnsupportedSigningAlgorithmException( - message: String, - public val signingAlgorithm: AwsSigningAlgorithm, - cause: Throwable? = null, -) : ClientException( - message, - cause, -) { - public constructor( - message: String, - signingAlgorithm: AwsSigningAlgorithm, - ) : this ( - message, - signingAlgorithm, - null, - ) -} diff --git a/runtime/auth/aws-signing-default/common/src/aws/smithy/kotlin/runtime/auth/awssigning/SigV4aSignatureCalculator.kt b/runtime/auth/aws-signing-default/common/src/aws/smithy/kotlin/runtime/auth/awssigning/SigV4aSignatureCalculator.kt index 4f4239a2ce..c4fe4f5ca4 100644 --- a/runtime/auth/aws-signing-default/common/src/aws/smithy/kotlin/runtime/auth/awssigning/SigV4aSignatureCalculator.kt +++ b/runtime/auth/aws-signing-default/common/src/aws/smithy/kotlin/runtime/auth/awssigning/SigV4aSignatureCalculator.kt @@ -5,7 +5,7 @@ package aws.smithy.kotlin.runtime.auth.awssigning import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials -import aws.smithy.kotlin.runtime.collections.ReadThroughCache +import aws.smithy.kotlin.runtime.collections.PeriodicSweepCache import aws.smithy.kotlin.runtime.content.BigInteger import aws.smithy.kotlin.runtime.hashing.HashSupplier import aws.smithy.kotlin.runtime.hashing.Sha256 @@ -32,7 +32,7 @@ internal val N_MINUS_TWO = "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9 * @param sha256Provider the [HashSupplier] to use for computing SHA-256 hashes */ internal class SigV4aSignatureCalculator(override val sha256Provider: HashSupplier = ::Sha256) : BaseSigV4SignatureCalculator(AwsSigningAlgorithm.SIGV4_ASYMMETRIC, sha256Provider) { - private val privateKeyCache = ReadThroughCache( + private val privateKeyCache = PeriodicSweepCache( minimumSweepPeriod = 1.hours, // note: Sweeps are effectively a no-op because expiration is [Instant.MAX_VALUE] ) diff --git a/runtime/auth/aws-signing-tests/build.gradle.kts b/runtime/auth/aws-signing-tests/build.gradle.kts index 5810181041..667021303e 100644 --- a/runtime/auth/aws-signing-tests/build.gradle.kts +++ b/runtime/auth/aws-signing-tests/build.gradle.kts @@ -17,7 +17,6 @@ kotlin { api(project(":runtime:auth:http-auth-aws")) implementation(libs.kotlin.test) implementation(libs.kotlinx.coroutines.test) - implementation(libs.junit.jupiter.params) } } @@ -28,6 +27,7 @@ kotlin { implementation(libs.ktor.http.cio) implementation(libs.ktor.utils) implementation(libs.kotlin.test.junit5) + implementation(libs.junit.jupiter.params) implementation(libs.kotlinx.serialization.json) } } diff --git a/runtime/auth/aws-signing-tests/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/tests/SigningSuiteTestBaseJVM.kt b/runtime/auth/aws-signing-tests/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/tests/SigningSuiteTestBaseJVM.kt index b3c4202d26..70d48bfe45 100644 --- a/runtime/auth/aws-signing-tests/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/tests/SigningSuiteTestBaseJVM.kt +++ b/runtime/auth/aws-signing-tests/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/tests/SigningSuiteTestBaseJVM.kt @@ -11,7 +11,9 @@ import aws.smithy.kotlin.runtime.auth.awssigning.* import aws.smithy.kotlin.runtime.collections.Attributes import aws.smithy.kotlin.runtime.collections.ValuesMap import aws.smithy.kotlin.runtime.collections.get -import aws.smithy.kotlin.runtime.http.* +import aws.smithy.kotlin.runtime.http.HttpBody +import aws.smithy.kotlin.runtime.http.HttpMethod +import aws.smithy.kotlin.runtime.http.SdkHttpClient import aws.smithy.kotlin.runtime.http.auth.AwsHttpSigner import aws.smithy.kotlin.runtime.http.auth.SigV4AuthScheme import aws.smithy.kotlin.runtime.http.operation.* @@ -24,9 +26,7 @@ import aws.smithy.kotlin.runtime.net.url.Url import aws.smithy.kotlin.runtime.operation.ExecutionContext import aws.smithy.kotlin.runtime.time.Instant import io.ktor.http.cio.* -import io.ktor.util.* import io.ktor.utils.io.* -import io.ktor.utils.io.core.* import kotlinx.coroutines.runBlocking import kotlinx.io.readByteArray import kotlinx.serialization.json.* @@ -420,8 +420,7 @@ private fun buildOperation( serializeWith = object : HttpSerializer.NonStreaming { override fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder = serialized } - @Suppress("DEPRECATION") - deserializer = IdentityDeserializer + deserializeWith = HttpDeserializer.Identity context { operationName = "testSigningOperation" diff --git a/runtime/auth/http-auth-api/api/http-auth-api.api b/runtime/auth/http-auth-api/api/http-auth-api.api index 491aa48555..2f541fb6ef 100644 --- a/runtime/auth/http-auth-api/api/http-auth-api.api +++ b/runtime/auth/http-auth-api/api/http-auth-api.api @@ -1,7 +1,7 @@ public abstract interface class aws/smithy/kotlin/runtime/http/auth/AuthScheme { public abstract fun getSchemeId-DepwgT4 ()Ljava/lang/String; public abstract fun getSigner ()Laws/smithy/kotlin/runtime/http/auth/HttpSigner; - public abstract fun identityProvider (Laws/smithy/kotlin/runtime/identity/IdentityProviderConfig;)Laws/smithy/kotlin/runtime/identity/IdentityProvider; + public fun identityProvider (Laws/smithy/kotlin/runtime/identity/IdentityProviderConfig;)Laws/smithy/kotlin/runtime/identity/IdentityProvider; } public final class aws/smithy/kotlin/runtime/http/auth/AuthScheme$DefaultImpls { diff --git a/runtime/auth/http-auth-aws/common/test/aws/smithy/kotlin/runtime/http/auth/AwsHttpSignerTestBase.kt b/runtime/auth/http-auth-aws/common/test/aws/smithy/kotlin/runtime/http/auth/AwsHttpSignerTestBase.kt index 7d10b35e14..75be7cbbf2 100644 --- a/runtime/auth/http-auth-aws/common/test/aws/smithy/kotlin/runtime/http/auth/AwsHttpSignerTestBase.kt +++ b/runtime/auth/http-auth-aws/common/test/aws/smithy/kotlin/runtime/http/auth/AwsHttpSignerTestBase.kt @@ -37,13 +37,11 @@ class DefaultAwsHttpSignerTest : AwsHttpSignerTestBase(DefaultAwsSigner) * Basic sanity tests. Signing (including `AwsHttpSigner`) is covered by the more exhaustive * test suite in the `aws-signing-tests` module. */ -@Suppress("HttpUrlsUsage") public abstract class AwsHttpSignerTestBase( private val signer: AwsSigner, ) { private val testCredentials = Credentials("AKID", "SECRET", "SESSION") - @Suppress("DEPRECATION") private fun buildOperation( requestBody: String = "{\"TableName\": \"foo\"}", streaming: Boolean = false, @@ -51,29 +49,11 @@ public abstract class AwsHttpSignerTestBase( unsigned: Boolean = false, ): SdkHttpOperation { val operation: SdkHttpOperation = SdkHttpOperation.build { - serializer = object : HttpSerialize { - override suspend fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder = - HttpRequestBuilder().apply { - method = HttpMethod.POST - url.scheme = Scheme.HTTP - url.host = Host.Domain("demo.us-east-1.amazonaws.com") - url.path.encoded = "/" - headers.append("Host", "demo.us-east-1.amazonaws.com") - headers.appendAll("x-amz-archive-description", listOf("test", "test")) - body = when (streaming) { - true -> { - object : HttpBody.ChannelContent() { - override val contentLength: Long = requestBody.length.toLong() - override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(requestBody.encodeToByteArray()) - override val isOneShot: Boolean = !replayable - } - } - false -> HttpBody.fromBytes(requestBody.encodeToByteArray()) - } - headers.append("Content-Length", body.contentLength?.toString() ?: "0") - } + serializeWith = when (streaming) { + true -> StreamingSerializer(requestBody, replayable) + false -> NonStreamingSerializer(requestBody) } - deserializer = IdentityDeserializer + deserializeWith = HttpDeserializer.Identity operationName = "testSigningOperation" serviceName = "testService" context { @@ -186,3 +166,36 @@ public abstract class AwsHttpSignerTestBase( assertEquals(expectedSig, signed.headers["Authorization"]) } } + +private class NonStreamingSerializer(private val requestBody: String) : HttpSerializer.NonStreaming { + override fun serialize(context: ExecutionContext, input: Unit) = HttpRequestBuilder().apply { + method = HttpMethod.POST + url.scheme = Scheme.HTTP + url.host = Host.Domain("demo.us-east-1.amazonaws.com") + url.path.encoded = "/" + body = HttpBody.fromBytes(requestBody.encodeToByteArray()) + headers.append("Host", "demo.us-east-1.amazonaws.com") + headers.appendAll("x-amz-archive-description", listOf("test", "test")) + headers.append("Content-Length", body.contentLength?.toString() ?: "0") + } +} + +private class StreamingSerializer( + private val requestBody: String, + private val replayable: Boolean, +) : HttpSerializer.Streaming { + override suspend fun serialize(context: ExecutionContext, input: Unit) = HttpRequestBuilder().apply { + method = HttpMethod.POST + url.scheme = Scheme.HTTP + url.host = Host.Domain("demo.us-east-1.amazonaws.com") + url.path.encoded = "/" + body = object : HttpBody.ChannelContent() { + override val contentLength: Long = requestBody.length.toLong() + override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(requestBody.encodeToByteArray()) + override val isOneShot: Boolean = !replayable + } + headers.append("Host", "demo.us-east-1.amazonaws.com") + headers.appendAll("x-amz-archive-description", listOf("test", "test")) + headers.append("Content-Length", body.contentLength?.toString() ?: "0") + } +} diff --git a/runtime/auth/identity-api/api/identity-api.api b/runtime/auth/identity-api/api/identity-api.api index 9b07562934..824dc99385 100644 --- a/runtime/auth/identity-api/api/identity-api.api +++ b/runtime/auth/identity-api/api/identity-api.api @@ -56,6 +56,7 @@ public final class aws/smithy/kotlin/runtime/identity/IdentityAttributesKt { public abstract interface class aws/smithy/kotlin/runtime/identity/IdentityProvider { public abstract fun resolve (Laws/smithy/kotlin/runtime/collections/Attributes;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun resolve$default (Laws/smithy/kotlin/runtime/identity/IdentityProvider;Laws/smithy/kotlin/runtime/collections/Attributes;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } public final class aws/smithy/kotlin/runtime/identity/IdentityProvider$DefaultImpls { diff --git a/runtime/build.gradle.kts b/runtime/build.gradle.kts index 4aa9e06b3b..9da74f02a8 100644 --- a/runtime/build.gradle.kts +++ b/runtime/build.gradle.kts @@ -69,6 +69,7 @@ subprojects { freeCompilerArgs.add("-Xexpect-actual-classes") } } + tasks.withType { compilerOptions { freeCompilerArgs.add("-Xexpect-actual-classes") diff --git a/runtime/observability/telemetry-api/api/telemetry-api.api b/runtime/observability/telemetry-api/api/telemetry-api.api index 470061210b..5e1f6b8684 100644 --- a/runtime/observability/telemetry-api/api/telemetry-api.api +++ b/runtime/observability/telemetry-api/api/telemetry-api.api @@ -176,11 +176,16 @@ public abstract interface class aws/smithy/kotlin/runtime/telemetry/logging/Logg public static final field Companion Laws/smithy/kotlin/runtime/telemetry/logging/Logger$Companion; public abstract fun atLevel (Laws/smithy/kotlin/runtime/telemetry/logging/LogLevel;)Laws/smithy/kotlin/runtime/telemetry/logging/LogRecordBuilder; public abstract fun debug (Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;)V + public static synthetic fun debug$default (Laws/smithy/kotlin/runtime/telemetry/logging/Logger;Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;ILjava/lang/Object;)V public abstract fun error (Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;)V + public static synthetic fun error$default (Laws/smithy/kotlin/runtime/telemetry/logging/Logger;Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;ILjava/lang/Object;)V public abstract fun info (Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;)V + public static synthetic fun info$default (Laws/smithy/kotlin/runtime/telemetry/logging/Logger;Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;ILjava/lang/Object;)V public abstract fun isEnabledFor (Laws/smithy/kotlin/runtime/telemetry/logging/LogLevel;)Z public abstract fun trace (Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;)V + public static synthetic fun trace$default (Laws/smithy/kotlin/runtime/telemetry/logging/Logger;Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;ILjava/lang/Object;)V public abstract fun warn (Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;)V + public static synthetic fun warn$default (Laws/smithy/kotlin/runtime/telemetry/logging/Logger;Ljava/lang/Throwable;Lkotlin/jvm/functions/Function0;ILjava/lang/Object;)V } public final class aws/smithy/kotlin/runtime/telemetry/logging/Logger$Companion { @@ -266,6 +271,7 @@ public abstract class aws/smithy/kotlin/runtime/telemetry/metrics/AbstractUpDown public abstract interface class aws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurement { public abstract fun record (Ljava/lang/Number;Laws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;)V + public static synthetic fun record$default (Laws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurement;Ljava/lang/Number;Laws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;ILjava/lang/Object;)V } public final class aws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurement$DefaultImpls { @@ -284,6 +290,7 @@ public final class aws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurementH public abstract interface class aws/smithy/kotlin/runtime/telemetry/metrics/Histogram { public static final field Companion Laws/smithy/kotlin/runtime/telemetry/metrics/Histogram$Companion; public abstract fun record (Ljava/lang/Number;Laws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;)V + public static synthetic fun record$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Histogram;Ljava/lang/Number;Laws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;ILjava/lang/Object;)V } public final class aws/smithy/kotlin/runtime/telemetry/metrics/Histogram$Companion { @@ -307,12 +314,19 @@ public final class aws/smithy/kotlin/runtime/telemetry/metrics/HistogramKt { public abstract interface class aws/smithy/kotlin/runtime/telemetry/metrics/Meter { public static final field Companion Laws/smithy/kotlin/runtime/telemetry/metrics/Meter$Companion; public abstract fun createAsyncUpDownCounter (Ljava/lang/String;Lkotlin/jvm/functions/Function1;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurementHandle; + public static synthetic fun createAsyncUpDownCounter$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Meter;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurementHandle; public abstract fun createDoubleGauge (Ljava/lang/String;Lkotlin/jvm/functions/Function1;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurementHandle; + public static synthetic fun createDoubleGauge$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Meter;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurementHandle; public abstract fun createDoubleHistogram (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/telemetry/metrics/Histogram; + public static synthetic fun createDoubleHistogram$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Meter;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/metrics/Histogram; public abstract fun createLongGauge (Ljava/lang/String;Lkotlin/jvm/functions/Function1;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurementHandle; + public static synthetic fun createLongGauge$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Meter;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/metrics/AsyncMeasurementHandle; public abstract fun createLongHistogram (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/telemetry/metrics/Histogram; + public static synthetic fun createLongHistogram$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Meter;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/metrics/Histogram; public abstract fun createMonotonicCounter (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/telemetry/metrics/MonotonicCounter; + public static synthetic fun createMonotonicCounter$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Meter;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/metrics/MonotonicCounter; public abstract fun createUpDownCounter (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/telemetry/metrics/UpDownCounter; + public static synthetic fun createUpDownCounter$default (Laws/smithy/kotlin/runtime/telemetry/metrics/Meter;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/metrics/UpDownCounter; } public final class aws/smithy/kotlin/runtime/telemetry/metrics/Meter$Companion { @@ -341,6 +355,7 @@ public final class aws/smithy/kotlin/runtime/telemetry/metrics/MeterProvider$Com public abstract interface class aws/smithy/kotlin/runtime/telemetry/metrics/MonotonicCounter { public static final field Companion Laws/smithy/kotlin/runtime/telemetry/metrics/MonotonicCounter$Companion; public abstract fun add (JLaws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;)V + public static synthetic fun add$default (Laws/smithy/kotlin/runtime/telemetry/metrics/MonotonicCounter;JLaws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;ILjava/lang/Object;)V } public final class aws/smithy/kotlin/runtime/telemetry/metrics/MonotonicCounter$Companion { @@ -354,6 +369,7 @@ public final class aws/smithy/kotlin/runtime/telemetry/metrics/MonotonicCounter$ public abstract interface class aws/smithy/kotlin/runtime/telemetry/metrics/UpDownCounter { public static final field Companion Laws/smithy/kotlin/runtime/telemetry/metrics/UpDownCounter$Companion; public abstract fun add (JLaws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;)V + public static synthetic fun add$default (Laws/smithy/kotlin/runtime/telemetry/metrics/UpDownCounter;JLaws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/context/Context;ILjava/lang/Object;)V } public final class aws/smithy/kotlin/runtime/telemetry/metrics/UpDownCounter$Companion { @@ -427,9 +443,10 @@ public final class aws/smithy/kotlin/runtime/telemetry/trace/SpanStatus : java/l public abstract interface class aws/smithy/kotlin/runtime/telemetry/trace/TraceSpan : aws/smithy/kotlin/runtime/telemetry/context/Scope { public static final field Companion Laws/smithy/kotlin/runtime/telemetry/trace/TraceSpan$Companion; - public abstract fun asContextElement ()Lkotlin/coroutines/CoroutineContext; + public fun asContextElement ()Lkotlin/coroutines/CoroutineContext; public abstract fun close ()V public abstract fun emitEvent (Ljava/lang/String;Laws/smithy/kotlin/runtime/collections/Attributes;)V + public static synthetic fun emitEvent$default (Laws/smithy/kotlin/runtime/telemetry/trace/TraceSpan;Ljava/lang/String;Laws/smithy/kotlin/runtime/collections/Attributes;ILjava/lang/Object;)V public abstract fun getSpanContext ()Laws/smithy/kotlin/runtime/telemetry/trace/SpanContext; public abstract fun mergeAttributes (Laws/smithy/kotlin/runtime/collections/Attributes;)V public abstract fun set (Laws/smithy/kotlin/runtime/collections/AttributeKey;Ljava/lang/Object;)V @@ -468,6 +485,7 @@ public final class aws/smithy/kotlin/runtime/telemetry/trace/TraceSpanExtKt { public abstract interface class aws/smithy/kotlin/runtime/telemetry/trace/Tracer { public static final field Companion Laws/smithy/kotlin/runtime/telemetry/trace/Tracer$Companion; public abstract fun createSpan (Ljava/lang/String;Laws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/trace/SpanKind;Laws/smithy/kotlin/runtime/telemetry/context/Context;)Laws/smithy/kotlin/runtime/telemetry/trace/TraceSpan; + public static synthetic fun createSpan$default (Laws/smithy/kotlin/runtime/telemetry/trace/Tracer;Ljava/lang/String;Laws/smithy/kotlin/runtime/collections/Attributes;Laws/smithy/kotlin/runtime/telemetry/trace/SpanKind;Laws/smithy/kotlin/runtime/telemetry/context/Context;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/telemetry/trace/TraceSpan; } public final class aws/smithy/kotlin/runtime/telemetry/trace/Tracer$Companion { diff --git a/runtime/protocol/aws-json-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/json/AwsJsonProtocolTest.kt b/runtime/protocol/aws-json-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/json/AwsJsonProtocolTest.kt index ac53fc26a8..172d831283 100644 --- a/runtime/protocol/aws-json-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/json/AwsJsonProtocolTest.kt +++ b/runtime/protocol/aws-json-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/json/AwsJsonProtocolTest.kt @@ -6,8 +6,10 @@ package aws.smithy.kotlin.runtime.awsprotocol.json import aws.smithy.kotlin.runtime.collections.get -import aws.smithy.kotlin.runtime.http.* +import aws.smithy.kotlin.runtime.http.HttpBody +import aws.smithy.kotlin.runtime.http.SdkHttpClient import aws.smithy.kotlin.runtime.http.operation.* +import aws.smithy.kotlin.runtime.http.readAll import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder import aws.smithy.kotlin.runtime.http.response.HttpResponse import aws.smithy.kotlin.runtime.httptest.TestEngine @@ -20,10 +22,9 @@ class AwsJsonProtocolTest { @Test fun testSetJsonProtocolHeaders() = runTest { - @Suppress("DEPRECATION") val op = SdkHttpOperation.build { - serializer = UnitSerializer - deserializer = IdentityDeserializer + serializeWith = HttpSerializer.Unit + deserializeWith = HttpDeserializer.Identity operationName = "Bar" serviceName = "Foo" } @@ -42,10 +43,9 @@ class AwsJsonProtocolTest { @Test fun testEmptyBody() = runTest { - @Suppress("DEPRECATION") val op = SdkHttpOperation.build { - serializer = UnitSerializer - deserializer = IdentityDeserializer + serializeWith = HttpSerializer.Unit + deserializeWith = HttpDeserializer.Identity operationName = "Bar" serviceName = "Foo" } @@ -63,14 +63,14 @@ class AwsJsonProtocolTest { fun testDoesNotOverride() = runTest { @Suppress("DEPRECATION") val op = SdkHttpOperation.build { - serializer = object : HttpSerialize { - override suspend fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder = + serializeWith = object : HttpSerializer.NonStreaming { + override fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder = HttpRequestBuilder().apply { headers["Content-Type"] = "application/xml" body = HttpBody.fromBytes("foo".encodeToByteArray()) } } - deserializer = IdentityDeserializer + deserializeWith = HttpDeserializer.Identity operationName = "Bar" serviceName = "Foo" } diff --git a/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api b/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api index 7325432724..ab7db70f5a 100644 --- a/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api +++ b/runtime/protocol/aws-xml-protocols/api/aws-xml-protocols.api @@ -1,10 +1,8 @@ public final class aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializerKt { - public static final fun parseEc2QueryErrorResponse ([BLkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static final fun parseEc2QueryErrorResponseNoSuspend ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; + public static final fun parseEc2QueryErrorResponse ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; } public final class aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializerKt { - public static final fun parseRestXmlErrorResponse ([BLkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static final fun parseRestXmlErrorResponseNoSuspend ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; + public static final fun parseRestXmlErrorResponse ([B)Laws/smithy/kotlin/runtime/awsprotocol/ErrorDetails; } diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt index 63391e1b12..2bd349177c 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializer.kt @@ -7,18 +7,16 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.awsprotocol.ErrorDetails import aws.smithy.kotlin.runtime.serde.getOrDeserializeErr -import aws.smithy.kotlin.runtime.serde.xml.* +import aws.smithy.kotlin.runtime.serde.xml.XmlTagReader +import aws.smithy.kotlin.runtime.serde.xml.data +import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader internal data class Ec2QueryErrorResponse(val errors: List, val requestId: String?) internal data class Ec2QueryError(val code: String?, val message: String?) -@Deprecated("use parseEc2QueryErrorResponseNoSuspend") @InternalApi -public suspend fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails = - parseEc2QueryErrorResponseNoSuspend(payload) - -public fun parseEc2QueryErrorResponseNoSuspend(payload: ByteArray): ErrorDetails { +public fun parseEc2QueryErrorResponse(payload: ByteArray): ErrorDetails { val response = Ec2QueryErrorResponseDeserializer.deserialize(xmlTagReader(payload)) val firstError = response.errors.firstOrNull() return ErrorDetails(firstError?.code, firstError?.message, response.requestId) diff --git a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt index 99925c0dc1..74a9448146 100644 --- a/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt +++ b/runtime/protocol/aws-xml-protocols/common/src/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializer.kt @@ -6,7 +6,7 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.awsprotocol.ErrorDetails -import aws.smithy.kotlin.runtime.serde.* +import aws.smithy.kotlin.runtime.serde.getOrDeserializeErr import aws.smithy.kotlin.runtime.serde.xml.XmlTagReader import aws.smithy.kotlin.runtime.serde.xml.data import aws.smithy.kotlin.runtime.serde.xml.xmlTagReader @@ -26,18 +26,8 @@ internal data class XmlError( override val message: String?, ) : RestXmlErrorDetails -/** - * Deserializes rest XML protocol errors as specified by: - * https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#error-response-serialization - * - * Returns parsed data in normalized form or throws [DeserializationException] if response cannot be parsed. - */ -@Deprecated("use parseRestXmlErrorResponseNoSuspend") @InternalApi -public suspend fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails = - parseRestXmlErrorResponseNoSuspend(payload) - -public fun parseRestXmlErrorResponseNoSuspend(payload: ByteArray): ErrorDetails { +public fun parseRestXmlErrorResponse(payload: ByteArray): ErrorDetails { val details = XmlErrorDeserializer.deserialize(xmlTagReader(payload)) return ErrorDetails(details.code, details.message, details.requestId) } diff --git a/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializerTest.kt b/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializerTest.kt index 6dfb531a5a..d2b4a6972e 100644 --- a/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializerTest.kt +++ b/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/Ec2QueryErrorDeserializerTest.kt @@ -26,7 +26,7 @@ class Ec2QueryErrorDeserializerTest { foo-request """.trimIndent().encodeToByteArray() - val actual = parseEc2QueryErrorResponseNoSuspend(payload) + val actual = parseEc2QueryErrorResponse(payload) assertEquals("InvalidGreeting", actual.code) assertEquals("Hi", actual.message) assertEquals("foo-request", actual.requestId) @@ -61,7 +61,7 @@ class Ec2QueryErrorDeserializerTest { for (payload in tests) { assertFailsWith { - parseEc2QueryErrorResponseNoSuspend(payload) + parseEc2QueryErrorResponse(payload) } } } @@ -90,7 +90,7 @@ class Ec2QueryErrorDeserializerTest { ).map { it.trimIndent().encodeToByteArray() } for (payload in tests) { - val actual = parseEc2QueryErrorResponseNoSuspend(payload) + val actual = parseEc2QueryErrorResponse(payload) assertNull(actual.code) assertNull(actual.message) assertEquals("foo-request", actual.requestId) diff --git a/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializerTest.kt b/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializerTest.kt index 9b94c7c31c..1f08e7bd3c 100644 --- a/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializerTest.kt +++ b/runtime/protocol/aws-xml-protocols/common/test/aws/smithy/kotlin/runtime/awsprotocol/xml/RestXmlErrorDeserializerTest.kt @@ -6,7 +6,10 @@ package aws.smithy.kotlin.runtime.awsprotocol.xml import aws.smithy.kotlin.runtime.serde.DeserializationException import kotlinx.coroutines.test.runTest -import kotlin.test.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNull class RestXmlErrorDeserializerTest { @@ -36,7 +39,7 @@ class RestXmlErrorDeserializerTest { ) for (payload in tests) { - val actual = parseRestXmlErrorResponseNoSuspend(payload) + val actual = parseRestXmlErrorResponse(payload) assertEquals("InvalidGreeting", actual.code) assertEquals("Hi", actual.message) assertEquals("foo-id", actual.requestId) @@ -70,7 +73,7 @@ class RestXmlErrorDeserializerTest { for (payload in tests) { assertFailsWith { - parseRestXmlErrorResponseNoSuspend(payload) + parseRestXmlErrorResponse(payload) } } } @@ -92,7 +95,7 @@ class RestXmlErrorDeserializerTest { ) for (payload in tests) { - val error = parseRestXmlErrorResponseNoSuspend(payload) + val error = parseRestXmlErrorResponse(payload) assertEquals("foo-id", error.requestId) assertNull(error.code) assertNull(error.message) diff --git a/runtime/protocol/http-client-engines/http-client-engine-okhttp/api/http-client-engine-okhttp.api b/runtime/protocol/http-client-engines/http-client-engine-okhttp/api/http-client-engine-okhttp.api index 25c2339550..77e6120478 100644 --- a/runtime/protocol/http-client-engines/http-client-engine-okhttp/api/http-client-engine-okhttp.api +++ b/runtime/protocol/http-client-engines/http-client-engine-okhttp/api/http-client-engine-okhttp.api @@ -87,7 +87,7 @@ public final class aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngineConf } public final class aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngineKt { - public static final fun buildClient (Laws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngineConfig;Laws/smithy/kotlin/runtime/http/engine/internal/HttpClientMetrics;)Lokhttp3/OkHttpClient; + public static final fun buildClient (Laws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngineConfig;Laws/smithy/kotlin/runtime/http/engine/internal/HttpClientMetrics;[Lokhttp3/EventListener;)Lokhttp3/OkHttpClient; } public final class aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpHeadersAdapter : aws/smithy/kotlin/runtime/http/Headers { diff --git a/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/ConnectionIdleMonitor.kt b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/ConnectionMonitoringEventListener.kt similarity index 78% rename from runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/ConnectionIdleMonitor.kt rename to runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/ConnectionMonitoringEventListener.kt index 3f4c366f70..131466e66e 100644 --- a/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/ConnectionIdleMonitor.kt +++ b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/ConnectionMonitoringEventListener.kt @@ -4,12 +4,10 @@ */ package aws.smithy.kotlin.runtime.http.engine.okhttp +import aws.smithy.kotlin.runtime.io.Closeable import aws.smithy.kotlin.runtime.telemetry.logging.logger import kotlinx.coroutines.* -import okhttp3.Call -import okhttp3.Connection -import okhttp3.ConnectionListener -import okhttp3.ExperimentalOkHttpApi +import okhttp3.* import okhttp3.internal.closeQuietly import okio.IOException import okio.buffer @@ -22,12 +20,20 @@ import kotlin.coroutines.coroutineContext import kotlin.time.Duration import kotlin.time.measureTime -@OptIn(ExperimentalOkHttpApi::class) -internal class ConnectionIdleMonitor(val pollInterval: Duration) : ConnectionListener() { +/** + * An [okhttp3.EventListener] implementation that monitors connections for remote closure. + * This replaces the functionality previously provided by the now-internal [okhttp3.ConnectionListener]. + */ +internal class ConnectionMonitoringEventListener(private val pollInterval: Duration) : + EventListener(), + Closeable { private val monitorScope = CoroutineScope(Dispatchers.IO + SupervisorJob()) private val monitors = ConcurrentHashMap() - fun close(): Unit = runBlocking { + /** + * Close all active connection monitors. + */ + override fun close(): Unit = runBlocking { val monitorJob = requireNotNull(monitorScope.coroutineContext[Job]) { "Connection idle monitor scope cannot be cancelled because it does not have a job: $this" } @@ -40,13 +46,16 @@ internal class ConnectionIdleMonitor(val pollInterval: Duration) : ConnectionLis ?.callContext ?: Dispatchers.IO - override fun connectionAcquired(connection: Connection, call: Call) { + // Cancel monitoring when a connection is acquired + override fun connectionAcquired(call: Call, connection: Connection) { + super.connectionAcquired(call, connection) + // Non-locking map access is okay here because this code will only execute synchronously as part of a // `connectionAcquired` event and will be complete before any future `connectionReleased` event could fire for // the same connection. monitors.remove(connection)?.let { monitor -> val context = call.callContext() - val logger = context.logger() + val logger = context.logger() logger.trace { "Cancel monitoring for $connection" } // Use `runBlocking` because this _must_ finish before OkHttp goes to use the connection @@ -58,13 +67,18 @@ internal class ConnectionIdleMonitor(val pollInterval: Duration) : ConnectionLis } } - override fun connectionReleased(connection: Connection, call: Call) { + // Start monitoring when a connection is released + override fun connectionReleased(call: Call, connection: Connection) { + super.connectionReleased(call, connection) + val connId = System.identityHashCode(connection) val callContext = call.callContext() + + // Start monitoring val monitor = monitorScope.launch(CoroutineName("okhttp-conn-monitor-for-$connId")) { doMonitor(connection, callContext) } - callContext.logger().trace { "Launched coroutine $monitor to monitor $connection" } + callContext.logger().trace { "Launched coroutine $monitor to monitor $connection" } // Non-locking map access is okay here because this code will only execute synchronously as part of a // `connectionReleased` event and will be complete before any future `connectionAcquired` event could fire for @@ -73,7 +87,7 @@ internal class ConnectionIdleMonitor(val pollInterval: Duration) : ConnectionLis } private suspend fun doMonitor(conn: Connection, callContext: CoroutineContext) { - val logger = callContext.logger() + val logger = callContext.logger() val socket = conn.socket() val source = try { diff --git a/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/EventListenerChain.kt b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/EventListenerChain.kt new file mode 100644 index 0000000000..c18bd331f3 --- /dev/null +++ b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/EventListenerChain.kt @@ -0,0 +1,115 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.http.engine.okhttp + +import aws.smithy.kotlin.runtime.io.closeIfCloseable +import okhttp3.* +import java.io.IOException +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Proxy + +/** + * An [okhttp3.EventListener] that delegates to a chain of EventListeners. + * Start events are sent in forward order, terminal events are sent in reverse order + */ +internal class EventListenerChain( + private val listeners: List, +) : EventListener() { + private val reverseListeners = listeners.reversed() + + fun close() { + listeners.forEach { + it.closeIfCloseable() + } + } + + override fun callStart(call: Call): Unit = + listeners.forEach { it.callStart(call) } + + override fun dnsStart(call: Call, domainName: String): Unit = + listeners.forEach { it.dnsStart(call, domainName) } + + override fun dnsEnd(call: Call, domainName: String, inetAddressList: List): Unit = + reverseListeners.forEach { it.dnsEnd(call, domainName, inetAddressList) } + + override fun proxySelectStart(call: Call, url: HttpUrl): Unit = + listeners.forEach { it.proxySelectStart(call, url) } + + override fun proxySelectEnd(call: Call, url: HttpUrl, proxies: List): Unit = + reverseListeners.forEach { it.proxySelectEnd(call, url, proxies) } + + override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy): Unit = + listeners.forEach { it.connectStart(call, inetSocketAddress, proxy) } + + override fun secureConnectStart(call: Call): Unit = + listeners.forEach { it.secureConnectStart(call) } + + override fun secureConnectEnd(call: Call, handshake: Handshake?): Unit = + reverseListeners.forEach { it.secureConnectEnd(call, handshake) } + + override fun connectEnd(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?): Unit = + reverseListeners.forEach { it.connectEnd(call, inetSocketAddress, proxy, protocol) } + + override fun connectFailed(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?, ioe: IOException): Unit = + reverseListeners.forEach { it.connectFailed(call, inetSocketAddress, proxy, protocol, ioe) } + + override fun connectionAcquired(call: Call, connection: Connection): Unit = + listeners.forEach { it.connectionAcquired(call, connection) } + + override fun connectionReleased(call: Call, connection: Connection): Unit = + reverseListeners.forEach { it.connectionReleased(call, connection) } + + override fun requestHeadersStart(call: Call): Unit = + listeners.forEach { it.requestHeadersStart(call) } + + override fun requestHeadersEnd(call: Call, request: Request): Unit = + reverseListeners.forEach { it.requestHeadersEnd(call, request) } + + override fun requestBodyStart(call: Call): Unit = + listeners.forEach { it.requestBodyStart(call) } + + override fun requestBodyEnd(call: Call, byteCount: Long): Unit = + reverseListeners.forEach { it.requestBodyEnd(call, byteCount) } + + override fun requestFailed(call: Call, ioe: IOException): Unit = + reverseListeners.forEach { it.requestFailed(call, ioe) } + + override fun responseHeadersStart(call: Call): Unit = + listeners.forEach { it.responseHeadersStart(call) } + + override fun responseHeadersEnd(call: Call, response: Response): Unit = + reverseListeners.forEach { it.responseHeadersEnd(call, response) } + + override fun responseBodyStart(call: Call): Unit = + listeners.forEach { it.responseBodyStart(call) } + + override fun responseBodyEnd(call: Call, byteCount: Long): Unit = + reverseListeners.forEach { it.responseBodyEnd(call, byteCount) } + + override fun responseFailed(call: Call, ioe: IOException): Unit = + reverseListeners.forEach { it.responseFailed(call, ioe) } + + override fun callEnd(call: Call): Unit = + reverseListeners.forEach { it.callEnd(call) } + + override fun callFailed(call: Call, ioe: IOException): Unit = + reverseListeners.forEach { it.callFailed(call, ioe) } + + override fun canceled(call: Call): Unit = + reverseListeners.forEach { it.canceled(call) } + + override fun satisfactionFailure(call: Call, response: Response): Unit = + reverseListeners.forEach { it.satisfactionFailure(call, response) } + + override fun cacheConditionalHit(call: Call, cachedResponse: Response): Unit = + listeners.forEach { it.cacheConditionalHit(call, cachedResponse) } + + override fun cacheHit(call: Call, response: Response): Unit = + listeners.forEach { it.cacheHit(call, response) } + + override fun cacheMiss(call: Call): Unit = + listeners.forEach { it.cacheMiss(call) } +} diff --git a/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngine.kt b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngine.kt index a20387d0a0..297e015eed 100644 --- a/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngine.kt +++ b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngine.kt @@ -8,9 +8,13 @@ package aws.smithy.kotlin.runtime.http.engine.okhttp import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.http.HttpCall import aws.smithy.kotlin.runtime.http.config.EngineFactory -import aws.smithy.kotlin.runtime.http.engine.* +import aws.smithy.kotlin.runtime.http.engine.AlpnId +import aws.smithy.kotlin.runtime.http.engine.HttpClientEngineBase +import aws.smithy.kotlin.runtime.http.engine.TlsContext +import aws.smithy.kotlin.runtime.http.engine.callContext import aws.smithy.kotlin.runtime.http.engine.internal.HttpClientMetrics import aws.smithy.kotlin.runtime.http.request.HttpRequest +import aws.smithy.kotlin.runtime.io.closeIfCloseable import aws.smithy.kotlin.runtime.net.TlsVersion import aws.smithy.kotlin.runtime.operation.ExecutionContext import aws.smithy.kotlin.runtime.time.Instant @@ -18,7 +22,6 @@ import aws.smithy.kotlin.runtime.time.fromEpochMilliseconds import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.job import okhttp3.* -import okhttp3.ConnectionPool import okhttp3.coroutines.executeAsync import java.util.concurrent.TimeUnit import kotlin.time.toJavaDuration @@ -44,9 +47,14 @@ public class OkHttpEngine( override val engineConstructor: (OkHttpEngineConfig.Builder.() -> Unit) -> OkHttpEngine = ::invoke } + // Create a single shared connection monitoring listener if idle polling is enabled + private val connectionMonitoringListener: EventListener? = + config.connectionIdlePollingInterval?.let { + ConnectionMonitoringEventListener(it) + } + private val metrics = HttpClientMetrics(TELEMETRY_SCOPE, config.telemetryProvider) - private val connectionIdleMonitor = config.connectionIdlePollingInterval?.let { ConnectionIdleMonitor(it) } - private val client = config.buildClientWithConnectionListener(metrics, connectionIdleMonitor) + private val client = config.buildClient(metrics, connectionMonitoringListener) @OptIn(ExperimentalCoroutinesApi::class) override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { @@ -73,16 +81,20 @@ public class OkHttpEngine( } override fun shutdown() { - connectionIdleMonitor?.close() + connectionMonitoringListener?.closeIfCloseable() client.connectionPool.evictAll() client.dispatcher.executorService.shutdown() metrics.close() } } -private fun OkHttpEngineConfig.buildClientFromConfig( +/** + * Convert SDK version of HTTP configuration to OkHttp specific configuration and return the configured client + */ +@InternalApi +public fun OkHttpEngineConfig.buildClient( metrics: HttpClientMetrics, - poolOverride: ConnectionPool? = null, + vararg clientScopedEventListeners: EventListener?, ): OkHttpClient { val config = this @@ -102,7 +114,7 @@ private fun OkHttpEngineConfig.buildClientFromConfig( writeTimeout(config.socketWriteTimeout.toJavaDuration()) // use our own pool configured with the timeout settings taken from config - val pool = poolOverride ?: ConnectionPool( + val pool = ConnectionPool( maxIdleConnections = 5, // The default from the no-arg ConnectionPool() constructor keepAliveDuration = config.connectionIdleTimeout.inWholeMilliseconds, TimeUnit.MILLISECONDS, @@ -116,7 +128,14 @@ private fun OkHttpEngineConfig.buildClientFromConfig( dispatcher(dispatcher) // Log events coming from okhttp. Allocate a new listener per-call to facilitate dedicated trace spans. - eventListenerFactory { call -> HttpEngineEventListener(pool, config.hostResolver, dispatcher, metrics, call) } + eventListenerFactory { call -> + EventListenerChain( + listOfNotNull( + HttpEngineEventListener(pool, config.hostResolver, dispatcher, metrics, call), + *clientScopedEventListeners, + ), + ) + } // map protocols if (config.tlsContext.alpn.isNotEmpty()) { @@ -140,34 +159,6 @@ private fun OkHttpEngineConfig.buildClientFromConfig( }.build() } -/** - * Convert SDK version of HTTP configuration to OkHttp specific configuration and return the configured client - */ -// Used by OkHttp4Engine - OkHttp4 does NOT have `connectionListener` -// TODO - Refactor in next minor version - Move this to OkHttp4Engine and make it private -@InternalApi -public fun OkHttpEngineConfig.buildClient( - metrics: HttpClientMetrics, -): OkHttpClient = this.buildClientFromConfig(metrics) - -/** - * Convert SDK version of HTTP configuration to OkHttp specific configuration and return the configured client - */ -// Used by OkHttpEngine - OkHttp5 does have `connectionListener` -@OptIn(ExperimentalOkHttpApi::class) -private fun OkHttpEngineConfig.buildClientWithConnectionListener( - metrics: HttpClientMetrics, - connectionListener: ConnectionIdleMonitor?, -): OkHttpClient = this.buildClientFromConfig( - metrics, - ConnectionPool( - maxIdleConnections = 5, // The default from the no-arg ConnectionPool() constructor - keepAliveDuration = this.connectionIdleTimeout.inWholeMilliseconds, - timeUnit = TimeUnit.MILLISECONDS, - connectionListener = connectionListener ?: ConnectionListener.NONE, - ), -) - private fun minTlsConnectionSpec(tlsContext: TlsContext): ConnectionSpec { val minVersion = tlsContext.minVersion ?: TlsVersion.TLS_1_2 val okHttpTlsVersions = SdkTlsVersion diff --git a/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/test/aws/smithy/kotlin/runtime/http/engine/okhttp/EventListenerChainTest.kt b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/test/aws/smithy/kotlin/runtime/http/engine/okhttp/EventListenerChainTest.kt new file mode 100644 index 0000000000..c9acb07cef --- /dev/null +++ b/runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/test/aws/smithy/kotlin/runtime/http/engine/okhttp/EventListenerChainTest.kt @@ -0,0 +1,281 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.http.engine.okhttp + +import okhttp3.* +import java.io.Closeable +import java.io.IOException +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Proxy +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class EventListenerChainTest { + @Test + fun testForwardEvents() { + val eventOrder = mutableListOf() + + val listener1 = TestEventListener("listener1", eventOrder) + val listener2 = TestEventListener("listener2", eventOrder) + + val chain = EventListenerChain(listOf(listener1, listener2)) + + val call = createMockCall() + + // Test forward events + chain.callStart(call) + chain.dnsStart(call, "example.com") + chain.proxySelectStart(call, createHttpUrl()) + + // Verify forward events were called in order (listener1 first, then listener2) + assertEquals("listener1:callStart", eventOrder[0]) + assertEquals("listener2:callStart", eventOrder[1]) + assertEquals("listener1:dnsStart", eventOrder[2]) + assertEquals("listener2:dnsStart", eventOrder[3]) + assertEquals("listener1:proxySelectStart", eventOrder[4]) + assertEquals("listener2:proxySelectStart", eventOrder[5]) + } + + @Test + fun testReverseEvents() { + val eventOrder = mutableListOf() + + val listener1 = TestEventListener("listener1", eventOrder) + val listener2 = TestEventListener("listener2", eventOrder) + + val chain = EventListenerChain(listOf(listener1, listener2)) + + val call = createMockCall() + + // Test reverse events + chain.dnsEnd(call, "example.com", listOf()) + chain.proxySelectEnd(call, createHttpUrl(), listOf()) + chain.callEnd(call) + + // Verify reverse events were called in reverse order (listener2 first, then listener1) + assertEquals("listener2:dnsEnd", eventOrder[0]) + assertEquals("listener1:dnsEnd", eventOrder[1]) + assertEquals("listener2:proxySelectEnd", eventOrder[2]) + assertEquals("listener1:proxySelectEnd", eventOrder[3]) + assertEquals("listener2:callEnd", eventOrder[4]) + assertEquals("listener1:callEnd", eventOrder[5]) + } + + @Test + fun testClose() { + val eventOrder = mutableListOf() + + val listener1 = TestEventListener("listener1", eventOrder) + val listener2 = TestEventListener("listener2", eventOrder) + + val chain = EventListenerChain(listOf(listener1, listener2)) + + // Close the chain + chain.close() + + // Verify all listeners were closed + assertTrue(listener1.closed) + assertTrue(listener2.closed) + } + + @Test + fun testMixedEvents() { + val eventOrder = mutableListOf() + + val listener1 = TestEventListener("listener1", eventOrder) + val listener2 = TestEventListener("listener2", eventOrder) + + val chain = EventListenerChain(listOf(listener1, listener2)) + + val call = createMockCall() + + // Test mixed forward and reverse events + chain.callStart(call) + chain.dnsStart(call, "example.com") + chain.dnsEnd(call, "example.com", listOf()) + + // Verify the order of events + assertEquals("listener1:callStart", eventOrder[0]) // listener1 first (forward) + assertEquals("listener2:callStart", eventOrder[1]) // listener2 second (forward) + assertEquals("listener1:dnsStart", eventOrder[2]) // listener1 first (forward) + assertEquals("listener2:dnsStart", eventOrder[3]) // listener2 second (forward) + assertEquals("listener2:dnsEnd", eventOrder[4]) // listener2 first (reverse) + assertEquals("listener1:dnsEnd", eventOrder[5]) // listener1 second (reverse) + + // Clear event order + eventOrder.clear() + + // Test more events to verify the sequence + chain.requestHeadersStart(call) // forward event + chain.requestHeadersEnd(call, Request.Builder().url("https://example.com").build()) // reverse event + chain.responseHeadersStart(call) // forward event + chain.responseHeadersEnd( + call, + Response.Builder() + .request(Request.Builder().url("https://example.com").build()) + .protocol(Protocol.HTTP_2) + .code(200) + .message("OK") + .build(), + ) // reverse event + + // Verify the sequence of events + assertEquals("listener1:requestHeadersStart", eventOrder[0]) // listener1 first (forward) + assertEquals("listener2:requestHeadersStart", eventOrder[1]) // listener2 second (forward) + assertEquals("listener2:requestHeadersEnd", eventOrder[2]) // listener2 first (reverse) + assertEquals("listener1:requestHeadersEnd", eventOrder[3]) // listener1 second (reverse) + assertEquals("listener1:responseHeadersStart", eventOrder[4]) // listener1 first (forward) + assertEquals("listener2:responseHeadersStart", eventOrder[5]) // listener2 second (forward) + assertEquals("listener2:responseHeadersEnd", eventOrder[6]) // listener2 first (reverse) + assertEquals("listener1:responseHeadersEnd", eventOrder[7]) // listener1 second (reverse) + } + + // A test EventListener that records the order of calls + private class TestEventListener(val name: String, val eventOrder: MutableList) : + EventListener(), + Closeable { + var closed = false + + override fun callStart(call: Call) { + eventOrder.add("$name:callStart") + } + + override fun dnsStart(call: Call, domainName: String) { + eventOrder.add("$name:dnsStart") + } + + override fun dnsEnd(call: Call, domainName: String, inetAddressList: List) { + eventOrder.add("$name:dnsEnd") + } + + override fun proxySelectStart(call: Call, url: HttpUrl) { + eventOrder.add("$name:proxySelectStart") + } + + override fun proxySelectEnd(call: Call, url: HttpUrl, proxies: List) { + eventOrder.add("$name:proxySelectEnd") + } + + override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy) { + eventOrder.add("$name:connectStart") + } + + override fun secureConnectStart(call: Call) { + eventOrder.add("$name:secureConnectStart") + } + + override fun secureConnectEnd(call: Call, handshake: Handshake?) { + eventOrder.add("$name:secureConnectEnd") + } + + override fun connectEnd(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?) { + eventOrder.add("$name:connectEnd") + } + + override fun connectFailed(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy, protocol: Protocol?, ioe: IOException) { + eventOrder.add("$name:connectFailed") + } + + override fun connectionAcquired(call: Call, connection: Connection) { + eventOrder.add("$name:connectionAcquired") + } + + override fun connectionReleased(call: Call, connection: Connection) { + eventOrder.add("$name:connectionReleased") + } + + override fun requestHeadersStart(call: Call) { + eventOrder.add("$name:requestHeadersStart") + } + + override fun requestHeadersEnd(call: Call, request: Request) { + eventOrder.add("$name:requestHeadersEnd") + } + + override fun requestBodyStart(call: Call) { + eventOrder.add("$name:requestBodyStart") + } + + override fun requestBodyEnd(call: Call, byteCount: Long) { + eventOrder.add("$name:requestBodyEnd") + } + + override fun requestFailed(call: Call, ioe: IOException) { + eventOrder.add("$name:requestFailed") + } + + override fun responseHeadersStart(call: Call) { + eventOrder.add("$name:responseHeadersStart") + } + + override fun responseHeadersEnd(call: Call, response: Response) { + eventOrder.add("$name:responseHeadersEnd") + } + + override fun responseBodyStart(call: Call) { + eventOrder.add("$name:responseBodyStart") + } + + override fun responseBodyEnd(call: Call, byteCount: Long) { + eventOrder.add("$name:responseBodyEnd") + } + + override fun responseFailed(call: Call, ioe: IOException) { + eventOrder.add("$name:responseFailed") + } + + override fun callEnd(call: Call) { + eventOrder.add("$name:callEnd") + } + + override fun callFailed(call: Call, ioe: IOException) { + eventOrder.add("$name:callFailed") + } + + override fun canceled(call: Call) { + eventOrder.add("$name:canceled") + } + + override fun satisfactionFailure(call: Call, response: Response) { + eventOrder.add("$name:satisfactionFailure") + } + + override fun cacheConditionalHit(call: Call, cachedResponse: Response) { + eventOrder.add("$name:cacheConditionalHit") + } + + override fun cacheHit(call: Call, response: Response) { + eventOrder.add("$name:cacheHit") + } + + override fun cacheMiss(call: Call) { + eventOrder.add("$name:cacheMiss") + } + + override fun close() { + closed = true + } + } + + // Helper methods to create mock objects + private fun createMockCall(): Call = object : Call { + override fun cancel() {} + override fun clone(): Call = this + override fun enqueue(responseCallback: Callback) {} + override fun execute(): Response = throw UnsupportedOperationException() + override fun isCanceled(): Boolean = false + override fun isExecuted(): Boolean = false + override fun request(): Request = Request.Builder().url("https://example.com").build() + override fun timeout(): okio.Timeout = okio.Timeout() + } + + private fun createHttpUrl(): HttpUrl = HttpUrl.Builder() + .scheme("https") + .host("example.com") + .build() +} diff --git a/runtime/protocol/http-client/api/http-client.api b/runtime/protocol/http-client/api/http-client.api index 9a752b6788..1d8deaa751 100644 --- a/runtime/protocol/http-client/api/http-client.api +++ b/runtime/protocol/http-client/api/http-client.api @@ -29,6 +29,8 @@ public abstract interface class aws/smithy/kotlin/runtime/http/config/HttpEngine public abstract fun getHttpClient ()Laws/smithy/kotlin/runtime/http/engine/HttpClientEngine; public abstract fun httpClient (Laws/smithy/kotlin/runtime/http/config/EngineFactory;Lkotlin/jvm/functions/Function1;)V public abstract fun httpClient (Lkotlin/jvm/functions/Function1;)V + public static synthetic fun httpClient$default (Laws/smithy/kotlin/runtime/http/config/HttpEngineConfig$Builder;Laws/smithy/kotlin/runtime/http/config/EngineFactory;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun httpClient$default (Laws/smithy/kotlin/runtime/http/config/HttpEngineConfig$Builder;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public abstract fun setHttpClient (Laws/smithy/kotlin/runtime/http/engine/HttpClientEngine;)V } @@ -309,7 +311,7 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/ContinueIntercept } public final class aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpointErrorInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { - public fun (Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function1;)V + public fun (Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function2;)V public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -495,11 +497,13 @@ public abstract interface class aws/smithy/kotlin/runtime/http/operation/Endpoin public abstract fun resolve (Laws/smithy/kotlin/runtime/http/operation/ResolveEndpointRequest;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } -public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpDeserialize { - public abstract fun deserialize (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Laws/smithy/kotlin/runtime/http/HttpCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpDeserializer { + public static final field Companion Laws/smithy/kotlin/runtime/http/operation/HttpDeserializer$Companion; } -public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpDeserializer { +public final class aws/smithy/kotlin/runtime/http/operation/HttpDeserializer$Companion { + public final fun getIdentity ()Laws/smithy/kotlin/runtime/http/operation/HttpDeserializer; + public final fun getUnit ()Laws/smithy/kotlin/runtime/http/operation/HttpDeserializer; } public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpDeserializer$NonStreaming : aws/smithy/kotlin/runtime/http/operation/HttpDeserializer { @@ -523,11 +527,12 @@ public final class aws/smithy/kotlin/runtime/http/operation/HttpOperationContext public final fun getSdkInvocationId ()Laws/smithy/kotlin/runtime/collections/AttributeKey; } -public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpSerialize { - public abstract fun serialize (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpSerializer { + public static final field Companion Laws/smithy/kotlin/runtime/http/operation/HttpSerializer$Companion; } -public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpSerializer { +public final class aws/smithy/kotlin/runtime/http/operation/HttpSerializer$Companion { + public final fun getUnit ()Laws/smithy/kotlin/runtime/http/operation/HttpSerializer; } public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpSerializer$NonStreaming : aws/smithy/kotlin/runtime/http/operation/HttpSerializer { @@ -538,13 +543,8 @@ public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpSer public abstract fun serialize (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } -public final class aws/smithy/kotlin/runtime/http/operation/IdentityDeserializer : aws/smithy/kotlin/runtime/http/operation/HttpDeserialize { - public static final field INSTANCE Laws/smithy/kotlin/runtime/http/operation/IdentityDeserializer; - public fun deserialize (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Laws/smithy/kotlin/runtime/http/HttpCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; -} - public abstract interface class aws/smithy/kotlin/runtime/http/operation/InitializeMiddleware : aws/smithy/kotlin/runtime/io/middleware/Middleware { - public abstract fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V + public fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V } public final class aws/smithy/kotlin/runtime/http/operation/InitializeMiddleware$DefaultImpls { @@ -556,7 +556,7 @@ public abstract interface class aws/smithy/kotlin/runtime/http/operation/InlineM } public abstract interface class aws/smithy/kotlin/runtime/http/operation/ModifyRequestMiddleware : aws/smithy/kotlin/runtime/io/middleware/ModifyRequest { - public abstract fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V + public fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V } public final class aws/smithy/kotlin/runtime/http/operation/ModifyRequestMiddleware$DefaultImpls { @@ -564,7 +564,7 @@ public final class aws/smithy/kotlin/runtime/http/operation/ModifyRequestMiddlew } public abstract interface class aws/smithy/kotlin/runtime/http/operation/MutateMiddleware : aws/smithy/kotlin/runtime/io/middleware/Middleware { - public abstract fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V + public fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V } public final class aws/smithy/kotlin/runtime/http/operation/MutateMiddleware$DefaultImpls { @@ -633,7 +633,7 @@ public final class aws/smithy/kotlin/runtime/http/operation/OperationTelemetryKt } public abstract interface class aws/smithy/kotlin/runtime/http/operation/ReceiveMiddleware : aws/smithy/kotlin/runtime/io/middleware/Middleware { - public abstract fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V + public fun install (Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation;)V } public final class aws/smithy/kotlin/runtime/http/operation/ReceiveMiddleware$DefaultImpls { @@ -675,20 +675,16 @@ public final class aws/smithy/kotlin/runtime/http/operation/SdkHttpOperationBuil public final fun build ()Laws/smithy/kotlin/runtime/http/operation/SdkHttpOperation; public final fun getContext ()Laws/smithy/kotlin/runtime/operation/ExecutionContext; public final fun getDeserializeWith ()Laws/smithy/kotlin/runtime/http/operation/HttpDeserializer; - public final fun getDeserializer ()Laws/smithy/kotlin/runtime/http/operation/HttpDeserialize; public final fun getExecution ()Laws/smithy/kotlin/runtime/http/operation/SdkOperationExecution; public final fun getHostPrefix ()Ljava/lang/String; public final fun getOperationName ()Ljava/lang/String; public final fun getSerializeWith ()Laws/smithy/kotlin/runtime/http/operation/HttpSerializer; - public final fun getSerializer ()Laws/smithy/kotlin/runtime/http/operation/HttpSerialize; public final fun getServiceName ()Ljava/lang/String; public final fun getTelemetry ()Laws/smithy/kotlin/runtime/http/operation/SdkOperationTelemetry; public final fun setDeserializeWith (Laws/smithy/kotlin/runtime/http/operation/HttpDeserializer;)V - public final fun setDeserializer (Laws/smithy/kotlin/runtime/http/operation/HttpDeserialize;)V public final fun setHostPrefix (Ljava/lang/String;)V public final fun setOperationName (Ljava/lang/String;)V public final fun setSerializeWith (Laws/smithy/kotlin/runtime/http/operation/HttpSerializer;)V - public final fun setSerializer (Laws/smithy/kotlin/runtime/http/operation/HttpSerialize;)V public final fun setServiceName (Ljava/lang/String;)V } @@ -731,14 +727,3 @@ public final class aws/smithy/kotlin/runtime/http/operation/SdkOperationTelemetr public final fun setSpanName (Ljava/lang/String;)V } -public final class aws/smithy/kotlin/runtime/http/operation/UnitDeserializer : aws/smithy/kotlin/runtime/http/operation/HttpDeserialize { - public static final field INSTANCE Laws/smithy/kotlin/runtime/http/operation/UnitDeserializer; - public fun deserialize (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Laws/smithy/kotlin/runtime/http/HttpCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; -} - -public final class aws/smithy/kotlin/runtime/http/operation/UnitSerializer : aws/smithy/kotlin/runtime/http/operation/HttpSerialize { - public static final field INSTANCE Laws/smithy/kotlin/runtime/http/operation/UnitSerializer; - public synthetic fun serialize (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun serialize (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Lkotlin/Unit;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; -} - diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpointErrorInterceptor.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpointErrorInterceptor.kt index ce1ef41857..adb4babd48 100644 --- a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpointErrorInterceptor.kt +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpointErrorInterceptor.kt @@ -23,7 +23,7 @@ import kotlin.reflect.KClass @InternalApi public class DiscoveredEndpointErrorInterceptor( private val errorType: KClass, - private val invalidate: (ExecutionContext) -> Unit, + private val invalidate: suspend (ExecutionContext) -> Unit, ) : HttpInterceptor { override suspend fun modifyBeforeAttemptCompletion( context: ResponseInterceptorContext, diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpSerde.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpSerde.kt index 5e50a3cca9..b50a05d037 100644 --- a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpSerde.kt +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpSerde.kt @@ -15,6 +15,13 @@ import aws.smithy.kotlin.runtime.operation.ExecutionContext */ @InternalApi public sealed interface HttpSerializer { + @InternalApi + public companion object { + public val Unit: HttpSerializer = object : NonStreaming { + override fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder = + HttpRequestBuilder() + } + } /** * Serializer for streaming operations that need full control over serialization of the body @@ -38,6 +45,17 @@ public sealed interface HttpSerializer { */ @InternalApi public sealed interface HttpDeserializer { + @InternalApi + public companion object { + public val Identity: HttpDeserializer = object : NonStreaming { + override fun deserialize(context: ExecutionContext, call: HttpCall, payload: ByteArray?): HttpResponse = + call.response + } + + public val Unit: HttpDeserializer = object : NonStreaming { + override fun deserialize(context: ExecutionContext, call: HttpCall, payload: ByteArray?) { } + } + } /** * Deserializer for streaming operations that need full control over deserialization of the body @@ -56,66 +74,3 @@ public sealed interface HttpDeserializer { public fun deserialize(context: ExecutionContext, call: HttpCall, payload: ByteArray?): T } } - -/** - * Implemented by types that know how to serialize to the HTTP protocol. - */ -@Deprecated("use HttpSerializer.Streaming") -@InternalApi -public fun interface HttpSerialize { - public suspend fun serialize(context: ExecutionContext, input: T): HttpRequestBuilder -} - -@Suppress("DEPRECATION") -private class LegacyHttpSerializeAdapter(val serializer: HttpSerialize) : HttpSerializer.Streaming { - override suspend fun serialize(context: ExecutionContext, input: T): HttpRequestBuilder = - serializer.serialize(context, input) -} - -@Suppress("DEPRECATION") -internal fun HttpSerialize.intoSerializer(): HttpSerializer = LegacyHttpSerializeAdapter(this) - -/** - * Implemented by types that know how to deserialize from the HTTP protocol. - */ -@Deprecated("use HttpDeserializer.Streaming") -@InternalApi -public fun interface HttpDeserialize { - public suspend fun deserialize(context: ExecutionContext, call: HttpCall): T -} - -@Suppress("DEPRECATION") -private class LegacyHttpDeserializeAdapter(val deserializer: HttpDeserialize) : HttpDeserializer.Streaming { - override suspend fun deserialize(context: ExecutionContext, call: HttpCall): T = - deserializer.deserialize(context, call) -} - -@Suppress("DEPRECATION") -internal fun HttpDeserialize.intoDeserializer(): HttpDeserializer = LegacyHttpDeserializeAdapter(this) - -/** - * Convenience deserialize implementation for a type with no output type - */ -@Suppress("DEPRECATION") -@InternalApi -public object UnitDeserializer : HttpDeserialize { - override suspend fun deserialize(context: ExecutionContext, call: HttpCall) {} -} - -/** - * Convenience serialize implementation for a type with no input type - */ -@Suppress("DEPRECATION") -@InternalApi -public object UnitSerializer : HttpSerialize { - override suspend fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder = HttpRequestBuilder() -} - -/** - * Convenience deserialize implementation that returns the response without modification - */ -@Suppress("DEPRECATION") -@InternalApi -public object IdentityDeserializer : HttpDeserialize { - override suspend fun deserialize(context: ExecutionContext, call: HttpCall): HttpResponse = call.response -} diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperation.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperation.kt index 775da6c891..abb3b8cc46 100644 --- a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperation.kt +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperation.kt @@ -36,17 +36,6 @@ public class SdkHttpOperation internal constructor( internal val typeInfo: OperationTypeInfo, internal val telemetry: SdkOperationTelemetry, ) { - - @Suppress("DEPRECATION") - internal constructor( - execution: SdkOperationExecution, - context: ExecutionContext, - serializer: HttpSerialize, - deserializer: HttpDeserialize, - typeInfo: OperationTypeInfo, - telemetry: SdkOperationTelemetry, - ) : this(execution, context, serializer.intoSerializer(), deserializer.intoDeserializer(), typeInfo, telemetry) - init { context[HttpOperationContext.SdkInvocationId] = Uuid.random().toString() } @@ -142,27 +131,8 @@ public class SdkHttpOperationBuilder( private val outputType: KClass<*>, ) { public val telemetry: SdkOperationTelemetry = SdkOperationTelemetry() - - @Suppress("DEPRECATION") - @Deprecated("use serializeWith") - public var serializer: HttpSerialize? = null - set(value) { - field = value - serializeWith = value?.intoSerializer() - } - public var serializeWith: HttpSerializer? = null - - @Suppress("DEPRECATION") - @Deprecated("use deserializeWith") - public var deserializer: HttpDeserialize? = null - set(value) { - field = value - deserializeWith = value?.intoDeserializer() - } - public var deserializeWith: HttpDeserializer? = null - public val execution: SdkOperationExecution = SdkOperationExecution() public val context: ExecutionContext = ExecutionContext() diff --git a/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperationTest.kt b/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperationTest.kt index 22e36d1cf1..e564bedcc2 100644 --- a/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperationTest.kt +++ b/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/operation/SdkHttpOperationTest.kt @@ -36,8 +36,8 @@ class SdkHttpOperationTest { val ex = assertFailsWith { @Suppress("DEPRECATION") SdkHttpOperation.build { - serializer = UnitSerializer - deserializer = UnitDeserializer + serializeWith = HttpSerializer.Unit + deserializeWith = HttpDeserializer.Unit } } diff --git a/runtime/runtime-core/api/runtime-core.api b/runtime/runtime-core/api/runtime-core.api index fa4c1d2d2e..91855f6e4d 100644 --- a/runtime/runtime-core/api/runtime-core.api +++ b/runtime/runtime-core/api/runtime-core.api @@ -177,6 +177,12 @@ public final class aws/smithy/kotlin/runtime/collections/CollectionExtKt { public static final fun createOrAppend (Ljava/util/List;Ljava/lang/Object;)Ljava/util/List; } +public abstract interface class aws/smithy/kotlin/runtime/collections/ExpiringKeyedCache { + public abstract fun get (Ljava/lang/Object;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun getSize ()I + public abstract fun invalidate (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class aws/smithy/kotlin/runtime/collections/LruCache { public fun (I)V public final fun get (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -188,9 +194,9 @@ public final class aws/smithy/kotlin/runtime/collections/LruCache { } public abstract interface class aws/smithy/kotlin/runtime/collections/MultiMap : java/util/Map, kotlin/jvm/internal/markers/KMappedMarker { - public abstract fun contains (Ljava/lang/Object;Ljava/lang/Object;)Z + public fun contains (Ljava/lang/Object;Ljava/lang/Object;)Z public abstract fun getEntryValues ()Lkotlin/sequences/Sequence; - public abstract fun toMutableMultiMap ()Laws/smithy/kotlin/runtime/collections/MutableMultiMap; + public fun toMutableMultiMap ()Laws/smithy/kotlin/runtime/collections/MutableMultiMap; } public final class aws/smithy/kotlin/runtime/collections/MultiMap$DefaultImpls { @@ -213,15 +219,15 @@ public abstract interface class aws/smithy/kotlin/runtime/collections/MutableMul public abstract fun add (Ljava/lang/Object;Ljava/lang/Object;)Z public abstract fun addAll (Ljava/lang/Object;ILjava/util/Collection;)Z public abstract fun addAll (Ljava/lang/Object;Ljava/util/Collection;)Z - public abstract fun addAll (Ljava/util/Map;)V - public abstract fun contains (Ljava/lang/Object;Ljava/lang/Object;)Z + public fun addAll (Ljava/util/Map;)V + public fun contains (Ljava/lang/Object;Ljava/lang/Object;)Z public abstract fun getEntryValues ()Lkotlin/sequences/Sequence; - public abstract fun put (Ljava/lang/Object;Ljava/lang/Object;)Ljava/util/List; + public fun put (Ljava/lang/Object;Ljava/lang/Object;)Ljava/util/List; public abstract fun removeAll (Ljava/lang/Object;Ljava/util/Collection;)Ljava/lang/Boolean; public abstract fun removeAt (Ljava/lang/Object;I)Ljava/lang/Object; public abstract fun removeElement (Ljava/lang/Object;Ljava/lang/Object;)Z public abstract fun retainAll (Ljava/lang/Object;Ljava/util/Collection;)Ljava/lang/Boolean; - public abstract fun toMultiMap ()Laws/smithy/kotlin/runtime/collections/MultiMap; + public fun toMultiMap ()Laws/smithy/kotlin/runtime/collections/MultiMap; } public final class aws/smithy/kotlin/runtime/collections/MutableMultiMap$DefaultImpls { @@ -235,12 +241,12 @@ public final class aws/smithy/kotlin/runtime/collections/MutableMultiMapKt { public static final fun mutableMultiMapOf ([Lkotlin/Pair;)Laws/smithy/kotlin/runtime/collections/MutableMultiMap; } -public final class aws/smithy/kotlin/runtime/collections/ReadThroughCache { +public final class aws/smithy/kotlin/runtime/collections/PeriodicSweepCache : aws/smithy/kotlin/runtime/collections/ExpiringKeyedCache { public synthetic fun (JLaws/smithy/kotlin/runtime/time/Clock;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (JLaws/smithy/kotlin/runtime/time/Clock;Lkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun get (Ljava/lang/Object;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun getSize ()I - public final fun invalidate (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun get (Ljava/lang/Object;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun getSize ()I + public fun invalidate (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class aws/smithy/kotlin/runtime/collections/StackKt { @@ -254,10 +260,10 @@ public final class aws/smithy/kotlin/runtime/collections/StackKt { public abstract interface class aws/smithy/kotlin/runtime/collections/ValuesMap { public abstract fun contains (Ljava/lang/String;)Z - public abstract fun contains (Ljava/lang/String;Ljava/lang/Object;)Z + public fun contains (Ljava/lang/String;Ljava/lang/Object;)Z public abstract fun entries ()Ljava/util/Set; - public abstract fun forEach (Lkotlin/jvm/functions/Function2;)V - public abstract fun get (Ljava/lang/String;)Ljava/lang/Object; + public fun forEach (Lkotlin/jvm/functions/Function2;)V + public fun get (Ljava/lang/String;)Ljava/lang/Object; public abstract fun getAll (Ljava/lang/String;)Ljava/util/List; public abstract fun getCaseInsensitiveName ()Z public abstract fun isEmpty ()Z @@ -707,6 +713,7 @@ public abstract interface class aws/smithy/kotlin/runtime/hashing/HashFunction { public abstract fun getDigestSizeBytes ()I public abstract fun reset ()V public abstract fun update ([BII)V + public static synthetic fun update$default (Laws/smithy/kotlin/runtime/hashing/HashFunction;[BIIILjava/lang/Object;)V } public final class aws/smithy/kotlin/runtime/hashing/HashFunction$DefaultImpls { @@ -903,6 +910,7 @@ public abstract interface class aws/smithy/kotlin/runtime/io/SdkBufferedSink : a public abstract fun outputStream ()Ljava/io/OutputStream; public abstract fun write (Laws/smithy/kotlin/runtime/io/SdkSource;J)V public abstract fun write ([BII)V + public static synthetic fun write$default (Laws/smithy/kotlin/runtime/io/SdkBufferedSink;[BIIILjava/lang/Object;)V public abstract fun writeAll (Laws/smithy/kotlin/runtime/io/SdkSource;)J public abstract fun writeByte (B)V public abstract fun writeInt (I)V @@ -912,6 +920,7 @@ public abstract interface class aws/smithy/kotlin/runtime/io/SdkBufferedSink : a public abstract fun writeShort (S)V public abstract fun writeShortLe (S)V public abstract fun writeUtf8 (Ljava/lang/String;II)V + public static synthetic fun writeUtf8$default (Laws/smithy/kotlin/runtime/io/SdkBufferedSink;Ljava/lang/String;IIILjava/lang/Object;)V } public final class aws/smithy/kotlin/runtime/io/SdkBufferedSink$DefaultImpls { @@ -925,6 +934,7 @@ public abstract interface class aws/smithy/kotlin/runtime/io/SdkBufferedSource : public abstract fun inputStream ()Ljava/io/InputStream; public abstract fun peek ()Laws/smithy/kotlin/runtime/io/SdkBufferedSource; public abstract fun read ([BII)I + public static synthetic fun read$default (Laws/smithy/kotlin/runtime/io/SdkBufferedSource;[BIIILjava/lang/Object;)I public abstract fun readAll (Laws/smithy/kotlin/runtime/io/SdkSink;)J public abstract fun readByte ()B public abstract fun readByteArray ()[B @@ -947,7 +957,7 @@ public final class aws/smithy/kotlin/runtime/io/SdkBufferedSource$DefaultImpls { } public abstract interface class aws/smithy/kotlin/runtime/io/SdkByteChannel : aws/smithy/kotlin/runtime/io/SdkByteReadChannel, aws/smithy/kotlin/runtime/io/SdkByteWriteChannel { - public abstract fun close ()V + public fun close ()V } public final class aws/smithy/kotlin/runtime/io/SdkByteChannel$DefaultImpls { @@ -994,6 +1004,7 @@ public abstract interface class aws/smithy/kotlin/runtime/io/SdkByteWriteChannel public abstract fun getTotalBytesWritten ()J public abstract fun isClosedForWrite ()Z public abstract fun write (Laws/smithy/kotlin/runtime/io/SdkBuffer;JLkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun write$default (Laws/smithy/kotlin/runtime/io/SdkByteWriteChannel;Laws/smithy/kotlin/runtime/io/SdkBuffer;JLkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } public final class aws/smithy/kotlin/runtime/io/SdkByteWriteChannel$DefaultImpls { @@ -1217,6 +1228,7 @@ public final class aws/smithy/kotlin/runtime/net/HostKt { public abstract interface class aws/smithy/kotlin/runtime/net/HostResolver { public static final field Companion Laws/smithy/kotlin/runtime/net/HostResolver$Companion; public abstract fun purgeCache (Laws/smithy/kotlin/runtime/net/HostAddress;)V + public static synthetic fun purgeCache$default (Laws/smithy/kotlin/runtime/net/HostResolver;Laws/smithy/kotlin/runtime/net/HostAddress;ILjava/lang/Object;)V public abstract fun reportFailure (Laws/smithy/kotlin/runtime/net/HostAddress;)V public abstract fun resolve (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } @@ -2094,7 +2106,6 @@ public final class aws/smithy/kotlin/runtime/smoketests/SmokeTestsFunctionsJVMKt public final class aws/smithy/kotlin/runtime/smoketests/SmokeTestsFunctionsKt { public static final fun getDefaultPrinter ()Ljava/lang/Appendable; - public static final fun printExceptionStackTrace (Ljava/lang/Exception;)V } public final class aws/smithy/kotlin/runtime/text/Scanner { @@ -2150,8 +2161,8 @@ public final class aws/smithy/kotlin/runtime/text/encoding/Encodable$Companion { public abstract interface class aws/smithy/kotlin/runtime/text/encoding/Encoding { public static final field Companion Laws/smithy/kotlin/runtime/text/encoding/Encoding$Companion; public abstract fun decode (Ljava/lang/String;)Ljava/lang/String; - public abstract fun encodableFromDecoded (Ljava/lang/String;)Laws/smithy/kotlin/runtime/text/encoding/Encodable; - public abstract fun encodableFromEncoded (Ljava/lang/String;)Laws/smithy/kotlin/runtime/text/encoding/Encodable; + public fun encodableFromDecoded (Ljava/lang/String;)Laws/smithy/kotlin/runtime/text/encoding/Encodable; + public fun encodableFromEncoded (Ljava/lang/String;)Laws/smithy/kotlin/runtime/text/encoding/Encodable; public abstract fun encode (Ljava/lang/String;)Ljava/lang/String; public abstract fun getName ()Ljava/lang/String; } diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/ExpiringKeyedCache.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/ExpiringKeyedCache.kt new file mode 100644 index 0000000000..b4588d9280 --- /dev/null +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/ExpiringKeyedCache.kt @@ -0,0 +1,42 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.collections + +import aws.smithy.kotlin.runtime.util.ExpiringValue + +/** + * A multi-value cache which supports retrieval and invalidation via a key paired with each value. The [get] and + * [invalidate] methods are `suspend` functions to allow for cross-context synchronization and potentially-expensive + * value lookup. + * + * Values in the cache _may_ expire and are retrieved as [ExpiringValue]. When a value is absent/expired in the cache, + * invoking [get] will cause a lookup to occur via the function's `valueLookup` parameter. + * + * @param K The type of the keys of this cache + * @param V The type of the values of this cache + */ +public interface ExpiringKeyedCache { + /** + * The number of values currently stored in the cache + */ + public val size: Int + + /** + * Gets the value associated with this key from the cache. If the cache does not contain the given key, + * implementations are expected to invoke [valueLookup], although they _may_ perform other actions such as throw + * exceptions, fall back to other caches, etc. + * @param key The key for which to look up a value + * @param valueLookup A possibly-suspending function which returns the read-through value associated with a given + * key. This function is invoked when the cache does not contain the given [key] or when the value is expired. + */ + public suspend fun get(key: K, valueLookup: suspend (K) -> ExpiringValue): V + + /** + * Invalidates the value (if any) for the given key, removing it from the cache regardless. This method has no + * effect if the given key is not present in the cache. + * @param key The key for which to invalidate a value + */ + public suspend fun invalidate(key: K) +} diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/ReadThroughCache.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/PeriodicSweepCache.kt similarity index 74% rename from runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/ReadThroughCache.kt rename to runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/PeriodicSweepCache.kt index 529a57d450..0940aaec2d 100644 --- a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/ReadThroughCache.kt +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/PeriodicSweepCache.kt @@ -12,8 +12,9 @@ import kotlinx.coroutines.sync.withLock import kotlin.time.Duration /** - * An object which caches values and allows retrieving them by key. The values expire after a time. If a value is - * expired or absent from the cache, it will be read from a `valueLookup` parameter passed to [get] and then cached. + * A cache which allows retrieving values by a key. Looking up a value for a key which does not exist in the cache (or + * where the value has expired) are resolved by calling [valueLookup]. The expiry for a value is included in the result + * returned from [valueLookup]. * * A sweep operation will run prior to a [get] or [invalidate] that happens after [minimumSweepPeriod] has elapsed from * the last sweep (or from the initialization of the cache). This sweep will search for and remove expired entries from @@ -29,10 +30,10 @@ import kotlin.time.Duration * @param clock The [Clock] to use for measuring time. Defaults to [Clock.System]. */ @InternalApi -public class ReadThroughCache( +public class PeriodicSweepCache( private val minimumSweepPeriod: Duration, private val clock: Clock = Clock.System, -) { +) : ExpiringKeyedCache { private val map = mutableMapOf>() private val mutex = Mutex() private var nextSweep = clock.now() + minimumSweepPeriod @@ -44,7 +45,7 @@ public class ReadThroughCache( * @param valueLookup A possibly-suspending function which returns the read-through value associated with a given * key. This function is invoked when the cache, for a given key, does not contain a value or the value is expired. */ - public suspend fun get(key: K, valueLookup: suspend (K) -> ExpiringValue): V = mutex.withLock { + override suspend fun get(key: K, valueLookup: suspend (K) -> ExpiringValue): V = mutex.withLock { if (clock.now() > nextSweep) sweep() val current = map[key] @@ -59,20 +60,29 @@ public class ReadThroughCache( * Invalidates the value (if any) for the given key, removing it from the cache regardless of its expiry. * @param key The key for which to invalidate a value. */ - public suspend fun invalidate(key: K): Unit = mutex.withLock { + override suspend fun invalidate(key: K): Unit = mutex.withLock { map.remove(key) if (clock.now() > nextSweep) sweep() } + /** + * Indicates whether this value is expired according to its [ExpiringValue.expiresAt] property and the cache's + * [clock] + */ private val ExpiringValue<*>.isExpired: Boolean get() = clock.now() >= expiresAt /** - * Gets the number of values currently stored in the cache. + * Gets the number of values currently stored in the cache. Note that this property is non-volatile and may reflect + * stale information in highly-concurrent scenarios. */ - public val size: Int + override val size: Int get() = map.size + /** + * Sweeps the cache to remove expired entries and schedule the next sweep. This method _must_ be invoked under mutex + * lock. + */ private fun sweep() { val iterator = map.iterator() while (iterator.hasNext()) { diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/smoketests/SmokeTestsFunctions.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/smoketests/SmokeTestsFunctions.kt index 3da3bb8c23..9cab81bdbc 100644 --- a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/smoketests/SmokeTestsFunctions.kt +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/smoketests/SmokeTestsFunctions.kt @@ -2,27 +2,6 @@ package aws.smithy.kotlin.runtime.smoketests public expect fun exitProcess(status: Int): Nothing -/** - * Prints an exceptions stack trace using test anything protocol (TAP) format e.g. - * - * #java.lang.ArithmeticException: / by zero - * # at FileKt.main(File.kt:3) - * # at FileKt.main(File.kt) - * # at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) - * # at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(Unknown Source) - * # at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source) - * # at java.base/java.lang.reflect.Method.invoke(Unknown Source) - * # at executors.JavaRunnerExecutor$Companion.main(JavaRunnerExecutor.kt:27) - * # at executors.JavaRunnerExecutor.main(JavaRunnerExecutor.kt) - */ -@Deprecated( - message = "No longer used, target for removal in 1.5", - replaceWith = ReplaceWith("println(exception.stackTraceToString().prependIndent(\"#\"))"), - level = DeprecationLevel.WARNING, -) -public fun printExceptionStackTrace(exception: Exception): Unit = - println(exception.stackTraceToString().split("\n").joinToString("\n") { "#$it" }) - public class SmokeTestsException(message: String) : Exception(message) /** diff --git a/runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/ReadThroughCacheTest.kt b/runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/PeriodicSweepCacheTest.kt similarity index 91% rename from runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/ReadThroughCacheTest.kt rename to runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/PeriodicSweepCacheTest.kt index 4d85ece15f..f070669c1f 100644 --- a/runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/ReadThroughCacheTest.kt +++ b/runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/PeriodicSweepCacheTest.kt @@ -12,13 +12,13 @@ import kotlin.test.assertEquals import kotlin.time.Duration.Companion.minutes import kotlin.time.Duration.Companion.seconds -class ReadThroughCacheTest { +class PeriodicSweepCacheTest { @Test - fun testReadThrough() = runTest { + fun testGet() = runTest { val clock = ManualClock() var counter = 0 fun uncachedValue() = ExpiringValue(counter++, clock.now() + 2.seconds) - val cache = ReadThroughCache(1.minutes, clock) + val cache = PeriodicSweepCache(1.minutes, clock) // Basic read through assertEquals(0, cache.get("a") { uncachedValue() }) @@ -41,7 +41,7 @@ class ReadThroughCacheTest { val clock = ManualClock() var counter = 0 fun uncachedValue() = ExpiringValue(counter++, clock.now() + 2.seconds) - val cache = ReadThroughCache(4.seconds, clock) + val cache = PeriodicSweepCache(4.seconds, clock) // Pre-populate values assertEquals(0, cache.get("a") { uncachedValue() }) diff --git a/runtime/serde/serde-xml/api/serde-xml.api b/runtime/serde/serde-xml/api/serde-xml.api index 736408bd97..eabc8febb8 100644 --- a/runtime/serde/serde-xml/api/serde-xml.api +++ b/runtime/serde/serde-xml/api/serde-xml.api @@ -47,16 +47,6 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlCollectionValueNamespa public synthetic fun (Ljava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V } -public final class aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer : aws/smithy/kotlin/runtime/serde/Deserializer { - public fun (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;Z)V - public synthetic fun (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;ZILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun ([BZ)V - public synthetic fun ([BZILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun deserializeList (Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$ElementIterator; - public fun deserializeMap (Laws/smithy/kotlin/runtime/serde/SdkFieldDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$EntryIterator; - public fun deserializeStruct (Laws/smithy/kotlin/runtime/serde/SdkObjectDescriptor;)Laws/smithy/kotlin/runtime/serde/Deserializer$FieldIterator; -} - public final class aws/smithy/kotlin/runtime/serde/xml/XmlError : aws/smithy/kotlin/runtime/serde/FieldTrait { public static final field INSTANCE Laws/smithy/kotlin/runtime/serde/xml/XmlError; public final fun getErrorTag ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken$QualifiedName; @@ -155,8 +145,10 @@ public abstract interface class aws/smithy/kotlin/runtime/serde/xml/XmlStreamRea public abstract fun getLastToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public abstract fun nextToken ()Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public abstract fun peek (I)Laws/smithy/kotlin/runtime/serde/xml/XmlToken; + public static synthetic fun peek$default (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;IILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlToken; public abstract fun skipNext ()V public abstract fun subTreeReader (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader$SubtreeStartDepth;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader; + public static synthetic fun subTreeReader$default (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader;Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader$SubtreeStartDepth;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamReader; } public final class aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader$DefaultImpls { @@ -178,13 +170,17 @@ public final class aws/smithy/kotlin/runtime/serde/xml/XmlStreamReaderKt { public abstract interface class aws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter { public abstract fun attribute (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; + public static synthetic fun attribute$default (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; public abstract fun endDocument ()V public abstract fun endTag (Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; + public static synthetic fun endTag$default (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; public abstract fun getBytes ()[B public abstract fun getText ()Ljava/lang/String; public abstract fun namespacePrefix (Ljava/lang/String;Ljava/lang/String;)V + public static synthetic fun namespacePrefix$default (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)V public abstract fun startDocument ()V public abstract fun startTag (Ljava/lang/String;Ljava/lang/String;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; + public static synthetic fun startTag$default (Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; public abstract fun text (Ljava/lang/String;)Laws/smithy/kotlin/runtime/serde/xml/XmlStreamWriter; } diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt deleted file mode 100644 index b950e888b6..0000000000 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlDeserializer.kt +++ /dev/null @@ -1,431 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package aws.smithy.kotlin.runtime.serde.xml - -import aws.smithy.kotlin.runtime.InternalApi -import aws.smithy.kotlin.runtime.content.BigDecimal -import aws.smithy.kotlin.runtime.content.BigInteger -import aws.smithy.kotlin.runtime.content.Document -import aws.smithy.kotlin.runtime.serde.* -import aws.smithy.kotlin.runtime.text.encoding.decodeBase64Bytes -import aws.smithy.kotlin.runtime.time.Instant -import aws.smithy.kotlin.runtime.time.TimestampFormat - -private const val FIRST_FIELD_INDEX: Int = 0 - -// Represents aspects of SdkFieldDescriptor that are particular to the Xml format -internal sealed class FieldLocation { - // specifies the mapping to a sdk field index - abstract val fieldIndex: Int - - data class Text(override val fieldIndex: Int) : FieldLocation() // Xml nodes have only one associated Text element - data class Attribute(override val fieldIndex: Int, val names: Set) : FieldLocation() -} - -/** - * Provides a deserializer for XML documents - * - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param validateRootElement Flag indicating if the root XML document [XmlToken.BeginElement] should be validated against - * the descriptor passed to [deserializeStruct]. This only affects the root element, not nested struct elements. Some - * restXml based services DO NOT always send documents with a root element name that matches the shape ID name - * (S3 in particular). This means there is nothing in the model that gives you enough information to validate the tag. - */ -@Deprecated("XmlDeserializer is deprecated and will be removed in a future release") -@InternalApi -public class XmlDeserializer( - private val reader: XmlStreamReader, - private val validateRootElement: Boolean = false, -) : Deserializer { - - public constructor(input: ByteArray, validateRootElement: Boolean = false) : this(xmlStreamReader(input), validateRootElement) - - private var firstStructCall = true - - override fun deserializeStruct(descriptor: SdkObjectDescriptor): Deserializer.FieldIterator { - if (firstStructCall) { - if (!descriptor.hasTrait()) throw DeserializationException("Top-level struct $descriptor requires a XmlSerialName trait but has none.") - - firstStructCall = false - - reader.nextToken() // Matching field descriptors to children tags so consume the start element of top-level struct - - val structToken = if (descriptor.hasTrait()) { - reader.seek { it.name == descriptor.expectTrait().errorTag } - } else { - reader.seek() - } ?: throw DeserializationException("Could not find a begin element for new struct") - - if (validateRootElement) { - descriptor.requireNameMatch(structToken.name.tag) - } - } - - // Consume any remaining terminating tokens from previous deserialization - reader.seek() - - // Because attributes set on the root node of the struct, we must read the values before creating the subtree - val attribFields = reader.tokenAttributesToFieldLocations(descriptor) - val parentToken = if (reader.lastToken is XmlToken.BeginElement) { - reader.lastToken as XmlToken.BeginElement - } else { - throw DeserializationException("Expected last parsed token to be ${XmlToken.BeginElement::class} but was ${reader.lastToken}") - } - - val unwrapped = descriptor.hasTrait() - return XmlStructDeserializer(descriptor, reader.subTreeReader(XmlStreamReader.SubtreeStartDepth.CURRENT), parentToken, attribFields, unwrapped) - } - - override fun deserializeList(descriptor: SdkFieldDescriptor): Deserializer.ElementIterator { - val depth = when (descriptor.hasTrait()) { - true -> XmlStreamReader.SubtreeStartDepth.CURRENT - else -> XmlStreamReader.SubtreeStartDepth.CHILD - } - - return XmlListDeserializer(reader.subTreeReader(depth), descriptor) - } - - override fun deserializeMap(descriptor: SdkFieldDescriptor): Deserializer.EntryIterator { - val depth = when (descriptor.hasTrait()) { - true -> XmlStreamReader.SubtreeStartDepth.CURRENT - else -> XmlStreamReader.SubtreeStartDepth.CHILD - } - - return XmlMapDeserializer(reader.subTreeReader(depth), descriptor) - } -} - -/** - * Deserializes specific XML structures into forms that can produce Maps - * - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param descriptor associated [SdkFieldDescriptor] which represents the expected Map - * @param primitiveDeserializer used to deserialize primitive values - */ -internal class XmlMapDeserializer( - private val reader: XmlStreamReader, - private val descriptor: SdkFieldDescriptor, - private val primitiveDeserializer: PrimitiveDeserializer = XmlPrimitiveDeserializer(reader, descriptor), -) : PrimitiveDeserializer by primitiveDeserializer, - Deserializer.EntryIterator { - private val mapTrait = descriptor.findTrait() ?: XmlMapName.Default - - override fun hasNextEntry(): Boolean { - val compareTo = when (descriptor.hasTrait()) { - true -> descriptor.findTrait()?.name ?: mapTrait.key // Prefer seeking to XmlSerialName if the trait exists - false -> mapTrait.entry - } - - // Seek to either the XML serial name, entry, or key token depending on the flatness of the map and if the name trait is present - val nextEntryToken = when (descriptor.hasTrait()) { - true -> reader.peekSeek { it.name.local == compareTo } - false -> reader.seek { it.name.local == compareTo } - } - - return nextEntryToken != null - } - - override fun key(): String { - // Seek to the key begin token - reader.seek { it.name.local == mapTrait.key } - ?: error("Unable to find key $mapTrait.key in $descriptor") - - val keyValueToken = reader.takeNextAs() - reader.nextToken() // Consume the end wrapper - - return keyValueToken.value ?: throw DeserializationException("Key unspecified in $descriptor") - } - - override fun nextHasValue(): Boolean { - // Expect a begin and value (or another begin) token if Map entry has a value - val peekBeginToken = reader.peek(1) ?: throw DeserializationException("Unexpected termination of token stream in $descriptor") - val peekValueToken = reader.peek(2) ?: throw DeserializationException("Unexpected termination of token stream in $descriptor") - - return peekBeginToken !is XmlToken.EndElement && peekValueToken !is XmlToken.EndElement - } -} - -/** - * Deserializes specific XML structures into forms that can produce Lists - * - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param descriptor associated [SdkFieldDescriptor] which represents the expected Map - * @param primitiveDeserializer used to deserialize primitive values - */ -internal class XmlListDeserializer( - private val reader: XmlStreamReader, - private val descriptor: SdkFieldDescriptor, - private val primitiveDeserializer: PrimitiveDeserializer = XmlPrimitiveDeserializer(reader, descriptor), -) : PrimitiveDeserializer by primitiveDeserializer, - Deserializer.ElementIterator { - private var firstCall = true - private val flattened = descriptor.hasTrait() - private val elementName = (descriptor.findTrait() ?: XmlCollectionName.Default).element - - override fun hasNextElement(): Boolean { - if (!flattened && firstCall) { - val nextToken = reader.peek() - val matchedListDescriptor = nextToken is XmlToken.BeginElement && descriptor.nameMatches(nextToken.name.tag) - val hasChildren = if (nextToken == null) false else nextToken.depth >= reader.lastToken!!.depth - - if (!matchedListDescriptor && !hasChildren) return false - - // Discard the wrapper and move to the first element in the list - if (matchedListDescriptor) reader.nextToken() - - firstCall = false - } - - if (flattened) { - // Because our subtree is not CHILD, we cannot rely on the subtree boundary to determine end of collection. - // Rather, we search for either the next begin token matching the (flat) list member name which should - // be immediately after the current token - - // peek at the next token if there is one, in the case of a list of structs, the next token is actually - // the end of the current flat list element in which case we need to peek twice - val next = when (val peeked = reader.peek()) { - is XmlToken.EndElement -> { - if (peeked.name.local == descriptor.serialName.name) { - // consume the end token - reader.nextToken() - reader.peek() - } else { - peeked - } - } - else -> peeked - } - - val tokens = listOfNotNull(reader.lastToken, next) - - // Iterate over the token stream until begin token matching name is found or end element matching list is found. - return tokens - .filterIsInstance() - .any { it.name.local == descriptor.serialName.name } - } else { - // If we can find another begin token w/ the element name, we have more elements to process - return reader.seek { it.name.local == elementName }.isNotTerminal() - } - } - - override fun nextHasValue(): Boolean = reader.peek() !is XmlToken.EndElement -} - -/** - * Deserializes specific XML structures into forms that can produce structures - * - * @param objDescriptor associated [SdkObjectDescriptor] which represents the expected structure - * @param reader underlying [XmlStreamReader] from which tokens are read - * @param parentToken initial token of associated structure - * @param parsedFieldLocations list of [FieldLocation] representing values able to be loaded into deserialized instances - */ -private class XmlStructDeserializer( - private val objDescriptor: SdkObjectDescriptor, - reader: XmlStreamReader, - private val parentToken: XmlToken.BeginElement, - private val parsedFieldLocations: MutableList = mutableListOf(), - private val unwrapped: Boolean, -) : Deserializer.FieldIterator { - // Used to track direct deserialization or further nesting between calls to findNextFieldIndex() and deserialize() - private var reentryFlag: Boolean = false - - private val reader: XmlStreamReader = if (unwrapped) reader else reader.subTreeReader(XmlStreamReader.SubtreeStartDepth.CHILD) - - override fun findNextFieldIndex(): Int? { - if (unwrapped) { - return if (reader.peek() is XmlToken.Text) FIRST_FIELD_INDEX else null - } - if (inNestedMode()) { - // Returning from a nested struct call. Nested deserializer consumed - // tokens so clear them here to avoid processing stale state - parsedFieldLocations.clear() - } - - if (parsedFieldLocations.isEmpty()) { - val matchedFieldLocations = when (val token = reader.nextToken()) { - null, is XmlToken.EndDocument -> return null - is XmlToken.EndElement -> return findNextFieldIndex() - is XmlToken.BeginElement -> { - val nextToken = reader.peek() ?: return null - val objectFields = objDescriptor.fields - val memberFields = objectFields.filter { field -> objDescriptor.fieldTokenMatcher(field, token) } - val matchingFields = memberFields.mapNotNull { it.findFieldLocation(token, nextToken) } - matchingFields - } - else -> return findNextFieldIndex() - } - - // Sorting ensures attribs are processed before text, as processing the Text token pushes the parser on to the next token. - parsedFieldLocations.addAll(matchedFieldLocations.sortedBy { it is FieldLocation.Text }) - } - - return parsedFieldLocations.firstOrNull()?.fieldIndex ?: Deserializer.FieldIterator.UNKNOWN_FIELD - } - - private fun deserializeValue(transform: ((String) -> T)): T { - if (unwrapped) { - val value = reader.takeNextAs().value ?: "" - return transform(value) - } - // Set and validate mode - reentryFlag = false - if (parsedFieldLocations.isEmpty()) throw DeserializationException("matchedFields is empty, was findNextFieldIndex() called?") - - // Take the first FieldLocation and attempt to parse it into the value specified by the descriptor. - return when (val nextField = parsedFieldLocations.removeFirst()) { - is FieldLocation.Text -> { - val value = when (val peekToken = reader.peek()) { - is XmlToken.Text -> reader.takeNextAs().value ?: "" - is XmlToken.EndElement -> "" - else -> throw DeserializationException("Unexpected token $peekToken") - } - transform(value) - } - is FieldLocation.Attribute -> { - transform( - nextField - .names - .mapNotNull { parentToken.attributes[it] } - .firstOrNull() ?: throw DeserializationException("Expected attrib value ${nextField.names.first()} not found in ${parentToken.name}"), - ) - } - } - } - - override fun skipValue() = reader.skipNext() - - override fun deserializeByte(): Byte = deserializeValue { it.toIntOrNull()?.toByte() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeInt(): Int = deserializeValue { it.toIntOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeShort(): Short = deserializeValue { it.toIntOrNull()?.toShort() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeLong(): Long = deserializeValue { it.toLongOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeFloat(): Float = deserializeValue { it.toFloatOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeDouble(): Double = deserializeValue { it.toDoubleOrNull() ?: throw DeserializationException("Unable to deserialize $it") } - - override fun deserializeBigInteger(): BigInteger = deserializeValue { - runCatching { BigInteger(it) } - .getOrElse { throw DeserializationException("Unable to deserialize $it as BigInteger") } - } - - override fun deserializeBigDecimal(): BigDecimal = deserializeValue { - runCatching { BigDecimal(it) } - .getOrElse { throw DeserializationException("Unable to deserialize $it as BigDecimal") } - } - - override fun deserializeString(): String = deserializeValue { it } - - override fun deserializeBoolean(): Boolean = deserializeValue { it.toBoolean() } - - override fun deserializeDocument(): Document = throw DeserializationException("cannot deserialize unsupported Document type in xml") - - override fun deserializeByteArray(): ByteArray = deserializeString().decodeBase64Bytes() - - override fun deserializeInstant(format: TimestampFormat): Instant = when (format) { - TimestampFormat.EPOCH_SECONDS -> deserializeString().let { Instant.fromEpochSeconds(it) } - TimestampFormat.ISO_8601 -> deserializeString().let { Instant.fromIso8601(it) } - TimestampFormat.RFC_5322 -> deserializeString().let { Instant.fromRfc5322(it) } - else -> throw DeserializationException("unknown timestamp format: $format") - } - - override fun deserializeNull(): Nothing? { - reader.takeNextAs() - return null - } - - // A struct deserializer can be called in two "modes": - // 1. to deserialize a value. This calls findNextFieldIndex() followed by deserialize() - // 2. to deserialize a nested container. This calls findNextFieldIndex() followed by a call to another deserialize() - // Because state is built in findNextFieldIndex() that is intended to be used directly in deserialize() (mode 1) - // and there is no explicit way that this type knows which mode is in use, the state built must be cleared. - // this is done by flipping a bit between the two calls. If the bit has not been flipped on any call to findNextFieldIndex() - // it is determined that the nested mode was used and any existing state should be cleared. - // if the state is not cleared, deserialization goes into an infinite loop because the deserializer sees pending fields to pull from the stream - // which are never consumed by the (missing) call to deserialize() - private fun inNestedMode(): Boolean = when (reentryFlag) { - true -> true - false -> { - reentryFlag = true - false - } - } -} - -// Extract the attributes from the last-read token and match them to [FieldLocation] on the [SdkObjectDescriptor]. -private fun XmlStreamReader.tokenAttributesToFieldLocations(descriptor: SdkObjectDescriptor): MutableList = - if (descriptor.hasXmlAttributes && lastToken is XmlToken.BeginElement) { - val attribFields = descriptor.fields.filter { it.hasTrait() } - val matchedAttribFields = attribFields.filter { it.findFieldLocation(lastToken as XmlToken.BeginElement, peek() ?: throw DeserializationException("Unexpected end of tokens")) != null } - matchedAttribFields.map { FieldLocation.Attribute(it.index, it.toQualifiedNames()) } - .toMutableList() - } else { - mutableListOf() - } - -// Returns a [FieldLocation] if the field maps to the current token -private fun SdkFieldDescriptor.findFieldLocation( - currentToken: XmlToken.BeginElement, - nextToken: XmlToken, -): FieldLocation? = when (val property = toFieldLocation()) { - is FieldLocation.Text -> { - when { - nextToken is XmlToken.Text -> property - nextToken is XmlToken.BeginElement -> property - // The following allows for struct primitives to remain unvisited if no value - // but causes nested deserializers to be called even if they contain no value - nextToken is XmlToken.EndElement && currentToken.name == nextToken.name -> property - else -> null - } - } - is FieldLocation.Attribute -> { - val foundMatch = property.names.any { currentToken.attributes[it]?.isNotBlank() == true } - if (foundMatch) property else null - } -} - -// Produce a [FieldLocation] type based on presence of traits of field -// A field without an attribute trait is assumed to be a text token -private fun SdkFieldDescriptor.toFieldLocation(): FieldLocation = - when (findTrait()) { - null -> FieldLocation.Text(index) // Assume a text value if no attributes defined. - else -> FieldLocation.Attribute(index, toQualifiedNames()) - } - -// Matches fields and tokens with matching qualified name -private fun SdkObjectDescriptor.fieldTokenMatcher(fieldDescriptor: SdkFieldDescriptor, beginElement: XmlToken.BeginElement): Boolean { - if (fieldDescriptor.kind == SerialKind.List && fieldDescriptor.hasTrait()) { - val fieldName = fieldDescriptor.findTrait() ?: XmlCollectionName.Default - val tokenQname = beginElement.name - - // It may be that we are matching a flattened list element or matching a list itself. In the latter - // case the following predicate will not work, so if we fail to match the member - // try again (below) to match against the container. - if (fieldName.element == tokenQname.local) return true - } - - return fieldDescriptor.nameMatches(beginElement.name.tag) -} - -/** - * Return the next token of the specified type or throw [DeserializationException] if incorrect type. - */ -internal inline fun XmlStreamReader.takeNextAs(): TExpected { - val token = this.nextToken() ?: throw DeserializationException("Expected ${TExpected::class} but instead found null") - requireToken(token) - return token as TExpected -} - -/** - * Require that the given token be of type [TExpected] or else throw an exception - */ -internal inline fun requireToken(token: XmlToken) { - if (token::class != TExpected::class) { - throw DeserializationException("Expected ${TExpected::class}; found ${token::class} ($token)") - } -} diff --git a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt index a43250b267..760f52a833 100644 --- a/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt +++ b/runtime/serde/serde-xml/common/src/aws/smithy/kotlin/runtime/serde/xml/XmlStreamReader.kt @@ -6,6 +6,7 @@ package aws.smithy.kotlin.runtime.serde.xml import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.serde.DeserializationException import aws.smithy.kotlin.runtime.serde.xml.deserialization.LexingXmlStreamReader import aws.smithy.kotlin.runtime.serde.xml.deserialization.StringTextStream import aws.smithy.kotlin.runtime.serde.xml.deserialization.XmlLexer @@ -114,6 +115,24 @@ public inline fun XmlStreamReader.peekSeek(selectionPredi return null } +/** + * Return the next token of the specified type or throw [DeserializationException] if incorrect type. + */ +internal inline fun XmlStreamReader.takeNextAs(): TExpected { + val token = this.nextToken() ?: throw DeserializationException("Expected ${TExpected::class} but instead found null") + requireToken(token) + return token as TExpected +} + +/** + * Require that the given token be of type [TExpected] or else throw an exception + */ +private inline fun requireToken(token: XmlToken) { + if (token::class != TExpected::class) { + throw DeserializationException("Expected ${TExpected::class}; found ${token::class} ($token)") + } +} + /** * Creates an [XmlStreamReader] instance */ diff --git a/runtime/smithy-client/api/smithy-client.api b/runtime/smithy-client/api/smithy-client.api index b6132b26b8..823ff0467d 100644 --- a/runtime/smithy-client/api/smithy-client.api +++ b/runtime/smithy-client/api/smithy-client.api @@ -30,47 +30,25 @@ public final class aws/smithy/kotlin/runtime/client/IdempotencyTokenProvider$Com } public abstract interface class aws/smithy/kotlin/runtime/client/Interceptor { - public abstract fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public abstract fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public abstract fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public abstract fun modifyBeforeRetryLoop (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public abstract fun modifyBeforeSerialization (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public abstract fun modifyBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public abstract fun modifyBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public abstract fun readAfterAttempt (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public abstract fun readAfterDeserialization (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public abstract fun readAfterExecution (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public abstract fun readAfterSerialization (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public abstract fun readAfterSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public abstract fun readAfterTransmit (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V - public abstract fun readBeforeAttempt (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public abstract fun readBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V - public abstract fun readBeforeExecution (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V - public abstract fun readBeforeSerialization (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V - public abstract fun readBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public abstract fun readBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V -} - -public final class aws/smithy/kotlin/runtime/client/Interceptor$DefaultImpls { - public static fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static fun modifyBeforeRetryLoop (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static fun modifyBeforeSerialization (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static fun modifyBeforeSigning (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static fun modifyBeforeTransmit (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public static fun readAfterAttempt (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public static fun readAfterDeserialization (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public static fun readAfterExecution (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public static fun readAfterSerialization (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public static fun readAfterSigning (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public static fun readAfterTransmit (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V - public static fun readBeforeAttempt (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public static fun readBeforeDeserialization (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V - public static fun readBeforeExecution (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V - public static fun readBeforeSerialization (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V - public static fun readBeforeSigning (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public static fun readBeforeTransmit (Laws/smithy/kotlin/runtime/client/Interceptor;Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V + public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun modifyBeforeRetryLoop (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun modifyBeforeSerialization (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun modifyBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun modifyBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun readAfterAttempt (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V + public fun readAfterDeserialization (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V + public fun readAfterExecution (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V + public fun readAfterSerialization (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V + public fun readAfterSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V + public fun readAfterTransmit (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V + public fun readBeforeAttempt (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V + public fun readBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V + public fun readBeforeExecution (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V + public fun readBeforeSerialization (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V + public fun readBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V + public fun readBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V } public abstract class aws/smithy/kotlin/runtime/client/LogMode { @@ -187,7 +165,7 @@ public abstract interface class aws/smithy/kotlin/runtime/client/SdkClientConfig public abstract interface class aws/smithy/kotlin/runtime/client/SdkClientFactory { public abstract fun builder ()Laws/smithy/kotlin/runtime/client/SdkClient$Builder; - public abstract fun invoke (Lkotlin/jvm/functions/Function1;)Laws/smithy/kotlin/runtime/client/SdkClient; + public fun invoke (Lkotlin/jvm/functions/Function1;)Laws/smithy/kotlin/runtime/client/SdkClient; } public final class aws/smithy/kotlin/runtime/client/SdkClientFactory$DefaultImpls { @@ -225,7 +203,7 @@ public abstract interface class aws/smithy/kotlin/runtime/client/config/Compress public abstract interface class aws/smithy/kotlin/runtime/client/config/CompressionClientConfig$Builder { public abstract fun getRequestCompression ()Laws/smithy/kotlin/runtime/client/config/RequestCompressionConfig$Builder; - public abstract fun requestCompression (Lkotlin/jvm/functions/Function1;)V + public fun requestCompression (Lkotlin/jvm/functions/Function1;)V } public final class aws/smithy/kotlin/runtime/client/config/CompressionClientConfig$Builder$DefaultImpls { @@ -353,3 +331,15 @@ public final class aws/smithy/kotlin/runtime/client/endpoints/functions/Url { public fun toString ()Ljava/lang/String; } +public abstract interface class aws/smithy/kotlin/runtime/client/region/RegionProvider { + public abstract fun getRegion (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public class aws/smithy/kotlin/runtime/client/region/RegionProviderChain : aws/smithy/kotlin/runtime/client/region/RegionProvider { + public fun (Ljava/util/List;)V + public fun ([Laws/smithy/kotlin/runtime/client/region/RegionProvider;)V + protected final fun getProviders ()[Laws/smithy/kotlin/runtime/client/region/RegionProvider; + public fun getRegion (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun toString ()Ljava/lang/String; +} + diff --git a/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/Interceptor.kt b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/Interceptor.kt index 8d2b5c68bc..4fab16b9a7 100644 --- a/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/Interceptor.kt +++ b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/Interceptor.kt @@ -5,6 +5,7 @@ package aws.smithy.kotlin.runtime.client +import aws.smithy.kotlin.runtime.client.util.MpJvmDefaultWithoutCompatibility import aws.smithy.kotlin.runtime.operation.ExecutionContext /** @@ -20,6 +21,7 @@ import aws.smithy.kotlin.runtime.operation.ExecutionContext * **MUST** not modify state even if it is possible to do so (it is not always possible or performant to provide an * immutable view of every type). */ +@MpJvmDefaultWithoutCompatibility public interface Interceptor< Input, Output, diff --git a/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/region/RegionProvider.kt b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/region/RegionProvider.kt new file mode 100644 index 0000000000..9e82ec4509 --- /dev/null +++ b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/region/RegionProvider.kt @@ -0,0 +1,17 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.client.region + +/** + * Interface for providing AWS region information. Implementations are free to use any strategy for + * providing region information + */ +public interface RegionProvider { + /** + * Return the region name to use. If region information is not available, implementations should return null + */ + public suspend fun getRegion(): String? +} diff --git a/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/region/RegionProviderChain.kt b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/region/RegionProviderChain.kt new file mode 100644 index 0000000000..0df388d11f --- /dev/null +++ b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/region/RegionProviderChain.kt @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.client.region + +import aws.smithy.kotlin.runtime.telemetry.logging.logger +import aws.smithy.kotlin.runtime.util.asyncLazy +import kotlin.coroutines.coroutineContext + +/** + * Composite [RegionProvider] that delegates to a chain of providers. + * [providers] are consulted in the order given and the first region found is returned + * + * @param providers the list of providers to delegate to + */ +public open class RegionProviderChain( + protected vararg val providers: RegionProvider, +) : RegionProvider { + + public constructor(providers: List) : this(*providers.toTypedArray()) + + private val resolvedRegion = asyncLazy(::resolveRegion) + + init { + require(providers.isNotEmpty()) { "at least one provider must be in the chain" } + } + + override fun toString(): String = + (listOf(this) + providers).map { it::class.simpleName }.joinToString(" -> ") + + override suspend fun getRegion(): String? = resolvedRegion.get() + + private suspend fun resolveRegion(): String? { + val logger = coroutineContext.logger() + for (provider in providers) { + try { + val region = provider.getRegion() + if (region != null) { + logger.debug { "resolved region ($region) from $provider " } + return region + } + logger.debug { "failed to resolve region from $provider" } + } catch (ex: Exception) { + logger.debug { "unable to load region from $provider: ${ex.message}" } + } + } + + return null + } +} diff --git a/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/util/Annotations.kt b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/util/Annotations.kt new file mode 100644 index 0000000000..f7b9449ebc --- /dev/null +++ b/runtime/smithy-client/common/src/aws/smithy/kotlin/runtime/client/util/Annotations.kt @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.client.util + +/** + * A KMP-compatible variant of [kotlin.jvm.JvmDefaultWithoutCompatibility] + */ +@Retention(AnnotationRetention.SOURCE) +@Target(AnnotationTarget.CLASS) +public expect annotation class MpJvmDefaultWithoutCompatibility() diff --git a/runtime/smithy-client/common/test/aws/smithy/kotlin/runtime/client/region/RegionProviderChainTest.kt b/runtime/smithy-client/common/test/aws/smithy/kotlin/runtime/client/region/RegionProviderChainTest.kt new file mode 100644 index 0000000000..361336ac2d --- /dev/null +++ b/runtime/smithy-client/common/test/aws/smithy/kotlin/runtime/client/region/RegionProviderChainTest.kt @@ -0,0 +1,47 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.client.region + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFails + +class RegionProviderChainTest { + @Test + fun testNoProviders() { + assertFails("at least one provider") { + RegionProviderChain() + } + } + data class TestProvider(val region: String? = null) : RegionProvider { + override suspend fun getRegion(): String? = region + } + + @Test + fun testChain() = runTest { + val chain = RegionProviderChain( + TestProvider(null), + TestProvider("us-east-1"), + TestProvider("us-east-2"), + ) + + assertEquals("us-east-1", chain.getRegion()) + } + + @Test + fun testChainList() = runTest { + val providers = listOf( + TestProvider(null), + TestProvider("us-east-1"), + TestProvider("us-east-2"), + ) + + val chain = RegionProviderChain(providers) + + assertEquals("us-east-1", chain.getRegion()) + } +} diff --git a/runtime/smithy-client/jvm/src/aws/smithy/kotlin/runtime/client/util/AnnotationsJVM.kt b/runtime/smithy-client/jvm/src/aws/smithy/kotlin/runtime/client/util/AnnotationsJVM.kt new file mode 100644 index 0000000000..db7512a9df --- /dev/null +++ b/runtime/smithy-client/jvm/src/aws/smithy/kotlin/runtime/client/util/AnnotationsJVM.kt @@ -0,0 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.client.util + +public actual typealias MpJvmDefaultWithoutCompatibility = kotlin.jvm.JvmDefaultWithoutCompatibility diff --git a/runtime/smithy-client/native/src/aws/smithy/kotlin/runtime/client/util/AnnotationsNative.kt b/runtime/smithy-client/native/src/aws/smithy/kotlin/runtime/client/util/AnnotationsNative.kt new file mode 100644 index 0000000000..118ea69bda --- /dev/null +++ b/runtime/smithy-client/native/src/aws/smithy/kotlin/runtime/client/util/AnnotationsNative.kt @@ -0,0 +1,10 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.client.util + +public actual annotation class MpJvmDefaultWithoutCompatibility { + // No-op on non-JVM platforms +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 2c8d3df112..2738eb0d29 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -120,6 +120,7 @@ include(":tests:codegen:paginator-tests") include(":tests:codegen:serde-tests") include(":tests:codegen:serde-codegen-support") include(":tests:codegen:waiter-tests") +include(":tests:codegen:service-codegen-tests") include(":tests:integration:slf4j-1x-consumer") include(":tests:integration:slf4j-2x-consumer") include(":tests:integration:slf4j-hybrid-consumer") diff --git a/tests/codegen/service-codegen-tests/build.gradle.kts b/tests/codegen/service-codegen-tests/build.gradle.kts new file mode 100644 index 0000000000..a7d734d896 --- /dev/null +++ b/tests/codegen/service-codegen-tests/build.gradle.kts @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import aws.sdk.kotlin.gradle.dsl.skipPublishing + +plugins { + id(libs.plugins.kotlin.jvm.get().pluginId) + alias(libs.plugins.kotlinx.serialization) + alias(libs.plugins.aws.kotlin.repo.tools.smithybuild) +} + +skipPublishing() + +val optinAnnotations = listOf("kotlin.RequiresOptIn") +kotlin.sourceSets.all { + optinAnnotations.forEach { languageSettings.optIn(it) } +} + +// Create a task to run the DefaultServiceGeneratorTestKt file +val runServiceGenerator by tasks.registering(JavaExec::class) { + group = "verification" + description = "Run the DefaultServiceGeneratorTestKt file" + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("com.test.DefaultServiceGeneratorTestKt") +} + +tasks.test { + dependsOn(runServiceGenerator) + useJUnitPlatform() + testLogging { + events("passed", "skipped", "failed") + showStandardStreams = true + } +} + +kotlin { + compilerOptions { + freeCompilerArgs.addAll( + "-opt-in=kotlin.io.path.ExperimentalPathApi", + "-opt-in=kotlinx.serialization.ExperimentalSerializationApi", + ) + } +} + +dependencies { + + compileOnly(project(":codegen:smithy-kotlin-codegen")) + + implementation(project(":codegen:smithy-kotlin-codegen")) + implementation(project(":codegen:smithy-aws-kotlin-codegen")) + implementation(project(":codegen:smithy-kotlin-codegen-testutils")) + + implementation(libs.kotlinx.serialization.json) + implementation(libs.kotlinx.serialization.cbor) + + testImplementation(libs.junit.jupiter) + testImplementation(libs.kotest.assertions.core.jvm) + testImplementation(libs.kotlin.test) + testImplementation(libs.kotlin.test.junit5) + testImplementation(project(":codegen:smithy-kotlin-codegen-testutils")) + testImplementation(project(":codegen:smithy-kotlin-codegen")) + testImplementation(project(":codegen:smithy-aws-kotlin-codegen")) + testApi(project(":runtime:auth:aws-signing-common")) + testApi(project(":runtime:auth:http-auth-aws")) + testApi(project(":runtime:auth:aws-signing-default")) + + testImplementation(gradleTestKit()) + + testImplementation(libs.kotlinx.serialization.json) + testImplementation(libs.kotlinx.serialization.cbor) +} diff --git a/tests/codegen/service-codegen-tests/model/service-authentication-test.smithy b/tests/codegen/service-codegen-tests/model/service-authentication-test.smithy new file mode 100644 index 0000000000..3d3a5dc117 --- /dev/null +++ b/tests/codegen/service-codegen-tests/model/service-authentication-test.smithy @@ -0,0 +1,122 @@ +$version: "2.0" + +namespace com.authentication + +use smithy.api#httpBearerAuth +use aws.auth#sigv4 +use aws.auth#sigv4a +use aws.protocols#restJson1 + +@restJson1 +@auth([sigv4a, sigv4, httpBearerAuth]) +@httpBearerAuth +@sigv4(name: "service-1") +@sigv4a(name: "service-1") +service AuthenticationServiceTest { + version: "1.0.0" + operations: [ + OnlyBearerTest + OnlySigV4Test + SigV4ATest + AllAuthenticationTest + NoAuthenticationTest + SigV4AuthenticationWithBodyTest + SigV4AAuthenticationWithBodyTest + ] +} + +@auth([httpBearerAuth]) +@http(method: "POST", uri: "/only-bearer", code: 201) +operation OnlyBearerTest { + input: OnlyBearerTestInput + output: OnlyBearerTestOutput +} + +@input +structure OnlyBearerTestInput {} + +@output +structure OnlyBearerTestOutput {} + +@auth([sigv4]) +@http(method: "POST", uri: "/only-sigv4", code: 201) +operation OnlySigV4Test { + input: OnlySigV4TestInput + output: OnlySigV4TestOutput +} + +@input +structure OnlySigV4TestInput {} + +@output +structure OnlySigV4TestOutput {} + +@auth([sigv4a, sigv4]) +@http(method: "POST", uri: "/sigv4a", code: 201) +operation SigV4ATest { + input: SigV4ATestInput + output: SigV4ATestOutput +} + +@input +structure SigV4ATestInput {} + +@output +structure SigV4ATestOutput {} + +@http(method: "POST", uri: "/all-authentication", code: 201) +operation AllAuthenticationTest { + input: AllAuthenticationTestInput + output: AllAuthenticationTestOutput +} + +@input +structure AllAuthenticationTestInput {} + +@output +structure AllAuthenticationTestOutput {} + +@auth([]) +@http(method: "POST", uri: "/no-authentication", code: 201) +operation NoAuthenticationTest { + input: NoAuthenticationTestInput + output: NoAuthenticationTestOutput +} + +@input +structure NoAuthenticationTestInput {} + +@output +structure NoAuthenticationTestOutput {} + +@auth([sigv4]) +@http(method: "POST", uri: "/sigv4-authentication-body", code: 201) +operation SigV4AuthenticationWithBodyTest { + input: SigV4AuthenticationWithBodyTestInput + output: SigV4AuthenticationWithBodyTestOutput +} + +@input +structure SigV4AuthenticationWithBodyTestInput { + input1: String +} + +@output +structure SigV4AuthenticationWithBodyTestOutput {} + + +@auth([sigv4a, sigv4]) +@http(method: "POST", uri: "/sigv4a-authentication-body", code: 201) +operation SigV4AAuthenticationWithBodyTest { + input: SigV4AAuthenticationWithBodyTestInput + output: SigV4AAuthenticationWithBodyTestOutput +} + +@input +structure SigV4AAuthenticationWithBodyTestInput { + input1: String +} + +@output +structure SigV4AAuthenticationWithBodyTestOutput {} + diff --git a/tests/codegen/service-codegen-tests/model/service-cbor-test.smithy b/tests/codegen/service-codegen-tests/model/service-cbor-test.smithy new file mode 100644 index 0000000000..90314f61e8 --- /dev/null +++ b/tests/codegen/service-codegen-tests/model/service-cbor-test.smithy @@ -0,0 +1,87 @@ +$version: "2.0" + +namespace com.cbor + +use smithy.protocols#rpcv2Cbor + +@rpcv2Cbor +service CborServiceTest { + version: "1.0.0" + operations: [ + PostTest + AuthTest + ErrorTest + HttpErrorTest + ] +} + +@http(method: "POST", uri: "/post", code: 201) +operation PostTest { + input: PostTestInput + output: PostTestOutput +} + +@input +structure PostTestInput { + input1: String + input2: Integer +} + +@output +structure PostTestOutput { + output1: String + output2: Integer +} + +@http(method: "POST", uri: "/auth", code: 201) +operation AuthTest { + input: AuthTestInput + output: AuthTestOutput +} + +@input +structure AuthTestInput { + input1: String +} + +@output +structure AuthTestOutput { + output1: String +} + +@http(method: "POST", uri: "/error", code: 200) +operation ErrorTest { + input: ErrorTestInput + output: ErrorTestOutput +} + +@input +structure ErrorTestInput { + input1: String +} + +@output +structure ErrorTestOutput { + output1: String +} + + +@http(method: "POST", uri: "/http-error", code: 200) +operation HttpErrorTest { + input: HttpErrorTestInput + output: HttpErrorTestOutput + errors: [HttpError] +} + +@input +structure HttpErrorTestInput {} + +@output +structure HttpErrorTestOutput {} + +@error("client") +@httpError(456) +structure HttpError { + msg: String + num: Integer +} \ No newline at end of file diff --git a/tests/codegen/service-codegen-tests/model/service-constraints-test.smithy b/tests/codegen/service-codegen-tests/model/service-constraints-test.smithy new file mode 100644 index 0000000000..51cacdfedd --- /dev/null +++ b/tests/codegen/service-codegen-tests/model/service-constraints-test.smithy @@ -0,0 +1,165 @@ +$version: "2.0" + +namespace com.constraints + +use smithy.protocols#rpcv2Cbor + +@rpcv2Cbor +service ServiceConstraintsTest { + version: "1.0.0" + operations: [ + RequiredConstraintTest, + LengthConstraintTest, + PatternConstraintTest, + RangeConstraintTest, + UniqueItemsConstraintTest, + NestedUniqueItemsConstraintTest, + DoubleNestedUniqueItemsConstraintTest, + ] +} + +@http(method: "POST", uri: "/required-constraint", code: 201) +operation RequiredConstraintTest { + input: RequiredConstraintTestInput + output: RequiredConstraintTestOutput +} + +@input +structure RequiredConstraintTestInput { + @required + requiredInput: String + notRequiredInput: String +} + +@output +structure RequiredConstraintTestOutput {} + +@http(method: "POST", uri: "/length-constraint", code: 201) +operation LengthConstraintTest { + input: LengthConstraintTestInput + output: LengthConstraintTestOutput +} + +@input +structure LengthConstraintTestInput { + @length(min: 3) + greaterLengthInput: String + + @length(max: 3) + smallerLengthInput: NotUniqueItemsList + + @length(min: 1, max: 2) + betweenLengthInput: MyMap +} + +@output +structure LengthConstraintTestOutput {} + +@http(method: "POST", uri: "/pattern-constraint", code: 201) +operation PatternConstraintTest { + input: PatternConstraintTestInput + output: PatternConstraintTestOutput +} + +@input +structure PatternConstraintTestInput { + @pattern("^[A-Za-z]+$") + patternInput1: String + + @pattern("[1-9]+") + patternInput2: String + + +} + +@output +structure PatternConstraintTestOutput {} + + +@http(method: "POST", uri: "/range-constraint", code: 201) +operation RangeConstraintTest { + input: RangeConstraintTestInput + output: RangeConstraintTestOutput +} + +@input +structure RangeConstraintTestInput { + @range(min: 0, max: 5) + betweenInput: Integer + + @range(min: -10) + greaterInput: Double + + @range(max: 9) + smallerInput: Float +} + +@output +structure RangeConstraintTestOutput {} + +@http(method: "POST", uri: "/unique-items-constraint", code: 201) +operation UniqueItemsConstraintTest { + input: UniqueItemsConstraintTestInput + output: UniqueItemsConstraintTestOutput +} + +@input +structure UniqueItemsConstraintTestInput { + notUniqueItemsListInput: NotUniqueItemsList + uniqueItemsListInput: UniqueItemsList +} + +@output +structure UniqueItemsConstraintTestOutput {} + +@http(method: "POST", uri: "/nested-unique-items-constraint", code: 201) +operation NestedUniqueItemsConstraintTest { + input: NestedUniqueItemsConstraintTestInput + output: NestedUniqueItemsConstraintTestOutput +} + +@input +structure NestedUniqueItemsConstraintTestInput { + nestedUniqueItemsListInput: UniqueItemsListWrap +} + +@output +structure NestedUniqueItemsConstraintTestOutput {} + +@http(method: "POST", uri: "/double-nested-unique-items-constraint", code: 201) +operation DoubleNestedUniqueItemsConstraintTest { + input: DoubleNestedUniqueItemsConstraintTestInput + output: DoubleNestedUniqueItemsConstraintTestOutput +} + +@input +structure DoubleNestedUniqueItemsConstraintTestInput { + doubleNestedUniqueItemsListInput: UniqueItemsListWrapContainer +} + +@output +structure DoubleNestedUniqueItemsConstraintTestOutput {} + +list NotUniqueItemsList { + member: String +} + +@uniqueItems +list UniqueItemsList { + member: String +} + +@uniqueItems +list UniqueItemsListWrap { + member: UniqueItemsList +} + +@uniqueItems +list UniqueItemsListWrapContainer { + member: UniqueItemsListWrap +} + +map MyMap { + key: String + value: String +} \ No newline at end of file diff --git a/tests/codegen/service-codegen-tests/model/service-json-test.smithy b/tests/codegen/service-codegen-tests/model/service-json-test.smithy new file mode 100644 index 0000000000..c48ef7c279 --- /dev/null +++ b/tests/codegen/service-codegen-tests/model/service-json-test.smithy @@ -0,0 +1,198 @@ +$version: "2.0" + +namespace com.json + +use aws.protocols#restJson1 + +@restJson1 +service JsonServiceTest { + version: "1.0.0" + operations: [ + HttpHeaderTest + HttpLabelTest + HttpQueryTest + HttpStringPayloadTest + HttpStructurePayloadTest + TimestampTest + JsonNameTest + HttpErrorTest + ] +} + +@http(method: "POST", uri: "/http-header", code: 201) +operation HttpHeaderTest { + input: HttpHeaderTestInput + output: HttpHeaderTestOutput +} + +@input +structure HttpHeaderTestInput { + @httpHeader("X-Request-Header") + header: String + + @httpPrefixHeaders("X-Request-Headers-") + headers: MapOfStrings +} + +@output +structure HttpHeaderTestOutput { + @httpHeader("X-Response-Header") + header: String + + @httpPrefixHeaders("X-Response-Headers-") + headers: MapOfStrings +} + + +@http(method: "GET", uri: "/http-label/{foo}", code: 200) +operation HttpLabelTest { + input: HttpLabelTestInput + output: HttpLabelTestOutput +} + +@input +structure HttpLabelTestInput { + @required + @httpLabel + foo: String +} + +@output +structure HttpLabelTestOutput { + output: String +} + +@http(method: "DELETE", uri: "/http-query", code: 200) +operation HttpQueryTest { + input: HttpQueryTestInput + output: HttpQueryTestOutput +} + +@input +structure HttpQueryTestInput { + @httpQuery("query") + query: Integer + + @httpQueryParams + params: MapOfStrings +} + +@output +structure HttpQueryTestOutput { + output: String +} + + +@http(method: "POST", uri: "/http-payload/string", code: 201) +operation HttpStringPayloadTest { + input: HttpStringPayloadTestInput + output: HttpStringPayloadTestOutput +} + +@input +structure HttpStringPayloadTestInput { + @httpPayload + content: String +} + +@output +structure HttpStringPayloadTestOutput { + @httpPayload + content: String +} + +@http(method: "POST", uri: "/http-payload/structure", code: 201) +operation HttpStructurePayloadTest { + input: HttpStructurePayloadTestInput + output: HttpStructurePayloadTestOutput +} + +@input +structure HttpStructurePayloadTestInput { + @httpPayload + content: HttpStructurePayloadTestStructure +} + +@output +structure HttpStructurePayloadTestOutput { + @httpPayload + content: HttpStructurePayloadTestStructure +} + + +@http(method: "POST", uri: "/timestamp", code: 201) +operation TimestampTest { + input: TimestampTestInput + output: TimestampTestOutput +} + +@input +structure TimestampTestInput { + default: Timestamp + @timestampFormat("date-time") + dateTime: Timestamp + @timestampFormat("http-date") + httpDate: Timestamp + @timestampFormat("epoch-seconds") + epochSeconds: Timestamp +} + +@output +structure TimestampTestOutput { + default: Timestamp + @timestampFormat("date-time") + dateTime: Timestamp + @timestampFormat("http-date") + httpDate: Timestamp + @timestampFormat("epoch-seconds") + epochSeconds: Timestamp +} + +@http(method: "POST", uri: "/json-name", code: 201) +operation JsonNameTest { + input: JsonNameTestInput + output: JsonNameTestOutput +} + +@input +structure JsonNameTestInput { + @jsonName("requestName") + content: String +} + +@output +structure JsonNameTestOutput { + @jsonName("responseName") + content: String +} + +@http(method: "POST", uri: "/http-error", code: 200) +operation HttpErrorTest { + input: HttpErrorTestInput + output: HttpErrorTestOutput + errors: [HttpError] +} + +@input +structure HttpErrorTestInput {} + +@output +structure HttpErrorTestOutput {} + +@error("client") +@httpError(456) +structure HttpError { + msg: String + num: Integer +} + +structure HttpStructurePayloadTestStructure { + content1: String + content2: Integer + content3: Float +} + +map MapOfStrings { + key: String + value: String +} \ No newline at end of file diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/AuthenticationServiceTestGenerator.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/AuthenticationServiceTestGenerator.kt new file mode 100644 index 0000000000..c339e08b05 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/AuthenticationServiceTestGenerator.kt @@ -0,0 +1,119 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun generateAuthenticationConstraintsTest() { + val modelPath: Path = Paths.get("model", "service-authentication-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "AuthenticationServiceTest" + val packageName = "com.authentication" + val outputDirName = "service-authentication-test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", outputDirName).also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val bearerValidation = """ + package $packageName.auth + + internal object BearerValidation { + public fun bearerValidation(token: String): UserPrincipal? { + // TODO: implement me + if (token == "correctToken") return UserPrincipal("Authenticated User") else return null + } + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/auth/Validation.kt", bearerValidation) + + val awsValidation = """ + package $packageName.auth + + import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials + + internal object SigV4CredentialStore { + private val table: Map = mapOf( + "AKIACORRECTEXAMPLEACCESSKEY" to Credentials(accessKeyId = "AKIACORRECTEXAMPLEACCESSKEY", secretAccessKey = "CORRECTEXAMPLESECRETKEY"), + ) + internal fun get(accessKeyId: String): Credentials? { + return table[accessKeyId] + } + } + + internal object SigV4aPublicKeyStore { + private val table: MutableMap = mutableMapOf() + + init { + val pem = ""${'"'} + -----BEGIN PUBLIC KEY----- + MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE/VTdYVFt+jAgj1N4Q+Dnpcho/XeI + 655JtWjFvxscKZJbDNa8F6hzo/s3lQNwMozl2p3KqmmjYwlIu9tQQkFZvQ== + -----END PUBLIC KEY----- + ""${'"'}.trimIndent() + val clean = pem.replace("-----BEGIN PUBLIC KEY-----", "").replace("-----END PUBLIC KEY-----", "").replace("\\s".toRegex(), "") + val keyBytes = java.util.Base64.getDecoder().decode(clean) + val spec = java.security.spec.X509EncodedKeySpec(keyBytes) + val kf = java.security.KeyFactory.getInstance("EC") + table["AKIACORRECTEXAMPLEACCESSKEY"] = kf.generatePublic(spec) + } + + internal fun get(accessKeyId: String): java.security.PublicKey? { + return table[accessKeyId] + } + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/auth/AWSValidation.kt", awsValidation) + + val settingGradleKts = """ + rootProject.name = "service-authentication-test" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/CborServiceTestGenerator.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/CborServiceTestGenerator.kt new file mode 100644 index 0000000000..b9e69a8642 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/CborServiceTestGenerator.kt @@ -0,0 +1,119 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun generateCborServiceTest() { + val modelPath: Path = Paths.get("model", "service-cbor-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "CborServiceTest" + val packageName = "com.cbor" + val outputDirName = "service-cbor-test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", outputDirName).also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val postTestOperation = """ + package $packageName.operations + + import $packageName.model.PostTestRequest + import $packageName.model.PostTestResponse + + public fun handlePostTestRequest(req: PostTestRequest): PostTestResponse { + val response = PostTestResponse.Builder() + val input1 = req.input1 ?: "" + val input2 = req.input2 ?: 0 + response.output1 = input1 + " world!" + response.output2 = input2 + 1 + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/PostTestOperation.kt", postTestOperation) + + val errorTestOperation = """ + package $packageName.operations + + import $packageName.model.ErrorTestRequest + import $packageName.model.ErrorTestResponse + + public fun handleErrorTestRequest(req: ErrorTestRequest): ErrorTestResponse { + val variable: String? = null + val error = variable!!.length + return ErrorTestResponse.Builder().build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/ErrorTestOperation.kt", errorTestOperation) + + val httpErrorTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpErrorTestRequest + import $packageName.model.HttpErrorTestResponse + import $packageName.model.HttpError + + public fun handleHttpErrorTestRequest(req: HttpErrorTestRequest): HttpErrorTestResponse { + + val error = HttpError.Builder() + error.msg = "this is an error message" + error.num = 444 + throw error.build() + + return HttpErrorTestResponse.Builder().build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpErrorTestOperation.kt", httpErrorTestOperation) + + val settingGradleKts = """ + rootProject.name = "service-cbor-test" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/ConstraintsServiceTestGenerator.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/ConstraintsServiceTestGenerator.kt new file mode 100644 index 0000000000..adcf520f61 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/ConstraintsServiceTestGenerator.kt @@ -0,0 +1,69 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun generateServiceConstraintsTest() { + val modelPath: Path = Paths.get("model", "service-constraints-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "ServiceConstraintsTest" + val packageName = "com.constraints" + val outputDirName = "service-constraints-test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", outputDirName).also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val settingGradleKts = """ + rootProject.name = "service-constraints-test" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt new file mode 100644 index 0000000000..3a816637ed --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt @@ -0,0 +1,8 @@ +package com.test + +internal fun main() { + generateCborServiceTest() + generateJsonServiceTest() + generateServiceConstraintsTest() + generateAuthenticationConstraintsTest() +} diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/JsonServiceTestGenerator.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/JsonServiceTestGenerator.kt new file mode 100644 index 0000000000..b21e4e1a6b --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/JsonServiceTestGenerator.kt @@ -0,0 +1,195 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun generateJsonServiceTest() { + val modelPath: Path = Paths.get("model", "service-json-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "JsonServiceTest" + val packageName = "com.json" + val outputDirName = "service-json-test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", outputDirName).also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val httpHeaderTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpHeaderTestRequest + import $packageName.model.HttpHeaderTestResponse + + public fun handleHttpHeaderTestRequest(req: HttpHeaderTestRequest): HttpHeaderTestResponse { + val response = HttpHeaderTestResponse.Builder() + response.header = req.headers?.get("hhh") + response.headers = mapOf("hhh" to (req.header ?: "")) + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpHeaderTestOperation.kt", httpHeaderTestOperation) + + val httpLabelTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpLabelTestRequest + import $packageName.model.HttpLabelTestResponse + + public fun handleHttpLabelTestRequest(req: HttpLabelTestRequest): HttpLabelTestResponse { + val response = HttpLabelTestResponse.Builder() + response.output = req.foo + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpLabelTestOperation.kt", httpLabelTestOperation) + + val httpQueryTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpQueryTestRequest + import $packageName.model.HttpQueryTestResponse + + public fun handleHttpQueryTestRequest(req: HttpQueryTestRequest): HttpQueryTestResponse { + val response = HttpQueryTestResponse.Builder() + response.output = req.query.toString() + (req.params?.get("qqq") ?: "") + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpQueryTestOperation.kt", httpQueryTestOperation) + + val httpStringPayloadTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpStringPayloadTestRequest + import $packageName.model.HttpStringPayloadTestResponse + + public fun handleHttpStringPayloadTestRequest(req: HttpStringPayloadTestRequest): HttpStringPayloadTestResponse { + val response = HttpStringPayloadTestResponse.Builder() + response.content = req.content + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpStringPayloadTestOperation.kt", httpStringPayloadTestOperation) + + val httpStructurePayloadTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpStructurePayloadTestRequest + import $packageName.model.HttpStructurePayloadTestResponse + import $packageName.model.HttpStructurePayloadTestStructure + + public fun handleHttpStructurePayloadTestRequest(req: HttpStructurePayloadTestRequest): HttpStructurePayloadTestResponse { + val response = HttpStructurePayloadTestResponse.Builder() + val content = HttpStructurePayloadTestStructure.Builder() + content.content1 = req.content?.content1 + content.content2 = req.content?.content2 + content.content3 = req.content?.content3 + response.content = content.build() + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpStructurePayloadTestOperation.kt", httpStructurePayloadTestOperation) + + val timestampTestOperation = """ + package $packageName.operations + + import $packageName.model.TimestampTestRequest + import $packageName.model.TimestampTestResponse + + public fun handleTimestampTestRequest(req: TimestampTestRequest): TimestampTestResponse { + val response = TimestampTestResponse.Builder() + response.default = req.default + response.dateTime = req.dateTime + response.httpDate = req.httpDate + response.epochSeconds = req.epochSeconds + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/TimestampTestOperation.kt", timestampTestOperation) + + val jsonNameTestOperation = """ + package $packageName.operations + + import $packageName.model.JsonNameTestRequest + import $packageName.model.JsonNameTestResponse + + public fun handleJsonNameTestRequest(req: JsonNameTestRequest): JsonNameTestResponse { + val response = JsonNameTestResponse.Builder() + response.content = req.content + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/JsonNameTestOperation.kt", jsonNameTestOperation) + + val httpErrorTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpErrorTestRequest + import $packageName.model.HttpErrorTestResponse + import $packageName.model.HttpError + + public fun handleHttpErrorTestRequest(req: HttpErrorTestRequest): HttpErrorTestResponse { + + val error = HttpError.Builder() + error.msg = "this is an error message" + error.num = 444 + throw error.build() + + return HttpErrorTestResponse.Builder().build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpErrorTestOperation.kt", httpErrorTestOperation) + + val settingGradleKts = """ + rootProject.name = "service-json-test" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/AuthenticationServiceTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/AuthenticationServiceTest.kt new file mode 100644 index 0000000000..0bfb55872b --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/AuthenticationServiceTest.kt @@ -0,0 +1,656 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials +import aws.smithy.kotlin.runtime.auth.awssigning.AwsSigningAlgorithm +import kotlinx.serialization.json.Json +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import java.net.ServerSocket +import java.net.http.HttpResponse +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class AuthenticationServiceTest { + val closeGracePeriodMillis = TestParams.CLOSE_GRACE_PERIOD_MILLIS + val closeTimeoutMillis = TestParams.CLOSE_TIMEOUT_MILLIS + val gracefulWindow = TestParams.GRACEFUL_WINDOW + val requestBodyLimit = TestParams.REQUEST_BODY_LIMIT + val portListenerTimeout = TestParams.PORT_LISTENER_TIMEOUT + + val port: Int = ServerSocket(0).use { it.localPort } + val baseUrl = "http://localhost:$port" + + val projectDir: Path = Paths.get("build/service-authentication-test") + + private lateinit var proc: Process + + @BeforeAll + fun boot() { + proc = startService("netty", port, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(port, portListenerTimeout) + assertTrue(ready, "Service did not start within $portListenerTimeout s") + } + + @AfterAll + fun shutdown() = cleanupService(proc, gracefulWindow) + + @Test + fun `checks bearer authentication with correct token`() { + val response = sendRequest( + "$baseUrl/only-bearer", + "POST", + null, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks bearer authentication with wrong token`() { + val response = sendRequest( + "$baseUrl/only-bearer", + "POST", + null, + "application/json", + "application/json", + "wrongToken", + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks bearer authentication without token`() { + val response = sendRequest( + "$baseUrl/only-bearer", + "POST", + null, + "application/json", + "application/json", + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Missing bearer token", body.message) + } + + @Test + fun `checks sigv4 authentication with correct signature`() { + val region = "us-east-1" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4, + ) + + val response = sendRequest( + "$baseUrl/only-sigv4", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks sigv4 authentication with wrong region`() { + val region = "us-east-2" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4, + ) + + val response = sendRequest( + "$baseUrl/only-sigv4", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks sigv4 authentication with wrong service name`() { + val region = "us-east-1" + val service = "service-2" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4, + ) + + val response = sendRequest( + "$baseUrl/only-sigv4", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks sigv4 authentication with wrong access key`() { + val region = "us-east-1" + val service = "service-1" + + val accessKey = "AKIAWRONGEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4, + ) + + val response = sendRequest( + "$baseUrl/only-sigv4", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks sigv4 authentication with wrong secret key`() { + val region = "us-east-1" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "WRONGEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4, + ) + + val response = sendRequest( + "$baseUrl/only-sigv4", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks sigv4a authentication with correct signature`() { + val region = "*" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val response = sendRequest( + "$baseUrl/sigv4a", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks sigv4a authentication with specific region`() { + val region = "us-east-1" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val response = sendRequest( + "$baseUrl/sigv4a", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks sigv4a authentication with multi regions`() { + val region = "us-east-*" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val response = sendRequest( + "$baseUrl/sigv4a", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks sigv4a authentication with wrong region`() { + val region = "us-east-2" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val response = sendRequest( + "$baseUrl/sigv4a", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks sigv4a authentication with wrong access key`() { + val region = "*" + val service = "service-1" + + val accessKey = "AKIAWRONGEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val response = sendRequest( + "$baseUrl/sigv4a", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks sigv4a authentication with wrong secret key`() { + val region = "us-east-1" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "WRONGEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val response = sendRequest( + "$baseUrl/sigv4a", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + + val body = Json.decodeFromString( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(401, body.code) + assertEquals("Invalid or expired authentication", body.message) + } + + @Test + fun `checks multi authentications with bearer token`() { + val response = sendRequest( + "$baseUrl/all-authentication", + "POST", + null, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks multi authentications with sigv4 token`() { + val region = "us-east-1" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4, + ) + + val response = sendRequest( + "$baseUrl/all-authentication", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks multi authentications with sigv4a token`() { + val region = "*" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val response = sendRequest( + "$baseUrl/all-authentication", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks no authentication`() { + val response = sendRequest( + "$baseUrl/no-authentication", + "POST", + null, + "application/json", + "application/json", + null, + mapOf(), + null, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks sigv4 authentication with body with correct signature`() { + val region = "us-east-1" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4, + ) + + val requestJson = Json.encodeToJsonElement( + AuthTestRequest.serializer(), + AuthTestRequest("this is a test input"), + ) + + val response = sendRequest( + "$baseUrl/sigv4-authentication-body", + "POST", + requestJson, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks sigv4a authentication with body with correct signature`() { + val region = "*" + val service = "service-1" + + val accessKey = "AKIACORRECTEXAMPLEACCESSKEY" + val secretKey = "CORRECTEXAMPLESECRETKEY" + + val creds = Credentials(accessKey, secretKey) + + val signingOptions = AwsSigningOptions( + credentials = creds, + service = service, + region = region, + algorithm = AwsSigningAlgorithm.SIGV4_ASYMMETRIC, + ) + + val requestJson = Json.encodeToJsonElement( + AuthTestRequest.serializer(), + AuthTestRequest("this is a test input"), + ) + + val response = sendRequest( + "$baseUrl/sigv4a-authentication-body", + "POST", + requestJson, + "application/json", + "application/json", + null, + mapOf(), + signingOptions, + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/CborServiceTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/CborServiceTest.kt new file mode 100644 index 0000000000..600a68d72f --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/CborServiceTest.kt @@ -0,0 +1,329 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import kotlinx.serialization.cbor.Cbor +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import java.net.ServerSocket +import java.net.http.HttpResponse +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class CborServiceTest { + val closeGracePeriodMillis = TestParams.CLOSE_GRACE_PERIOD_MILLIS + val closeTimeoutMillis = TestParams.CLOSE_TIMEOUT_MILLIS + val gracefulWindow = TestParams.GRACEFUL_WINDOW + val requestBodyLimit = TestParams.REQUEST_BODY_LIMIT + val portListenerTimeout = TestParams.PORT_LISTENER_TIMEOUT + + val port: Int = ServerSocket(0).use { it.localPort } + val baseUrl = "http://localhost:$port" + + val projectDir: Path = Paths.get("build/service-cbor-test") + + private lateinit var proc: Process + + @BeforeAll + fun boot() { + proc = startService("netty", port, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(port, portListenerTimeout) + assertTrue(ready, "Service did not start within $portListenerTimeout s") + } + + @AfterAll + fun shutdown() = cleanupService(proc, gracefulWindow) + + @Test + fun `checks correct POST request`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + + val body = cbor.decodeFromByteArray( + PostTestResponse.serializer(), + response.body(), + ) + + assertEquals("Hello world!", body.output1) + assertEquals(input2 + 1, body.output2) + } + + @Test + fun `checks unhandled runtime exception in handler`() { + val cbor = Cbor { } + val input1 = "Hello" + val requestBytes = cbor.encodeToByteArray( + ErrorTestRequest.serializer(), + ErrorTestRequest(input1), + ) + + val response = sendRequest( + "$baseUrl/error", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(500, response.statusCode(), "Expected 500") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(500, body.code) + assertEquals("Unexpected error", body.message) + } + + @Test + fun `checks wrong content type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/json", + "application/cbor", + ) + assertIs>(response) + assertEquals(415, response.statusCode(), "Expected 415") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(415, body.code) + assertEquals("Not acceptable Content‑Type found: 'application/json'. Accepted content types: application/cbor", body.message) + } + + @Test + fun `checks missing content type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + acceptType = "application/cbor", + ) + assertIs>(response) + assertEquals(415, response.statusCode(), "Expected 415") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(415, body.code) + assertEquals("Not acceptable Content‑Type found: '*/*'. Accepted content types: application/cbor", body.message) + } + + @Test + fun `checks wrong accept type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/json", + ) + assertIs>(response) + assertEquals(406, response.statusCode(), "Expected 406") + + assertEquals("""{"code":406,"message":"Not acceptable Accept type found: '[application/json]'. Accepted types: application/cbor"}""", response.body()) + } + + @Test + fun `checks missing accept type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks malformed input`() { + val cbor = Cbor { } + val input1 = 123 + val input2 = "Hello" + val requestBytes = cbor.encodeToByteArray( + MalformedPostTestRequest.serializer(), + MalformedPostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("Unexpected EOF: expected 109 more bytes; consumed: 14", body.message) + } + + @Test + fun `checks route not found`() { + val cbor = Cbor { } + val requestBytes = ByteArray(0) + val response = sendRequest( + "$baseUrl/does-not-exist", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(404, response.statusCode(), "Expected 404") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(404, body.code) + assertEquals("Resource not found", body.message) + } + + @Test + fun `checks method not allowed`() { + val cbor = Cbor { } + val input1 = 123 + val input2 = "Hello" + val requestBytes = cbor.encodeToByteArray( + MalformedPostTestRequest.serializer(), + MalformedPostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "PUT", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(405, response.statusCode(), "Expected 405") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(405, body.code) + assertEquals("Method not allowed for this resource", body.message) + } + + @Test + fun `checks request body limit`() { + val cbor = Cbor { } + val overLimitPayload = "x".repeat(requestBodyLimit.toInt() + 1) + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(overLimitPayload, input2), + ) + require(requestBytes.size > 10 * 1024 * 1024) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(413, response.statusCode(), "Expected 413") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(413, body.code) + assertEquals("Request is larger than the limit of 10485760 bytes", body.message) + } + + @Test + fun `checks http error`() { + val cbor = Cbor { } + + val response = sendRequest( + "$baseUrl/http-error", + "POST", + null, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + + assertEquals(456, response.statusCode(), "Expected 456") + val body = cbor.decodeFromByteArray( + HttpError.serializer(), + response.body(), + ) + + assertEquals(444, body.num) + assertEquals("this is an error message", body.msg) + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/JsonServiceTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/JsonServiceTest.kt new file mode 100644 index 0000000000..2df32dfc30 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/JsonServiceTest.kt @@ -0,0 +1,221 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.json.Json +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import java.net.ServerSocket +import java.net.http.HttpResponse +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class JsonServiceTest { + val closeGracePeriodMillis = TestParams.CLOSE_GRACE_PERIOD_MILLIS + val closeTimeoutMillis = TestParams.CLOSE_TIMEOUT_MILLIS + val gracefulWindow = TestParams.GRACEFUL_WINDOW + val requestBodyLimit = TestParams.REQUEST_BODY_LIMIT + val portListenerTimeout = TestParams.PORT_LISTENER_TIMEOUT + + val port: Int = ServerSocket(0).use { it.localPort } + val baseUrl = "http://localhost:$port" + + val projectDir: Path = Paths.get("build/service-json-test") + + private lateinit var proc: Process + + @BeforeAll + fun boot() { + proc = startService("netty", port, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(port, portListenerTimeout) + assertTrue(ready, "Service did not start within $portListenerTimeout s") + } + + @AfterAll + fun shutdown() = cleanupService(proc, gracefulWindow) + + @Test + fun `checks http-header`() { + val response = sendRequest( + "$baseUrl/http-header", + "POST", + null, + "application/json", + "application/json", + headers = mapOf("X-Request-Header" to "header", "X-Request-Headers-hhh" to "headers"), + ) + assertIs>(response) + + assertEquals(201, response.statusCode(), "Expected 201") + + assertEquals("headers", response.headers().firstValue("X-Response-Header").get()) + assertEquals("header", response.headers().firstValue("X-Response-Headers-hhh").get()) + } + + @Test + fun `checks http-label`() { + val response = sendRequest( + "$baseUrl/http-label/labelValue", + "GET", + null, + "application/json", + "application/json", + ) + assertIs>(response) + + assertEquals(200, response.statusCode(), "Expected 200") + val body = Json.decodeFromString( + HttpLabelTestOutputResponse.serializer(), + response.body(), + ) + assertEquals("labelValue", body.output) + } + + @Test + fun `checks http-query`() { + val response = sendRequest( + "$baseUrl/http-query?query=123&qqq=kotlin", + "DELETE", + null, + "application/json", + "application/json", + ) + assertIs>(response) + + assertEquals(200, response.statusCode(), "Expected 200") + val body = Json.decodeFromString( + HttpQueryTestOutputResponse.serializer(), + response.body(), + ) + assertEquals("123kotlin", body.output) + } + + @Test + fun `checks http-payload string`() { + val response = sendRequest( + "$baseUrl/http-payload/string", + "POST", + "This is the entire content", + "text/plain", + "text/plain", + ) + assertIs>(response) + + assertEquals(201, response.statusCode(), "Expected 201") + assertEquals("This is the entire content", response.body()) + } + + @Test + fun `checks http-payload structure`() { + val requestJson = Json.encodeToJsonElement( + HttpStructurePayloadTestStructure.serializer(), + HttpStructurePayloadTestStructure( + "content", + 123, + 456.toFloat(), + ), + ) + + val response = sendRequest( + "$baseUrl/http-payload/structure", + "POST", + requestJson, + "application/json", + "application/json", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + val body = Json.decodeFromString( + HttpStructurePayloadTestStructure.serializer(), + response.body(), + ) + assertEquals("content", body.content1) + assertEquals(123, body.content2) + assertEquals(456.toFloat(), body.content3) + } + + @Test + fun `checks timestamp`() { + val requestJson = Json.encodeToJsonElement( + TimestampTestRequestResponse.serializer(), + TimestampTestRequestResponse( + 1515531081.123, + "1985-04-12T23:20:50.520Z", + "Tue, 29 Apr 2014 18:30:38 GMT", + 1234567890.123, + ), + ) + + val response = sendRequest( + "$baseUrl/timestamp", + "POST", + requestJson, + "application/json", + "application/json", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + val body = Json.decodeFromString( + TimestampTestRequestResponse.serializer(), + response.body(), + ) + assertEquals(1515531081.123, body.default) + assertEquals("1985-04-12T23:20:50.520Z", body.dateTime) + assertEquals("Tue, 29 Apr 2014 18:30:38 GMT", body.httpDate) + assertEquals(1234567890.123, body.epochSeconds) + } + + @Test + fun `checks json name`() { + val requestJson = Json.encodeToJsonElement( + JsonNameTestRequest.serializer(), + JsonNameTestRequest("Hello Kotlin Team"), + ) + + val response = sendRequest( + "$baseUrl/json-name", + "POST", + requestJson, + "application/json", + "application/json", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + val body = Json.decodeFromString( + JsonNameTestResponse.serializer(), + response.body(), + ) + assertEquals("Hello Kotlin Team", body.responseName) + } + + @Test + fun `checks http error`() { + val response = sendRequest( + "$baseUrl/http-error", + "POST", + null, + "application/json", + "application/json", + ) + assertIs>(response) + + assertEquals(456, response.statusCode(), "Expected 456") + val body = Json.decodeFromString( + HttpError.serializer(), + response.body(), + ) + + assertEquals(444, body.num) + assertEquals("this is an error message", body.msg) + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceConstraintsTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceConstraintsTest.kt new file mode 100644 index 0000000000..0f5c9cff1c --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceConstraintsTest.kt @@ -0,0 +1,552 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import kotlinx.serialization.cbor.Cbor +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import java.net.ServerSocket +import java.net.http.HttpResponse +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue + +/* Tests for checking constraint traits work */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class ServiceConstraintsTest { + val closeGracePeriodMillis = TestParams.CLOSE_GRACE_PERIOD_MILLIS + val closeTimeoutMillis = TestParams.CLOSE_TIMEOUT_MILLIS + val gracefulWindow = TestParams.GRACEFUL_WINDOW + val requestBodyLimit = TestParams.REQUEST_BODY_LIMIT + val portListenerTimeout = TestParams.PORT_LISTENER_TIMEOUT + + val port: Int = ServerSocket(0).use { it.localPort } + val baseUrl = "http://localhost:$port" + + val projectDir: Path = Paths.get("build/service-constraints-test") + + private lateinit var proc: Process + + @BeforeAll + fun boot() { + proc = startService("netty", port, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(port, portListenerTimeout) + assertTrue(ready, "Service did not start within $portListenerTimeout s") + } + + @AfterAll + fun shutdown() = cleanupService(proc, gracefulWindow) + + @Test + fun `checks required constraint providing all data`() { + val cbor = Cbor { } + val requiredInput = "Hello" + val notRequiredInput = "World" + val requestBytes = cbor.encodeToByteArray( + RequiredConstraintTestRequest.serializer(), + RequiredConstraintTestRequest(requiredInput, notRequiredInput), + ) + + val response = sendRequest( + "$baseUrl/required-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks required constraint without providing non-required data`() { + val cbor = Cbor { } + val requiredInput = "Hello" + val requestBytes = cbor.encodeToByteArray( + RequiredConstraintTestRequest.serializer(), + RequiredConstraintTestRequest(requiredInput, null), + ) + + val response = sendRequest( + "$baseUrl/required-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks required constraint without providing required data`() { + val cbor = Cbor { } + val nonRequiredInput = "World" + val requestBytes = cbor.encodeToByteArray( + RequiredConstraintTestRequest.serializer(), + RequiredConstraintTestRequest(null, nonRequiredInput), + ) + + val response = sendRequest( + "$baseUrl/required-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("`requiredInput` must be provided", body.message) + } + + @Test + fun `checks length constraint providing correct data`() { + val cbor = Cbor { } + val greaterLengthInput = "1234567890" + val smallerLengthInput = listOf("1", "2", "3") + val betweenLengthInput = mapOf("1" to "2", "3" to "4") + + val requestBytes = cbor.encodeToByteArray( + LengthConstraintTestRequest.serializer(), + LengthConstraintTestRequest(greaterLengthInput, smallerLengthInput, betweenLengthInput), + ) + + val response = sendRequest( + "$baseUrl/length-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks length constraint violating greater than or equal to`() { + val cbor = Cbor { } + val greaterLengthInput = "1" + val smallerLengthInput = listOf("1", "2", "3") + val betweenLengthInput = mapOf("1" to "2", "3" to "4") + + val requestBytes = cbor.encodeToByteArray( + LengthConstraintTestRequest.serializer(), + LengthConstraintTestRequest(greaterLengthInput, smallerLengthInput, betweenLengthInput), + ) + + val response = sendRequest( + "$baseUrl/length-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("The size of `greaterLengthInput` must be greater than or equal to 3", body.message) + } + + @Test + fun `checks length constraint violating smaller than or equal to`() { + val cbor = Cbor { } + val greaterLengthInput = "123456789" + val smallerLengthInput = listOf("1", "2", "3", "4", "5", "6") + val betweenLengthInput = mapOf("1" to "2", "3" to "4") + + val requestBytes = cbor.encodeToByteArray( + LengthConstraintTestRequest.serializer(), + LengthConstraintTestRequest(greaterLengthInput, smallerLengthInput, betweenLengthInput), + ) + + val response = sendRequest( + "$baseUrl/length-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("The size of `smallerLengthInput` must be less than or equal to 3", body.message) + } + + @Test + fun `checks length constraint violating between`() { + val cbor = Cbor { } + val greaterLengthInput = "123456789" + val smallerLengthInput = listOf("1", "2") + val betweenLengthInput = mapOf("1" to "2", "3" to "4", "5" to "6", "7" to "8") + + val requestBytes = cbor.encodeToByteArray( + LengthConstraintTestRequest.serializer(), + LengthConstraintTestRequest(greaterLengthInput, smallerLengthInput, betweenLengthInput), + ) + + val response = sendRequest( + "$baseUrl/length-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("The size of `betweenLengthInput` must be between 1 and 2 (inclusive)", body.message) + } + + @Test + fun `checks pattern constraint providing correct data`() { + val cbor = Cbor { } + val patternInput1 = "qwertyuiop" + val patternInput2 = "qwe123rty" + + val requestBytes = cbor.encodeToByteArray( + PatternConstraintTestRequest.serializer(), + PatternConstraintTestRequest(patternInput1, patternInput2), + ) + + val response = sendRequest( + "$baseUrl/pattern-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks pattern constraint providing incorrect pattern 1`() { + val cbor = Cbor { } + val patternInput1 = "qwertyuiop1" + val patternInput2 = "qwe123rty" + + val requestBytes = cbor.encodeToByteArray( + PatternConstraintTestRequest.serializer(), + PatternConstraintTestRequest(patternInput1, patternInput2), + ) + + val response = sendRequest( + "$baseUrl/pattern-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("Value `qwertyuiop1` does not match required pattern: `^[A-Za-z]+\$`", body.message) + } + + @Test + fun `checks pattern constraint providing incorrect pattern 2`() { + val cbor = Cbor { } + val patternInput1 = "qwertyuiop" + val patternInput2 = "qwerty" + + val requestBytes = cbor.encodeToByteArray( + PatternConstraintTestRequest.serializer(), + PatternConstraintTestRequest(patternInput1, patternInput2), + ) + + val response = sendRequest( + "$baseUrl/pattern-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("Value `qwerty` does not match required pattern: `[1-9]+`", body.message) + } + + @Test + fun `checks range constraint providing correct data`() { + val cbor = Cbor { } + val betweenInput = 3 + val greaterInput = (-1).toDouble() + val smallerInput = 8.toFloat() + + val requestBytes = cbor.encodeToByteArray( + RangeConstraintTestRequest.serializer(), + RangeConstraintTestRequest(betweenInput, greaterInput, smallerInput), + ) + + val response = sendRequest( + "$baseUrl/range-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks range constraint violating greater than or equal to`() { + val cbor = Cbor { } + val betweenInput = 3 + val greaterInput = (-100).toDouble() + val smallerInput = 8.toFloat() + + val requestBytes = cbor.encodeToByteArray( + RangeConstraintTestRequest.serializer(), + RangeConstraintTestRequest(betweenInput, greaterInput, smallerInput), + ) + + val response = sendRequest( + "$baseUrl/range-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("`greaterInput` must be greater than or equal to -10", body.message) + } + + @Test + fun `checks range constraint violating smaller than or equal to`() { + val cbor = Cbor { } + val betweenInput = 3 + val greaterInput = (-1).toDouble() + val smallerInput = 10.toFloat() + + val requestBytes = cbor.encodeToByteArray( + RangeConstraintTestRequest.serializer(), + RangeConstraintTestRequest(betweenInput, greaterInput, smallerInput), + ) + + val response = sendRequest( + "$baseUrl/range-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("`smallerInput` must be less than or equal to 9", body.message) + } + + @Test + fun `checks range constraint violating between`() { + val cbor = Cbor { } + val betweenInput = -1 + val greaterInput = (-1).toDouble() + val smallerInput = 8.toFloat() + + val requestBytes = cbor.encodeToByteArray( + RangeConstraintTestRequest.serializer(), + RangeConstraintTestRequest(betweenInput, greaterInput, smallerInput), + ) + + val response = sendRequest( + "$baseUrl/range-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("`betweenInput` must be between 0 and 5 (inclusive)", body.message) + } + + @Test + fun `checks unique items constraint providing correct data`() { + val cbor = Cbor { } + val notUniqueInput = listOf("1", "2", "3", "4", "5", "1", "2", "3", "3", "4", "5") + val uniqueInput = listOf("1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11") + + val requestBytes = cbor.encodeToByteArray( + UniqueItemsConstraintTestRequest.serializer(), + UniqueItemsConstraintTestRequest(notUniqueInput, uniqueInput), + ) + + val response = sendRequest( + "$baseUrl/unique-items-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks unique items constraint providing non unique list`() { + val cbor = Cbor { } + val notUniqueInput = listOf("1", "2", "3", "4", "5", "1", "2", "3", "3", "4", "5") + val uniqueInput = listOf("1", "2", "3", "4", "5", "1", "2", "3", "3", "4", "5") + + val requestBytes = cbor.encodeToByteArray( + UniqueItemsConstraintTestRequest.serializer(), + UniqueItemsConstraintTestRequest(notUniqueInput, uniqueInput), + ) + + val response = sendRequest( + "$baseUrl/unique-items-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("`uniqueItemsListInput` must contain only unique items, duplicate values are not allowed", body.message) + } + + @Test + fun `checks unique items constraint providing unique nested list`() { + val cbor = Cbor { } + val nestedUniqueItemsListInput = listOf(listOf("1"), listOf("2", "3"), listOf("4"), listOf("5", "6", "7")) + + val requestBytes = cbor.encodeToByteArray( + NestedUniqueItemsConstraintTestRequest.serializer(), + NestedUniqueItemsConstraintTestRequest(nestedUniqueItemsListInput), + ) + + val response = sendRequest( + "$baseUrl/nested-unique-items-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks unique items constraint providing non-unique nested list`() { + val cbor = Cbor { } + val nestedUniqueItemsListInput = listOf(listOf("1"), listOf("2", "2"), listOf("4"), listOf("5", "6", "7")) + + val requestBytes = cbor.encodeToByteArray( + NestedUniqueItemsConstraintTestRequest.serializer(), + NestedUniqueItemsConstraintTestRequest(nestedUniqueItemsListInput), + ) + + val response = sendRequest( + "$baseUrl/nested-unique-items-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("`member` must contain only unique items, duplicate values are not allowed", body.message) + } + + @Test + fun `checks unique items constraint providing non-unique nested nested list`() { + val cbor = Cbor { } + val doubleNestedUniqueItemsListInput = listOf( + listOf(listOf("0"), listOf("1", "2"), listOf("6"), listOf("9", "10", "11")), + listOf(listOf("2"), listOf("7", "2"), listOf("4"), listOf("5", "6", "5")), + listOf(listOf("1"), listOf("1", "2"), listOf("4"), listOf("5", "6", "7")), + ) + + val requestBytes = cbor.encodeToByteArray( + DoubleNestedUniqueItemsConstraintTestRequest.serializer(), + DoubleNestedUniqueItemsConstraintTestRequest(doubleNestedUniqueItemsListInput), + ) + + val response = sendRequest( + "$baseUrl/double-nested-unique-items-constraint", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + + val body = cbor.decodeFromByteArray( + ErrorResponse.serializer(), + response.body(), + ) + assertEquals(400, body.code) + assertEquals("`member` must contain only unique items, duplicate values are not allowed", body.message) + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceDataClasses.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceDataClasses.kt new file mode 100644 index 0000000000..6ea3bb88ac --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceDataClasses.kt @@ -0,0 +1,76 @@ +package com.test + +import kotlinx.serialization.Serializable + +@Serializable +data class ErrorResponse(val code: Int, val message: String) + +@Serializable +data class MalformedPostTestRequest(val input1: Int, val input2: String) + +@Serializable +data class PostTestRequest(val input1: String, val input2: Int) + +@Serializable +data class PostTestResponse(val output1: String? = null, val output2: Int? = null) + +@Serializable +data class AuthTestRequest(val input1: String) + +@Serializable +data class ErrorTestRequest(val input1: String) + +@Serializable +data class HttpError(val msg: String, val num: Int) + +@Serializable +data class RequiredConstraintTestRequest(val requiredInput: String? = null, val notRequiredInput: String? = null) + +@Serializable +data class LengthConstraintTestRequest( + val greaterLengthInput: String, + val smallerLengthInput: List, + val betweenLengthInput: Map, +) + +@Serializable +data class PatternConstraintTestRequest(val patternInput1: String, val patternInput2: String) + +@Serializable +data class RangeConstraintTestRequest(val betweenInput: Int, val greaterInput: Double, val smallerInput: Float) + +@Serializable +data class UniqueItemsConstraintTestRequest(val notUniqueItemsListInput: List, val uniqueItemsListInput: List) + +@Serializable +data class NestedUniqueItemsConstraintTestRequest(val nestedUniqueItemsListInput: List>) + +@Serializable +data class DoubleNestedUniqueItemsConstraintTestRequest(val doubleNestedUniqueItemsListInput: List>>) + +@Serializable +data class HttpLabelTestOutputResponse(val output: String) + +@Serializable +data class HttpQueryTestOutputResponse(val output: String) + +@Serializable +data class HttpStructurePayloadTestStructure( + val content1: String, + val content2: Int, + val content3: Float, +) + +@Serializable +data class TimestampTestRequestResponse( + val default: Double, + val dateTime: String, + val httpDate: String, + val epochSeconds: Double, +) + +@Serializable +data class JsonNameTestRequest(val requestName: String) + +@Serializable +data class JsonNameTestResponse(val responseName: String) diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceEngineFactoryTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceEngineFactoryTest.kt new file mode 100644 index 0000000000..560412f1f5 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceEngineFactoryTest.kt @@ -0,0 +1,51 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import org.junit.jupiter.api.TestInstance +import java.net.ServerSocket +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.test.Test +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class ServiceEngineFactoryTest { + val closeGracePeriodMillis = TestParams.CLOSE_GRACE_PERIOD_MILLIS + val closeTimeoutMillis = TestParams.CLOSE_TIMEOUT_MILLIS + val gracefulWindow = TestParams.GRACEFUL_WINDOW + val requestBodyLimit = TestParams.REQUEST_BODY_LIMIT + val portListenerTimeout = TestParams.PORT_LISTENER_TIMEOUT + + val projectDir: Path = Paths.get("build/service-cbor-test") + + @Test + fun `checks service with netty engine`() { + val nettyPort: Int = ServerSocket(0).use { it.localPort } + val nettyProc = startService("netty", nettyPort, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(nettyPort, portListenerTimeout) + assertTrue(ready, "Service did not start within $portListenerTimeout s") + cleanupService(nettyProc, gracefulWindow) + } + + @Test + fun `checks service with cio engine`() { + val cioPort: Int = ServerSocket(0).use { it.localPort } + val cioProc = startService("cio", cioPort, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(cioPort, portListenerTimeout) + assertTrue(ready, "Service did not start within $portListenerTimeout s") + cleanupService(cioProc, gracefulWindow) + } + + @Test + fun `checks service with jetty jakarta engine`() { + val jettyPort: Int = ServerSocket(0).use { it.localPort } + val jettyProc = startService("jetty-jakarta", jettyPort, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(jettyPort, portListenerTimeout) + assertTrue(ready, "Service did not start within $portListenerTimeout s") + cleanupService(jettyProc, gracefulWindow) + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt new file mode 100644 index 0000000000..ffa3a5275b --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt @@ -0,0 +1,40 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.io.path.exists +import kotlin.test.Test +import kotlin.test.assertTrue + +class ServiceFileTest { + val packageName = "com.cbor" + val packagePath = packageName.replace('.', '/') + + val projectDir: Path = Paths.get("build/service-cbor-test") + + @Test + fun `generates service and all necessary files`() { + assertTrue(projectDir.resolve("build.gradle.kts").exists()) + assertTrue(projectDir.resolve("settings.gradle.kts").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/Main.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/Routing.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/config/ServiceFrameworkConfig.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/framework/ServiceFramework.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/plugins/ContentTypeGuard.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/plugins/ErrorHandler.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/utils/Logging.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/auth/Authentication.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/auth/Validation.kt").exists()) + + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/model/PostTestRequest.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/model/PostTestResponse.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/serde/PostTestOperationSerializer.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/serde/PostTestOperationDeserializer.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/operations/PostTestOperation.kt").exists()) + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/utils.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/utils.kt new file mode 100644 index 0000000000..d4027c18f3 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/utils.kt @@ -0,0 +1,248 @@ +package com.test + +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials +import aws.smithy.kotlin.runtime.auth.awssigning.AwsSignatureType +import aws.smithy.kotlin.runtime.auth.awssigning.AwsSignedBodyHeader +import aws.smithy.kotlin.runtime.auth.awssigning.AwsSigningAlgorithm +import aws.smithy.kotlin.runtime.auth.awssigning.AwsSigningAttributes +import aws.smithy.kotlin.runtime.auth.awssigning.DefaultAwsSigner +import aws.smithy.kotlin.runtime.auth.awssigning.HashSpecification +import aws.smithy.kotlin.runtime.collections.attributesOf +import aws.smithy.kotlin.runtime.http.HttpBody +import aws.smithy.kotlin.runtime.http.HttpMethod +import aws.smithy.kotlin.runtime.http.auth.AwsHttpSigner +import aws.smithy.kotlin.runtime.http.auth.SignHttpRequest +import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder +import aws.smithy.kotlin.runtime.http.request.headers +import aws.smithy.kotlin.runtime.http.request.url +import aws.smithy.kotlin.runtime.net.url.Url +import aws.smithy.kotlin.runtime.time.Instant +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonElement +import org.gradle.testkit.runner.GradleRunner +import java.io.IOException +import java.net.Socket +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.nio.file.Files +import java.nio.file.Path +import java.util.Locale +import java.util.concurrent.TimeUnit +import kotlin.test.assertTrue +import kotlin.test.fail + +internal object TestParams { + const val CLOSE_GRACE_PERIOD_MILLIS: Long = 5_000L + const val CLOSE_TIMEOUT_MILLIS: Long = 1_000L + const val GRACEFUL_WINDOW: Long = CLOSE_GRACE_PERIOD_MILLIS + CLOSE_TIMEOUT_MILLIS + const val REQUEST_BODY_LIMIT: Long = 10L * 1024 * 1024 + const val PORT_LISTENER_TIMEOUT: Long = 60L +} + +internal fun startService( + engineFactory: String = "netty", + port: Int = 8080, + closeGracePeriodMillis: Long = 1000, + closeTimeoutMillis: Long = 1000, + requestBodyLimit: Long = 10L * 1024 * 1024, + projectDir: Path, +): Process { + if (!Files.exists(projectDir.resolve("gradlew"))) { + GradleRunner.create() + .withProjectDir(projectDir.toFile()) + .withArguments( + "wrapper", + "--quiet", + ) + .build() + } + + val gradleCmd = if (isWindows()) "gradlew.bat" else "./gradlew" + val baseCmd = if (isWindows()) listOf("cmd", "/c", gradleCmd) else listOf(gradleCmd) + + return ProcessBuilder( + baseCmd + listOf( + "--no-daemon", + "--quiet", + "run", + "--args=--engineFactory $engineFactory " + + "--port $port " + + "--closeGracePeriodMillis ${closeGracePeriodMillis.toInt()} " + + "--closeTimeoutMillis ${closeTimeoutMillis.toInt()} " + + "--requestBodyLimit $requestBodyLimit", + ), + ) + .directory(projectDir.toFile()) + .redirectErrorStream(true) + .start() +} + +internal fun cleanupService(proc: Process, gracefulWindow: Long = 5_000L) { + val okExitCodes = if (isWindows()) { + setOf(0, 1, 143, -1, -1073741510) + } else { + setOf(0, 143) + } + + try { + proc.destroy() + val exited = proc.waitFor(gracefulWindow, TimeUnit.MILLISECONDS) + + if (!exited) { + proc.destroyForcibly() + fail("Service did not shut down within $gracefulWindow ms") + } + + assertTrue( + proc.exitValue() in okExitCodes, + "Service exited with ${proc.exitValue()} – shutdown not graceful?", + ) + } catch (e: Exception) { + proc.destroyForcibly() + throw e + } +} + +private fun isWindows() = System.getProperty("os.name").lowercase().contains("windows") + +internal fun waitForPort(port: Int, timeoutSec: Long = 180): Boolean { + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toNanos(timeoutSec) + while (System.currentTimeMillis() < deadline) { + try { + Socket("localhost", port).use { + return true // Port is available + } + } catch (e: IOException) { + Thread.sleep(100) + } + } + return false +} + +data class AwsSigningOptions( + val credentials: Credentials, + val service: String, // e.g., "execute-api", "s3", "es", "kinesis" + val region: String?, // e.g., "us-west-2" (null when using SigV4A) + val algorithm: AwsSigningAlgorithm, // AwsSigningAlgorithm.SIGV4 or SIGV4A + val signatureType: AwsSignatureType = AwsSignatureType.HTTP_REQUEST_VIA_HEADERS, + val signedBodyHeader: AwsSignedBodyHeader = AwsSignedBodyHeader.NONE, + val useDoubleUriEncode: Boolean = true, + val normalizeUriPath: Boolean = true, +) + +@OptIn(InternalApi::class) +internal fun sendRequest( + url: String, + method: String, + data: Any? = null, + contentType: String? = null, + acceptType: String? = null, + bearerToken: String? = null, + headers: Map = emptyMap(), + awsSigning: AwsSigningOptions? = null, +): HttpResponse<*> { + require(!(awsSigning != null && bearerToken != null)) { + "Cannot use bearerToken and awsSigning together." + } + val httpMethod = method.uppercase(Locale.ROOT) + require(!(httpMethod in setOf("GET", "HEAD") && data != null)) { + "GET/HEAD with a body is not supported." + } + + val client = HttpClient.newHttpClient() + + val uri = URI.create(url) + val defaultPort = if (uri.scheme.equals("https", true)) 443 else 80 + val hostHeader = buildString { + append(uri.host) + if (uri.port != -1 && uri.port != defaultPort) append(":${uri.port}") + } + val baseHeaders = linkedMapOf().apply { + put("Host", hostHeader) + contentType?.let { put("Content-Type", it) } + acceptType?.let { put("Accept", it) } + putAll(headers) + bearerToken?.let { put("Authorization", "Bearer $it") } + } + + val bodyPublisher = when (data) { + null -> HttpRequest.BodyPublishers.noBody() + is ByteArray -> HttpRequest.BodyPublishers.ofByteArray(data) + is String -> HttpRequest.BodyPublishers.ofString(data) + is JsonElement -> HttpRequest.BodyPublishers.ofString(data.toString()) + else -> error("Unsupported body type: ${data::class.qualifiedName}") + } + + val signedHeaders = if (awsSigning != null) { + requireNotNull(awsSigning.region) { "awsSigning.region is required." } + + val unsigned = awsSigning.signedBodyHeader == AwsSignedBodyHeader.NONE + val signer = AwsHttpSigner( + AwsHttpSigner.Config().apply { + this.signer = DefaultAwsSigner + this.service = awsSigning.service + this.isUnsignedPayload = unsigned + this.algorithm = awsSigning.algorithm + }, + ) + + val ktorReq = HttpRequestBuilder().apply { + this.method = HttpMethod.parse(httpMethod) + url(Url.parse(url)) + headers { + baseHeaders.forEach { (k, v) -> append(k, v) } + } + if (!unsigned) { + body = when { + data is ByteArray -> HttpBody.fromBytes(data) + data is String -> HttpBody.fromBytes(data.toByteArray()) + data == null -> HttpBody.Empty + else -> error("Unsupported body type: ${data::class.qualifiedName}") + } + } + } + + val attrs = attributesOf { + Credentials to awsSigning.credentials + AwsSigningAttributes.SigningRegion to awsSigning.region + AwsSigningAttributes.SigningDate to Instant.now() + if (unsigned) { + AwsSigningAttributes.HashSpecification to HashSpecification.UnsignedPayload + AwsSigningAttributes.SignedBodyHeader to AwsSignedBodyHeader.X_AMZ_CONTENT_SHA256 + } else { + AwsSigningAttributes.HashSpecification to HashSpecification.CalculateFromPayload + AwsSigningAttributes.SignedBodyHeader to AwsSignedBodyHeader.X_AMZ_CONTENT_SHA256 + } + } + + runBlocking { + signer.sign(SignHttpRequest(ktorReq, awsSigning.credentials, attrs)) + } + + ktorReq.headers.build().entries().associate { (k, vs) -> k to vs.joinToString(",") } + } else { + baseHeaders + } + + val builder = HttpRequest.newBuilder().uri(uri) + signedHeaders.forEach { (k, v) -> + if (!k.equals("Host", ignoreCase = true)) { + builder.header(k, v) + } + } + val request = builder.method(httpMethod, bodyPublisher).build() + + val bodyHandler = + if (acceptType?.contains("json", true) == true || + acceptType?.startsWith("text", true) == true + ) { + HttpResponse.BodyHandlers.ofString() + } else { + HttpResponse.BodyHandlers.ofByteArray() + } + + return client.send(request, bodyHandler) +} diff --git a/tests/codegen/waiter-tests/build.gradle.kts b/tests/codegen/waiter-tests/build.gradle.kts index 72b244a6c9..992a4d471e 100644 --- a/tests/codegen/waiter-tests/build.gradle.kts +++ b/tests/codegen/waiter-tests/build.gradle.kts @@ -38,7 +38,7 @@ kotlin.sourceSets.getByName("main") { tasks.withType { dependsOn(tasks.generateSmithyProjections) - kotlinOptions { + compilerOptions { // generated code has warnings unfortunately allWarningsAsErrors = false }