Skip to content

Commit 8ce9b6f

Browse files
authored
feat: continue header (#802)
1 parent 9305116 commit 8ce9b6f

File tree

5 files changed

+117
-2
lines changed

5 files changed

+117
-2
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "8f1ff251-7fed-4543-a077-51b2fe0aa684",
3+
"type": "feature",
4+
"description": "Add an interceptor for adding `Expect: 100-continue` headers to HTTP requests",
5+
"issues": [
6+
"awslabs/aws-sdk-kotlin#839"
7+
]
8+
}

codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/ModelTestUtils.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin
1010
import software.amazon.smithy.kotlin.codegen.KotlinSettings
1111
import software.amazon.smithy.kotlin.codegen.core.*
1212
import software.amazon.smithy.kotlin.codegen.inferService
13+
import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
1314
import software.amazon.smithy.kotlin.codegen.model.OperationNormalizer
1415
import software.amazon.smithy.kotlin.codegen.model.shapes
1516
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
@@ -116,6 +117,7 @@ fun Model.newTestContext(
116117
packageName: String = TestModelDefault.NAMESPACE,
117118
settings: KotlinSettings = this.defaultSettings(serviceName, packageName),
118119
generator: ProtocolGenerator = MockHttpProtocolGenerator(),
120+
integrations: List<KotlinIntegration> = listOf(),
119121
): TestContext {
120122
val manifest = MockManifest()
121123
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model = this, rootNamespace = packageName, serviceName = serviceName)
@@ -127,7 +129,7 @@ fun Model.newTestContext(
127129
this,
128130
service,
129131
provider,
130-
listOf(),
132+
integrations,
131133
generator.protocol,
132134
delegator,
133135
)
@@ -238,14 +240,18 @@ fun String.prependNamespaceAndService(
238240
protocol.annotation to imports + listOf(protocol.import)
239241
}
240242

241-
val importExpr = modelImports.map { "use $it" }.joinToString(separator = "\n") { it }
243+
val importExpr = modelImports
244+
.map { "use $it" }
245+
.plus("use aws.api#service")
246+
.joinToString(separator = "\n")
242247

243248
return (
244249
"""
245250
$versionExpr
246251
namespace $namespace
247252
$importExpr
248253
$modelProtocol
254+
@service(sdkId: "$serviceName")
249255
service $serviceName {
250256
version: "${TestModelDefault.MODEL_VERSION}",
251257
operations: $operations

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ object RuntimeTypes {
9292
}
9393

9494
object Interceptors : RuntimeTypePackage(KotlinDependency.HTTP, "interceptors") {
95+
val ContinueInterceptor = symbol("ContinueInterceptor")
9596
val HttpInterceptor = symbol("HttpInterceptor")
9697
val Md5ChecksumInterceptor = symbol("Md5ChecksumInterceptor")
9798
val FlexibleChecksumsRequestInterceptor = symbol("FlexibleChecksumsRequestInterceptor")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.smithy.kotlin.runtime.http.interceptors
6+
7+
import aws.smithy.kotlin.runtime.InternalApi
8+
import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext
9+
import aws.smithy.kotlin.runtime.http.request.HttpRequest
10+
import aws.smithy.kotlin.runtime.http.request.header
11+
import aws.smithy.kotlin.runtime.http.request.toBuilder
12+
13+
/**
14+
* An interceptor that adds an HTTP `Expect: 100-continue` header to requests with bodies at a certain length threshold.
15+
* Bodies with an unset `contentLength` will get the continue header added regardless of length.
16+
* @param thresholdLengthBytes The body length (in bytes) at which a continue header will be set. Bodies under this
17+
* length will not get a continue header.
18+
*/
19+
@InternalApi
20+
public class ContinueInterceptor(public val thresholdLengthBytes: Long) : HttpInterceptor {
21+
override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext<Any, HttpRequest>): HttpRequest {
22+
val req = context.protocolRequest
23+
24+
return if ((req.body.contentLength ?: Long.MAX_VALUE) >= thresholdLengthBytes) {
25+
req
26+
.toBuilder()
27+
.apply { header("Expect", "100-continue") }
28+
.build()
29+
} else {
30+
req
31+
}
32+
}
33+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.smithy.kotlin.runtime.http.interceptors
6+
7+
import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext
8+
import aws.smithy.kotlin.runtime.http.HttpBody
9+
import aws.smithy.kotlin.runtime.http.HttpMethod
10+
import aws.smithy.kotlin.runtime.http.request.HttpRequest
11+
import aws.smithy.kotlin.runtime.http.request.header
12+
import aws.smithy.kotlin.runtime.http.request.url
13+
import aws.smithy.kotlin.runtime.io.SdkSource
14+
import aws.smithy.kotlin.runtime.net.Url
15+
import aws.smithy.kotlin.runtime.operation.ExecutionContext
16+
import kotlinx.coroutines.ExperimentalCoroutinesApi
17+
import kotlinx.coroutines.test.runTest
18+
import kotlin.test.Test
19+
import kotlin.test.assertEquals
20+
import kotlin.test.assertNull
21+
import kotlin.test.fail
22+
23+
@OptIn(ExperimentalCoroutinesApi::class)
24+
class ContinueInterceptorTest {
25+
private fun context(request: HttpRequest) = object : ProtocolRequestInterceptorContext<Any, HttpRequest> {
26+
override val protocolRequest: HttpRequest = request
27+
override val executionContext: ExecutionContext get() = fail("Shouldn't have invoked `executionContext`")
28+
override val request: Any get() = fail("Shouldn't have invoked `request`")
29+
}
30+
31+
private fun request(contentLength: Long?) = HttpRequest {
32+
method = HttpMethod.POST
33+
url(Url.parse("https://localhost"))
34+
header("foo", "bar")
35+
body = object : HttpBody.SourceContent() {
36+
override val contentLength: Long? = contentLength
37+
override fun readFrom(): SdkSource = fail("Shouldn't have invoked `readFrom`")
38+
}
39+
}
40+
41+
@Test
42+
fun testInterceptorSmallBody() = runTest {
43+
val input = request(50)
44+
val interceptor = ContinueInterceptor(100)
45+
val output = interceptor.modifyBeforeSigning(context(input))
46+
assertEquals("bar", output.headers["foo"])
47+
assertNull(output.headers["Expect"])
48+
}
49+
50+
@Test
51+
fun testInterceptorLargeBody() = runTest {
52+
val input = request(150)
53+
val interceptor = ContinueInterceptor(100)
54+
val output = interceptor.modifyBeforeSigning(context(input))
55+
assertEquals("bar", output.headers["foo"])
56+
assertEquals("100-continue", output.headers["Expect"])
57+
}
58+
59+
@Test
60+
fun testInterceptorUnknownLengthBody() = runTest {
61+
val input = request(null)
62+
val interceptor = ContinueInterceptor(100)
63+
val output = interceptor.modifyBeforeSigning(context(input))
64+
assertEquals("bar", output.headers["foo"])
65+
assertEquals("100-continue", output.headers["Expect"])
66+
}
67+
}

0 commit comments

Comments
 (0)