|
| 1 | +/* |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package aws.smithy.kotlin.runtime.auth.awssigning |
| 7 | + |
| 8 | +import aws.smithy.kotlin.runtime.http.Headers |
| 9 | +import aws.smithy.kotlin.runtime.io.SdkByteReadChannel |
| 10 | +import aws.smithy.kotlin.runtime.util.InternalApi |
| 11 | + |
| 12 | +public const val CHUNK_SIZE_BYTES: Int = 65536 |
| 13 | + |
| 14 | +@InternalApi |
| 15 | +public abstract class AbstractAwsChunkedByteReadChannel( |
| 16 | + private val chan: SdkByteReadChannel, |
| 17 | + private val signer: AwsSigner, |
| 18 | + private val signingConfig: AwsSigningConfig, |
| 19 | + private var previousSignature: ByteArray, |
| 20 | + private val trailingHeaders: Headers = Headers.Empty, |
| 21 | +) : SdkByteReadChannel by chan { |
| 22 | + override val isClosedForRead: Boolean |
| 23 | + get() = chan.isClosedForRead && (chunk == null || chunkOffset >= chunk!!.size) && hasLastChunkBeenSent |
| 24 | + |
| 25 | + internal var chunk: ByteArray? = null |
| 26 | + internal var chunkOffset: Int = 0 |
| 27 | + private var hasLastChunkBeenSent: Boolean = false |
| 28 | + |
| 29 | + /** |
| 30 | + * Returns all the bytes remaining in the underlying data source, up to [limit]. |
| 31 | + * @return a [ByteArray] containing at most [limit] bytes. it may contain fewer if there are less than [limit] bytes |
| 32 | + * remaining in the data source. |
| 33 | + */ |
| 34 | + override suspend fun readRemaining(limit: Int): ByteArray { |
| 35 | + if (!ensureValidChunk()) { |
| 36 | + return byteArrayOf() |
| 37 | + } |
| 38 | + |
| 39 | + var bytesWritten = 0 |
| 40 | + val bytes = ByteArray(limit) |
| 41 | + |
| 42 | + while (bytesWritten != limit) { |
| 43 | + val numBytesToWrite: Int = minOf(limit - bytesWritten, chunk!!.size - chunkOffset) |
| 44 | + |
| 45 | + chunk!!.copyInto(bytes, bytesWritten, chunkOffset, chunkOffset + numBytesToWrite) |
| 46 | + |
| 47 | + bytesWritten += numBytesToWrite |
| 48 | + chunkOffset += numBytesToWrite |
| 49 | + |
| 50 | + // read a new chunk. this handles the case where we consumed the whole chunk but still have not sent `limit` bytes |
| 51 | + if (!ensureValidChunk()) { break } |
| 52 | + } |
| 53 | + |
| 54 | + return bytes.sliceArray(0 until bytesWritten) |
| 55 | + } |
| 56 | + |
| 57 | + /** |
| 58 | + * Writes [length] bytes to [sink], starting [offset] bytes from the beginning. If [length] bytes are not available in |
| 59 | + * the source data, the call will fail with an [IllegalArgumentException]. |
| 60 | + * |
| 61 | + * @param sink the destination [ByteArray] to write to |
| 62 | + * @param offset the number of bytes in [sink] to skip before beginning to write |
| 63 | + * @param length the number of bytes to write to [sink] |
| 64 | + * @throws IllegalArgumentException when illegal [offset] and [length] arguments are passed |
| 65 | + * @throws RuntimeException when the source data is exhausted before [length] bytes are written to [sink] |
| 66 | + */ |
| 67 | + override suspend fun readFully(sink: ByteArray, offset: Int, length: Int) { |
| 68 | + require(offset >= 0) { "Invalid read: offset must be positive: $offset" } |
| 69 | + require(offset + length <= sink.size) { "Invalid read: offset + length should be less than the destination size: $offset + $length < ${sink.size}" } |
| 70 | + if (length == 0) return |
| 71 | + |
| 72 | + var bytesWritten = 0 |
| 73 | + |
| 74 | + while (bytesWritten != length) { |
| 75 | + if (!ensureValidChunk()) { |
| 76 | + throw RuntimeException("Invalid read: unable to fully read $length bytes. missing ${length - bytesWritten} bytes.") |
| 77 | + } |
| 78 | + |
| 79 | + val numBytesToWrite: Int = minOf(length, chunk!!.size - chunkOffset) |
| 80 | + |
| 81 | + chunk!!.copyInto(sink, offset + bytesWritten, chunkOffset, chunkOffset + numBytesToWrite) |
| 82 | + |
| 83 | + bytesWritten += numBytesToWrite |
| 84 | + chunkOffset += numBytesToWrite |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + /** |
| 89 | + * Writes up to [length] bytes to [sink], starting [offset] bytes from the beginning. |
| 90 | + * Returns when [length] bytes or the number of available bytes have been written, whichever is lower. |
| 91 | + * |
| 92 | + * This function will read *at most* one chunk of data into the [sink]. Successive calls will be required to read additional chunks. |
| 93 | + * This is done because the function promises to not suspend unless there are zero bytes currently available, |
| 94 | + * and we are unable to poll the underlying data source to see if there is a whole chunk available. |
| 95 | + * |
| 96 | + * @param sink the [ByteArray] to write the data to |
| 97 | + * @param offset the number of bytes to skip from the beginning of the chunk |
| 98 | + * @param length the maximum number of bytes to write to [sink]. the actual number of bytes written may be fewer if |
| 99 | + * there are less immediately available. |
| 100 | + * @throws IllegalArgumentException when illegal [offset] and [length] arguments are passed |
| 101 | + * @return an [Int] representing the number of bytes written |
| 102 | + */ |
| 103 | + override suspend fun readAvailable(sink: ByteArray, offset: Int, length: Int): Int { |
| 104 | + require(offset >= 0) { "Invalid read: offset must be positive: $offset" } |
| 105 | + require(offset + length <= sink.size) { "Invalid read: offset + length should be less than the destination size: $offset + $length < ${sink.size}" } |
| 106 | + if (length == 0 || !ensureValidChunk()) { |
| 107 | + return 0 |
| 108 | + } |
| 109 | + |
| 110 | + var bytesWritten = 0 |
| 111 | + |
| 112 | + while (bytesWritten != length) { |
| 113 | + val numBytesToWrite = minOf(length, chunk!!.size - chunkOffset) |
| 114 | + |
| 115 | + chunk!!.copyInto(sink, offset + bytesWritten, chunkOffset, chunkOffset + numBytesToWrite) |
| 116 | + |
| 117 | + bytesWritten += numBytesToWrite |
| 118 | + chunkOffset += numBytesToWrite |
| 119 | + |
| 120 | + // if we've exhausted the current chunk, exit without suspending for a new one |
| 121 | + if (chunkOffset >= chunk!!.size) { break } |
| 122 | + } |
| 123 | + |
| 124 | + return bytesWritten |
| 125 | + } |
| 126 | + |
| 127 | + /** |
| 128 | + * Ensures that the internal [chunk] is valid for reading. If it's not valid, try to load the next chunk. Note that |
| 129 | + * this function will suspend until the whole chunk has been loaded. |
| 130 | + * |
| 131 | + * @return true if the [chunk] is valid for reading, false if it's invalid (chunk data is exhausted) |
| 132 | + */ |
| 133 | + internal suspend fun ensureValidChunk(): Boolean { |
| 134 | + // check if the current chunk is still valid |
| 135 | + if (chunk != null && chunkOffset < chunk!!.size) { return true } |
| 136 | + |
| 137 | + // if not, try to fetch a new chunk |
| 138 | + val nextChunk = if (chan.isClosedForRead && hasLastChunkBeenSent) { |
| 139 | + null |
| 140 | + } else if (chan.isClosedForRead && !hasLastChunkBeenSent) { |
| 141 | + hasLastChunkBeenSent = true |
| 142 | + getChunk(byteArrayOf()) + if (!trailingHeaders.isEmpty()) { getTrailingHeadersChunk(trailingHeaders) } else byteArrayOf() |
| 143 | + } else { |
| 144 | + getChunk() |
| 145 | + } |
| 146 | + |
| 147 | + chunkOffset = 0 |
| 148 | + chunk = nextChunk?.plus("\r\n".encodeToByteArray()) // terminating CRLF to signal end of chunk |
| 149 | + return (chunk != null) |
| 150 | + } |
| 151 | + |
| 152 | + /** |
| 153 | + * Get an aws-chunked encoding of [data]. |
| 154 | + * If [data] is not set, read the next chunk from [chan] and add hex-formatted chunk size and chunk signature to the front. |
| 155 | + * Note that this function will suspend until the whole chunk has been read. |
| 156 | + * The chunk structure is: `string(IntHexBase(chunk-size)) + ";chunk-signature=" + signature + \r\n + chunk-data + \r\n` |
| 157 | + * |
| 158 | + * @param data the ByteArray of data which will be encoded to aws-chunked. if not provided, will default to |
| 159 | + * reading up to [CHUNK_SIZE_BYTES] from [chan]. |
| 160 | + * @return a ByteArray containing the chunked data |
| 161 | + */ |
| 162 | + private suspend fun getChunk(data: ByteArray? = null): ByteArray { |
| 163 | + val chunkBody = data ?: chan.readRemaining(CHUNK_SIZE_BYTES) |
| 164 | + |
| 165 | + val chunkSignature = signer.signChunk(chunkBody, previousSignature, signingConfig).signature |
| 166 | + previousSignature = chunkSignature |
| 167 | + |
| 168 | + val chunkHeader = buildString { |
| 169 | + append(chunkBody.size.toString(16)) |
| 170 | + append(";") |
| 171 | + append("chunk-signature=") |
| 172 | + append(chunkSignature.decodeToString()) |
| 173 | + append("\r\n") |
| 174 | + }.encodeToByteArray() |
| 175 | + |
| 176 | + return chunkHeader + chunkBody |
| 177 | + } |
| 178 | + |
| 179 | + /** |
| 180 | + * Get the trailing headers chunk. The grammar for trailing headers is: |
| 181 | + * trailing-header-A:value CRLF |
| 182 | + * trailing-header-B:value CRLF |
| 183 | + * ... |
| 184 | + * x-amz-trailer-signature:signature_value CRLF |
| 185 | + * |
| 186 | + * @param trailingHeaders a list of [Headers] which will be sent |
| 187 | + * @return a [ByteArray] containing the trailing headers in aws-chunked encoding, ready to send on the wire |
| 188 | + */ |
| 189 | + private suspend fun getTrailingHeadersChunk(trailingHeaders: Headers): ByteArray { |
| 190 | + val trailerSignature = signer.signChunkTrailer(trailingHeaders, previousSignature, signingConfig).signature |
| 191 | + previousSignature = trailerSignature |
| 192 | + |
| 193 | + val trailerBody = trailingHeaders.entries().sortedBy { entry -> entry.key.lowercase() }.map { entry -> |
| 194 | + buildString { |
| 195 | + append(entry.key) |
| 196 | + append(":") |
| 197 | + append(entry.value.joinToString(",") { v -> v.trim() }) |
| 198 | + append("\r\n") |
| 199 | + }.encodeToByteArray() |
| 200 | + }.reduce { acc, bytes -> acc + bytes } + |
| 201 | + "x-amz-trailer-signature:${trailerSignature.decodeToString()}\r\n".encodeToByteArray() |
| 202 | + |
| 203 | + chunkOffset = 0 |
| 204 | + return trailerBody |
| 205 | + } |
| 206 | +} |
0 commit comments