Skip to content

Commit aa714d9

Browse files
committed
fix: pr feedback v1
1 parent 9beca23 commit aa714d9

File tree

11 files changed

+162
-95
lines changed

11 files changed

+162
-95
lines changed

runtime/protocol/http-client/api/http-client.api

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksums
341341
public class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor : aws/smithy/kotlin/runtime/client/Interceptor {
342342
public static final field Companion Laws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor$Companion;
343343
public fun <init> (ZLaws/smithy/kotlin/runtime/client/config/ResponseHttpChecksumConfig;)V
344-
public fun ignoreChecksum (Ljava/lang/String;)Z
344+
public fun ignoreChecksum (Ljava/lang/String;Laws/smithy/kotlin/runtime/telemetry/logging/Logger;)Z
345345
public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
346346
public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
347347
public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;

runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor.kt

Lines changed: 3 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55

66
package aws.smithy.kotlin.runtime.http.interceptors
77

8-
import aws.smithy.kotlin.runtime.ClientException
98
import aws.smithy.kotlin.runtime.InternalApi
10-
import aws.smithy.kotlin.runtime.businessmetrics.BusinessMetric
11-
import aws.smithy.kotlin.runtime.businessmetrics.SmithyBusinessMetric
129
import aws.smithy.kotlin.runtime.businessmetrics.emitBusinessMetric
1310
import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext
1411
import aws.smithy.kotlin.runtime.client.config.RequestHttpChecksumConfig
@@ -84,7 +81,10 @@ public class FlexibleChecksumsRequestInterceptor(
8481

8582
request.headers.append("x-amz-trailer", checksumHeader)
8683
request.trailingHeaders.append(checksumHeader, deferredChecksum)
84+
8785
context.executionContext.emitBusinessMetric(checksumAlgorithm.toBusinessMetric())
86+
87+
return request.build()
8888
} else { // Delegate checksum calculation to super class, calculateChecksum, and applyChecksum
8989
return super.modifyBeforeSigning(context)
9090
}
@@ -157,25 +157,6 @@ public class FlexibleChecksumsRequestInterceptor(
157157
return request.build()
158158
}
159159

160-
// FIXME this duplicates the logic from aws-signing-common, but can't import from there due to circular import.
161-
private val HttpBody.isEligibleForAwsChunkedStreaming: Boolean
162-
get() = (this is HttpBody.SourceContent || this is HttpBody.ChannelContent) &&
163-
contentLength != null &&
164-
(isOneShot || contentLength!! > 65536 * 16)
165-
166-
/**
167-
* Compute the rolling hash of an [SdkByteReadChannel] using [hashFunction], reading up-to [bufferSize] bytes into memory
168-
* @return a ByteArray of the hash function's digest
169-
*/
170-
private suspend fun SdkByteReadChannel.rollingHash(hashFunction: HashFunction, bufferSize: Long = 8192): ByteArray {
171-
val buffer = SdkBuffer()
172-
while (!isClosedForRead) {
173-
read(buffer, bufferSize)
174-
hashFunction.update(buffer.readToByteArray())
175-
}
176-
return hashFunction.digest()
177-
}
178-
179160
/**
180161
* Checks if a user provided a checksum for a request via an HTTP header.
181162
* The header must start with "x-amz-checksum-" followed by the checksum algorithm's name.
@@ -192,17 +173,6 @@ public class FlexibleChecksumsRequestInterceptor(
192173
}
193174
}
194175

195-
/**
196-
* Maps supported hash functions to business metrics.
197-
*/
198-
private fun HashFunction.toBusinessMetric(): BusinessMetric = when (this) {
199-
is Crc32 -> SmithyBusinessMetric.FLEXIBLE_CHECKSUMS_REQ_CRC32
200-
is Crc32c -> SmithyBusinessMetric.FLEXIBLE_CHECKSUMS_REQ_CRC32C
201-
is Sha1 -> SmithyBusinessMetric.FLEXIBLE_CHECKSUMS_REQ_SHA1
202-
is Sha256 -> SmithyBusinessMetric.FLEXIBLE_CHECKSUMS_REQ_SHA256
203-
else -> throw IllegalStateException("Checksum was calculated using an unsupported hash function: ${this::class.simpleName}")
204-
}
205-
206176
/**
207177
* Removes all checksum headers except [headerName]
208178
* @param headerName the checksum header name to keep
@@ -211,44 +181,4 @@ public class FlexibleChecksumsRequestInterceptor(
211181
names()
212182
.filter { it.startsWith("x-amz-checksum-", ignoreCase = true) && !it.equals(headerName, ignoreCase = true) }
213183
.forEach { remove(it) }
214-
215-
/**
216-
* Convert an [HttpBody] with an underlying [HashingSource] or [HashingByteReadChannel]
217-
* to a [CompletingSource] or [CompletingByteReadChannel], respectively.
218-
*/
219-
private fun HttpBody.toCompletingBody(deferred: CompletableDeferred<String>): HttpBody = when (this) {
220-
is HttpBody.SourceContent -> CompletingSource(deferred, (readFrom() as HashingSource)).toHttpBody(contentLength)
221-
is HttpBody.ChannelContent -> CompletingByteReadChannel(deferred, (readFrom() as HashingByteReadChannel)).toHttpBody(contentLength)
222-
else -> throw ClientException("HttpBody type is not supported")
223-
}
224-
225-
/**
226-
* An [SdkSource] which uses the underlying [hashingSource]'s checksum to complete a [CompletableDeferred] value.
227-
*/
228-
internal class CompletingSource(
229-
private val deferred: CompletableDeferred<String>,
230-
private val hashingSource: HashingSource,
231-
) : SdkSource by hashingSource {
232-
override fun read(sink: SdkBuffer, limit: Long): Long = hashingSource.read(sink, limit)
233-
.also {
234-
if (it == -1L) {
235-
deferred.complete(hashingSource.digest().encodeBase64String())
236-
}
237-
}
238-
}
239-
240-
/**
241-
* An [SdkByteReadChannel] which uses the underlying [hashingChannel]'s checksum to complete a [CompletableDeferred] value.
242-
*/
243-
internal class CompletingByteReadChannel(
244-
private val deferred: CompletableDeferred<String>,
245-
private val hashingChannel: HashingByteReadChannel,
246-
) : SdkByteReadChannel by hashingChannel {
247-
override suspend fun read(sink: SdkBuffer, limit: Long): Long = hashingChannel.read(sink, limit)
248-
.also {
249-
if (it == -1L) {
250-
deferred.complete(hashingChannel.digest().encodeBase64String())
251-
}
252-
}
253-
}
254184
}

runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import aws.smithy.kotlin.runtime.http.response.copy
1919
import aws.smithy.kotlin.runtime.http.toHashingBody
2020
import aws.smithy.kotlin.runtime.http.toHttpBody
2121
import aws.smithy.kotlin.runtime.io.*
22+
import aws.smithy.kotlin.runtime.telemetry.logging.Logger
2223
import aws.smithy.kotlin.runtime.telemetry.logging.logger
2324
import aws.smithy.kotlin.runtime.text.encoding.encodeBase64String
2425
import kotlin.coroutines.coroutineContext
@@ -55,20 +56,19 @@ public open class FlexibleChecksumsResponseInterceptor(
5556
}
5657

5758
override suspend fun modifyBeforeDeserialization(context: ProtocolResponseInterceptorContext<Any, HttpRequest, HttpResponse>): HttpResponse {
58-
val logger = coroutineContext.logger<FlexibleChecksumsResponseInterceptor>()
59-
6059
val configuredToVerifyChecksum = responseValidationRequired || responseChecksumValidation == ResponseHttpChecksumConfig.WHEN_SUPPORTED
6160
if (!configuredToVerifyChecksum) return context.protocolResponse
6261

62+
val logger = coroutineContext.logger<FlexibleChecksumsResponseInterceptor>()
63+
6364
val checksumHeader = CHECKSUM_HEADER_VALIDATION_PRIORITY_LIST
6465
.firstOrNull { context.protocolResponse.headers.contains(it) } ?: run {
6566
logger.warn { "Checksum validation was requested but the response headers didn't contain a valid checksum." }
6667
return context.protocolResponse
6768
}
6869

6970
val serviceChecksumValue = context.protocolResponse.headers[checksumHeader]!!
70-
if (ignoreChecksum(serviceChecksumValue)) {
71-
logger.info { "Checksum detected but validation was skipped." }
71+
if (ignoreChecksum(serviceChecksumValue, logger)) {
7272
return context.protocolResponse
7373
}
7474

@@ -110,7 +110,7 @@ public open class FlexibleChecksumsResponseInterceptor(
110110
/**
111111
* Additional check on the checksum itself to see if it should be validated
112112
*/
113-
public open fun ignoreChecksum(checksum: String): Boolean = false
113+
public open fun ignoreChecksum(checksum: String, logger: Logger): Boolean = false
114114
}
115115

116116
public class ChecksumMismatchException(message: String?) : ClientException(message)

runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/HttpChecksumRequiredInterceptor.kt

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,73 @@
66
package aws.smithy.kotlin.runtime.http.interceptors
77

88
import aws.smithy.kotlin.runtime.InternalApi
9+
import aws.smithy.kotlin.runtime.businessmetrics.emitBusinessMetric
910
import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext
1011
import aws.smithy.kotlin.runtime.hashing.*
1112
import aws.smithy.kotlin.runtime.http.*
1213
import aws.smithy.kotlin.runtime.http.request.HttpRequest
1314
import aws.smithy.kotlin.runtime.http.request.header
1415
import aws.smithy.kotlin.runtime.http.request.toBuilder
16+
import aws.smithy.kotlin.runtime.io.rollingHash
17+
import aws.smithy.kotlin.runtime.telemetry.logging.logger
1518
import aws.smithy.kotlin.runtime.text.encoding.encodeBase64String
19+
import kotlinx.coroutines.CompletableDeferred
20+
import kotlinx.coroutines.job
21+
import kotlin.coroutines.coroutineContext
1622

1723
/**
1824
* Handles checksum request calculation from the `httpChecksumRequired` trait.
1925
*/
2026
@InternalApi
2127
public class HttpChecksumRequiredInterceptor : AbstractChecksumInterceptor() {
22-
override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext<Any, HttpRequest>): HttpRequest =
28+
override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext<Any, HttpRequest>): HttpRequest {
2329
if (context.defaultChecksumAlgorithmName == null) {
2430
// Don't calculate checksum
25-
context.protocolRequest
26-
} else {
27-
// Delegate checksum calculation to super class, calculateChecksum, and applyChecksum
28-
super.modifyBeforeSigning(context)
31+
return context.protocolRequest
2932
}
3033

34+
val checksumAlgorithmName = context.defaultChecksumAlgorithmName!!
35+
val checksumAlgorithm = checksumAlgorithmName.toHashFunctionOrThrow()
36+
37+
val logger = coroutineContext.logger<HttpChecksumRequiredInterceptor>()
38+
39+
if (context.protocolRequest.body.isEligibleForAwsChunkedStreaming) { // Handle checksum calculation here
40+
logger.debug { "Calculating checksum during transmission using: ${checksumAlgorithm::class.simpleName}" }
41+
42+
val request = context.protocolRequest.toBuilder()
43+
val deferredChecksum = CompletableDeferred<String>(context.executionContext.coroutineContext.job)
44+
val checksumHeader = checksumAlgorithm.resolveChecksumAlgorithmHeaderName()
45+
46+
request.body = request.body
47+
.toHashingBody(checksumAlgorithm, request.body.contentLength)
48+
.toCompletingBody(deferredChecksum)
49+
50+
request.headers.append("x-amz-trailer", checksumHeader)
51+
request.trailingHeaders.append(checksumHeader, deferredChecksum)
52+
53+
context.executionContext.emitBusinessMetric(checksumAlgorithm.toBusinessMetric())
54+
55+
return request.build()
56+
} else { // Delegate checksum calculation to super class, calculateChecksum, and applyChecksum
57+
return super.modifyBeforeSigning(context)
58+
}
59+
}
60+
3161
public override suspend fun calculateChecksum(context: ProtocolRequestInterceptorContext<Any, HttpRequest>): String? {
62+
val req = context.protocolRequest.toBuilder()
3263
val checksumAlgorithmName = context.defaultChecksumAlgorithmName!!
3364
val checksumAlgorithm = checksumAlgorithmName.toHashFunctionOrThrow()
3465

35-
return when (val body = context.protocolRequest.body) {
36-
is HttpBody.Bytes -> {
37-
checksumAlgorithm.update(
38-
body.readAll() ?: byteArrayOf(),
39-
)
40-
checksumAlgorithm.digest().encodeBase64String()
66+
return when {
67+
req.body.contentLength == null && !req.body.isOneShot -> {
68+
val channel = req.body.toSdkByteReadChannel()!!
69+
channel.rollingHash(checksumAlgorithm).encodeBase64String()
70+
}
71+
else -> {
72+
val bodyBytes = req.body.readAll() ?: byteArrayOf()
73+
if (req.body.isOneShot) req.body = bodyBytes.toHttpBody()
74+
bodyBytes.hash(checksumAlgorithm).encodeBase64String()
4175
}
42-
else -> null // TODO: Support other body types
4376
}
4477
}
4578

runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptorTest.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class FlexibleChecksumsRequestInterceptorTest {
143143
val source = byteArray.source()
144144
val completableDeferred = CompletableDeferred<String>()
145145
val hashingSource = HashingSource(hashFunctionName.toHashFunction()!!, source)
146-
val completingSource = FlexibleChecksumsRequestInterceptor.CompletingSource(completableDeferred, hashingSource)
146+
val completingSource = CompletingSource(completableDeferred, hashingSource)
147147

148148
completingSource.read(SdkBuffer(), 1L)
149149
assertFalse(completableDeferred.isCompleted) // deferred value should not be completed because the source is not exhausted
@@ -165,7 +165,7 @@ class FlexibleChecksumsRequestInterceptorTest {
165165
val completableDeferred = CompletableDeferred<String>()
166166
val hashingChannel = HashingByteReadChannel(hashFunctionName.toHashFunction()!!, channel)
167167
val completingChannel =
168-
FlexibleChecksumsRequestInterceptor.CompletingByteReadChannel(completableDeferred, hashingChannel)
168+
CompletingByteReadChannel(completableDeferred, hashingChannel)
169169

170170
completingChannel.read(SdkBuffer(), 1L)
171171
assertFalse(completableDeferred.isCompleted)

runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/HttpChecksumRequiredInterceptorTest.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class HttpChecksumRequiredInterceptorTest {
6666
}
6767

6868
@Test
69-
fun itOnlySetsHeaderForBytesContent() = runTest {
69+
fun itSetsHeaderForNonBytesContent() = runTest {
7070
val req = HttpRequestBuilder().apply {
7171
body = object : HttpBody.ChannelContent() {
7272
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel("fooey".encodeToByteArray())
@@ -79,9 +79,10 @@ class HttpChecksumRequiredInterceptorTest {
7979
HttpChecksumRequiredInterceptor(),
8080
)
8181

82+
val expected = "vJLiaOiNxaxdWfYAYzdzFQ=="
8283
op.roundTrip(client, Unit)
8384
val call = op.context.attributes[HttpOperationContext.HttpCallList].first()
84-
assertNull(call.request.headers["Content-MD5"])
85+
assertEquals(expected, call.request.headers["Content-MD5"])
8586
}
8687

8788
@Test

runtime/protocol/http/api/http.api

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
public final class aws/smithy/kotlin/runtime/http/CompletingByteReadChannel : aws/smithy/kotlin/runtime/io/SdkByteReadChannel {
2+
public fun <init> (Lkotlinx/coroutines/CompletableDeferred;Laws/smithy/kotlin/runtime/io/HashingByteReadChannel;)V
3+
public fun cancel (Ljava/lang/Throwable;)Z
4+
public fun getAvailableForRead ()I
5+
public fun getClosedCause ()Ljava/lang/Throwable;
6+
public fun isClosedForRead ()Z
7+
public fun isClosedForWrite ()Z
8+
public fun read (Laws/smithy/kotlin/runtime/io/SdkBuffer;JLkotlin/coroutines/Continuation;)Ljava/lang/Object;
9+
}
10+
11+
public final class aws/smithy/kotlin/runtime/http/CompletingSource : aws/smithy/kotlin/runtime/io/SdkSource {
12+
public fun <init> (Lkotlinx/coroutines/CompletableDeferred;Laws/smithy/kotlin/runtime/io/HashingSource;)V
13+
public fun close ()V
14+
public fun read (Laws/smithy/kotlin/runtime/io/SdkBuffer;J)J
15+
}
16+
117
public abstract interface class aws/smithy/kotlin/runtime/http/DeferredHeaders : aws/smithy/kotlin/runtime/collections/ValuesMap {
218
public static final field Companion Laws/smithy/kotlin/runtime/http/DeferredHeaders$Companion;
319
}
@@ -86,8 +102,10 @@ public abstract class aws/smithy/kotlin/runtime/http/HttpBody$SourceContent : aw
86102
}
87103

88104
public final class aws/smithy/kotlin/runtime/http/HttpBodyKt {
105+
public static final fun isEligibleForAwsChunkedStreaming (Laws/smithy/kotlin/runtime/http/HttpBody;)Z
89106
public static final fun readAll (Laws/smithy/kotlin/runtime/http/HttpBody;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
90107
public static final fun toByteStream (Laws/smithy/kotlin/runtime/http/HttpBody;)Laws/smithy/kotlin/runtime/content/ByteStream;
108+
public static final fun toCompletingBody (Laws/smithy/kotlin/runtime/http/HttpBody;Lkotlinx/coroutines/CompletableDeferred;)Laws/smithy/kotlin/runtime/http/HttpBody;
91109
public static final fun toHashingBody (Laws/smithy/kotlin/runtime/http/HttpBody;Laws/smithy/kotlin/runtime/hashing/HashFunction;Ljava/lang/Long;)Laws/smithy/kotlin/runtime/http/HttpBody;
92110
public static final fun toHttpBody (Laws/smithy/kotlin/runtime/content/ByteStream;)Laws/smithy/kotlin/runtime/http/HttpBody;
93111
public static final fun toHttpBody (Laws/smithy/kotlin/runtime/io/SdkByteReadChannel;Ljava/lang/Long;)Laws/smithy/kotlin/runtime/http/HttpBody;

runtime/protocol/http/common/src/aws/smithy/kotlin/runtime/http/HttpBody.kt

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import aws.smithy.kotlin.runtime.content.ByteStream
1010
import aws.smithy.kotlin.runtime.hashing.HashFunction
1111
import aws.smithy.kotlin.runtime.http.content.ByteArrayContent
1212
import aws.smithy.kotlin.runtime.io.*
13+
import aws.smithy.kotlin.runtime.text.encoding.encodeBase64String
14+
import kotlinx.coroutines.CompletableDeferred
1315
import kotlinx.coroutines.CoroutineScope
1416

1517
/**
@@ -191,6 +193,49 @@ public fun HttpBody.toHashingBody(
191193
else -> throw ClientException("HttpBody type is not supported")
192194
}
193195

196+
/**
197+
* Convert an [HttpBody] with an underlying [HashingSource] or [HashingByteReadChannel]
198+
* to a [CompletingSource] or [CompletingByteReadChannel], respectively.
199+
*/
200+
@InternalApi
201+
public fun HttpBody.toCompletingBody(deferred: CompletableDeferred<String>): HttpBody = when (this) {
202+
is HttpBody.SourceContent -> CompletingSource(deferred, (readFrom() as HashingSource)).toHttpBody(contentLength)
203+
is HttpBody.ChannelContent -> CompletingByteReadChannel(deferred, (readFrom() as HashingByteReadChannel)).toHttpBody(contentLength)
204+
else -> throw ClientException("HttpBody type is not supported")
205+
}
206+
207+
/**
208+
* An [SdkSource] which uses the underlying [hashingSource]'s checksum to complete a [CompletableDeferred] value.
209+
*/
210+
@InternalApi
211+
public class CompletingSource(
212+
private val deferred: CompletableDeferred<String>,
213+
private val hashingSource: HashingSource,
214+
) : SdkSource by hashingSource {
215+
override fun read(sink: SdkBuffer, limit: Long): Long = hashingSource.read(sink, limit)
216+
.also {
217+
if (it == -1L) {
218+
deferred.complete(hashingSource.digest().encodeBase64String())
219+
}
220+
}
221+
}
222+
223+
/**
224+
* An [SdkByteReadChannel] which uses the underlying [hashingChannel]'s checksum to complete a [CompletableDeferred] value.
225+
*/
226+
@InternalApi
227+
public class CompletingByteReadChannel(
228+
private val deferred: CompletableDeferred<String>,
229+
private val hashingChannel: HashingByteReadChannel,
230+
) : SdkByteReadChannel by hashingChannel {
231+
override suspend fun read(sink: SdkBuffer, limit: Long): Long = hashingChannel.read(sink, limit)
232+
.also {
233+
if (it == -1L) {
234+
deferred.complete(hashingChannel.digest().encodeBase64String())
235+
}
236+
}
237+
}
238+
194239
// FIXME - replace/move to reading to SdkBuffer instead
195240
/**
196241
* Consume the [HttpBody] and pull the entire contents into memory as a [ByteArray].
@@ -244,3 +289,10 @@ public fun HttpBody.toSdkByteReadChannel(scope: CoroutineScope? = null): SdkByte
244289
is HttpBody.ChannelContent -> body.readFrom()
245290
is HttpBody.SourceContent -> body.readFrom().toSdkByteReadChannel(scope)
246291
}
292+
293+
// FIXME this duplicates the logic from aws-signing-common
294+
@InternalApi
295+
public val HttpBody.isEligibleForAwsChunkedStreaming: Boolean
296+
get() = (this is HttpBody.SourceContent || this is HttpBody.ChannelContent) &&
297+
contentLength != null &&
298+
(isOneShot || contentLength!! > 65536 * 16)

0 commit comments

Comments
 (0)