Skip to content

Commit 1667167

Browse files
authored
feat(rt): add conversions to and from ByteStream and Flow<ByteArray> (#947)
1 parent 71f20e9 commit 1667167

File tree

6 files changed

+305
-27
lines changed

6 files changed

+305
-27
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "f82c0433-30f9-4246-8f18-91402c5ac0ab",
3+
"type": "feature",
4+
"description": "Add conversions to and from `Flow<ByteArray>` and `ByteStream`",
5+
"issues": [
6+
"awslabs/aws-sdk-kotlin#612"
7+
]
8+
}

runtime/runtime-core/api/runtime-core.api

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ public final class aws/smithy/kotlin/runtime/content/ByteStreamKt {
131131
public static final fun cancel (Laws/smithy/kotlin/runtime/content/ByteStream;)V
132132
public static final fun decodeToString (Laws/smithy/kotlin/runtime/content/ByteStream;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
133133
public static final fun toByteArray (Laws/smithy/kotlin/runtime/content/ByteStream;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
134+
public static final fun toByteStream (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/CoroutineScope;Ljava/lang/Long;)Laws/smithy/kotlin/runtime/content/ByteStream;
135+
public static synthetic fun toByteStream$default (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/CoroutineScope;Ljava/lang/Long;ILjava/lang/Object;)Laws/smithy/kotlin/runtime/content/ByteStream;
136+
public static final fun toFlow (Laws/smithy/kotlin/runtime/content/ByteStream;J)Lkotlinx/coroutines/flow/Flow;
137+
public static synthetic fun toFlow$default (Laws/smithy/kotlin/runtime/content/ByteStream;JILjava/lang/Object;)Lkotlinx/coroutines/flow/Flow;
134138
}
135139

136140
public abstract class aws/smithy/kotlin/runtime/content/Document {

runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/content/ByteStream.kt

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
*/
55
package aws.smithy.kotlin.runtime.content
66

7-
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
8-
import aws.smithy.kotlin.runtime.io.SdkSource
9-
import aws.smithy.kotlin.runtime.io.readToBuffer
10-
import aws.smithy.kotlin.runtime.io.readToByteArray
7+
import aws.smithy.kotlin.runtime.io.*
8+
import aws.smithy.kotlin.runtime.io.internal.SdkDispatchers
9+
import kotlinx.coroutines.CoroutineScope
10+
import kotlinx.coroutines.flow.*
11+
import kotlinx.coroutines.launch
1112

1213
/**
1314
* Represents an abstract read-only stream of bytes
@@ -106,3 +107,92 @@ public fun ByteStream.cancel() {
106107
is ByteStream.SourceStream -> stream.readFrom().close()
107108
}
108109
}
110+
111+
/**
112+
* Return a [Flow] that consumes the underlying [ByteStream] when collected.
113+
*
114+
* @param bufferSize the size of the buffers to emit from the flow. All buffers emitted
115+
* will be of this size except for the last one which may be less than the requested buffer size.
116+
* This parameter has no effect for the [ByteStream.Buffer] variant. The emitted [ByteArray]
117+
* will be whatever size the in-memory buffer already is in that case.
118+
*/
119+
public fun ByteStream.toFlow(bufferSize: Long = 8192): Flow<ByteArray> = when (this) {
120+
is ByteStream.Buffer -> flowOf(bytes())
121+
is ByteStream.ChannelStream -> readFrom().toFlow(bufferSize)
122+
is ByteStream.SourceStream -> readFrom().toFlow(bufferSize).flowOn(SdkDispatchers.IO)
123+
}
124+
125+
/**
126+
* Create a [ByteStream] from a [Flow] of byte arrays.
127+
*
128+
* @param scope the [CoroutineScope] to use for launching a coroutine to do the collection in.
129+
* @param contentLength the overall content length of the [Flow] (if known). If set this will be
130+
* used as [ByteStream.contentLength]. Some APIs require a known `Content-Length` header and
131+
* since the total size of the flow can't be calculated without collecting it callers should set this
132+
* parameter appropriately in those cases.
133+
*/
134+
public fun Flow<ByteArray>.toByteStream(
135+
scope: CoroutineScope,
136+
contentLength: Long? = null,
137+
): ByteStream {
138+
val ch = SdkByteChannel(true)
139+
var totalWritten = 0L
140+
val job = scope.launch {
141+
collect { bytes ->
142+
ch.write(bytes)
143+
totalWritten += bytes.size
144+
145+
check(contentLength == null || totalWritten <= contentLength) {
146+
"$totalWritten bytes collected from flow exceeds reported content length of $contentLength"
147+
}
148+
}
149+
150+
check(contentLength == null || totalWritten == contentLength) {
151+
"expected $contentLength bytes collected from flow, got $totalWritten"
152+
}
153+
154+
ch.close()
155+
}
156+
157+
job.invokeOnCompletion { cause ->
158+
ch.close(cause)
159+
}
160+
161+
return object : ByteStream.ChannelStream() {
162+
override val contentLength: Long? = contentLength
163+
override val isOneShot: Boolean = true
164+
override fun readFrom(): SdkByteReadChannel = ch
165+
}
166+
}
167+
168+
private fun SdkByteReadChannel.toFlow(bufferSize: Long): Flow<ByteArray> = flow {
169+
val chan = this@toFlow
170+
val sink = SdkBuffer()
171+
while (!chan.isClosedForRead) {
172+
val rc = chan.read(sink, bufferSize)
173+
if (rc == -1L) break
174+
if (sink.size >= bufferSize) {
175+
val bytes = sink.readByteArray(bufferSize)
176+
emit(bytes)
177+
}
178+
}
179+
if (sink.size > 0L) {
180+
emit(sink.readByteArray())
181+
}
182+
}
183+
184+
private fun SdkSource.toFlow(bufferSize: Long): Flow<ByteArray> = flow {
185+
val source = this@toFlow
186+
val sink = SdkBuffer()
187+
while (true) {
188+
val rc = source.read(sink, bufferSize)
189+
if (rc == -1L) break
190+
if (sink.size >= bufferSize) {
191+
val bytes = sink.readByteArray(bufferSize)
192+
emit(bytes)
193+
}
194+
}
195+
if (sink.size > 0L) {
196+
emit(sink.readByteArray())
197+
}
198+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.smithy.kotlin.runtime.content
6+
7+
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
8+
import aws.smithy.kotlin.runtime.io.SdkSource
9+
import aws.smithy.kotlin.runtime.io.source
10+
11+
fun interface ByteStreamFactory {
12+
fun byteStream(input: ByteArray): ByteStream
13+
companion object {
14+
val BYTE_ARRAY: ByteStreamFactory = ByteStreamFactory { input -> ByteStream.fromBytes(input) }
15+
16+
val SDK_SOURCE: ByteStreamFactory = ByteStreamFactory { input ->
17+
object : ByteStream.SourceStream() {
18+
override fun readFrom(): SdkSource = input.source()
19+
override val contentLength: Long = input.size.toLong()
20+
}
21+
}
22+
23+
val SDK_CHANNEL: ByteStreamFactory = ByteStreamFactory { input ->
24+
object : ByteStream.ChannelStream() {
25+
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(input)
26+
override val contentLength: Long = input.size.toLong()
27+
}
28+
}
29+
}
30+
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.smithy.kotlin.runtime.content
6+
7+
import io.kotest.matchers.string.shouldContain
8+
import kotlinx.coroutines.*
9+
import kotlinx.coroutines.channels.Channel
10+
import kotlinx.coroutines.flow.*
11+
import kotlinx.coroutines.test.runTest
12+
import java.lang.RuntimeException
13+
import kotlin.test.*
14+
15+
class ByteStreamBufferFlowTest : ByteStreamFlowTest(ByteStreamFactory.BYTE_ARRAY)
16+
class ByteStreamSourceStreamFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_SOURCE)
17+
class ByteStreamChannelSourceFlowTest : ByteStreamFlowTest(ByteStreamFactory.SDK_CHANNEL)
18+
19+
abstract class ByteStreamFlowTest(
20+
private val factory: ByteStreamFactory,
21+
) {
22+
@Test
23+
fun testToFlowWithSizeHint() = runTest {
24+
val data = "a korf is a tiger".repeat(1024).encodeToByteArray()
25+
val bufferSize = 8182 * 2
26+
val byteStream = factory.byteStream(data)
27+
val flow = byteStream.toFlow(bufferSize.toLong())
28+
val buffers = mutableListOf<ByteArray>()
29+
flow.toList(buffers)
30+
31+
val totalCollected = buffers.sumOf { it.size }
32+
assertEquals(data.size, totalCollected)
33+
34+
if (byteStream is ByteStream.Buffer) {
35+
assertEquals(1, buffers.size)
36+
assertContentEquals(data, buffers.first())
37+
} else {
38+
val expectedFullBuffers = data.size / bufferSize
39+
for (i in 0 until expectedFullBuffers) {
40+
val b = buffers[i]
41+
val expected = data.sliceArray((i * bufferSize)until (i * bufferSize + bufferSize))
42+
assertEquals(bufferSize, b.size)
43+
assertContentEquals(expected, b)
44+
}
45+
46+
val last = buffers.last()
47+
val expected = data.sliceArray(((buffers.size - 1) * bufferSize) until data.size)
48+
assertContentEquals(expected, last)
49+
}
50+
}
51+
52+
class FlowToByteStreamTest {
53+
private fun testByteArray(size: Int): ByteArray = ByteArray(size) { i -> i.toByte() }
54+
55+
val data = listOf(
56+
testByteArray(576),
57+
testByteArray(9172),
58+
testByteArray(3278),
59+
)
60+
61+
@Test
62+
fun testFlowToByteStreamReadAll() = runTest {
63+
val flow = data.asFlow()
64+
val scope = CoroutineScope(coroutineContext)
65+
val byteStream = flow.toByteStream(scope)
66+
67+
assertNull(byteStream.contentLength)
68+
69+
val actual = byteStream.toByteArray()
70+
val expected = data.reduce { acc, bytes -> acc + bytes }
71+
assertContentEquals(expected, actual)
72+
}
73+
74+
@Test
75+
fun testContentLengthOverflow() = runTest {
76+
val advertisedContentLength = 1024L
77+
testInvalidContentLength(advertisedContentLength, "9748 bytes collected from flow exceeds reported content length of 1024")
78+
}
79+
80+
@Test
81+
fun testContentLengthUnderflow() = runTest {
82+
val advertisedContentLength = data.sumOf { it.size } + 100L
83+
testInvalidContentLength(advertisedContentLength, "expected 13126 bytes collected from flow, got 13026")
84+
}
85+
86+
private suspend fun testInvalidContentLength(advertisedContentLength: Long, expectedMessage: String) {
87+
val job = Job()
88+
val uncaughtExceptions = mutableListOf<Throwable>()
89+
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) }
90+
val scope = CoroutineScope(job + exHandler)
91+
val byteStream = data
92+
.asFlow()
93+
.toByteStream(scope, advertisedContentLength)
94+
95+
assertEquals(advertisedContentLength, byteStream.contentLength)
96+
97+
val ex = assertFailsWith<IllegalStateException> {
98+
byteStream.toByteArray()
99+
}
100+
101+
ex.message?.shouldContain(expectedMessage)
102+
assertTrue(job.isCancelled)
103+
job.join()
104+
105+
assertEquals(1, uncaughtExceptions.size)
106+
}
107+
108+
@Test
109+
fun testScopeCancellation() = runTest {
110+
// cancelling the scope should close/cancel the channel
111+
val waiter = Channel<Unit>(1)
112+
val flow = flow {
113+
emit(testByteArray(128))
114+
emit(testByteArray(277))
115+
waiter.receive()
116+
emit(testByteArray(97))
117+
}
118+
119+
val job = Job()
120+
val scope = CoroutineScope(job)
121+
val byteStream = flow.toByteStream(scope)
122+
assertIs<ByteStream.ChannelStream>(byteStream)
123+
assertNull(byteStream.contentLength)
124+
yield()
125+
126+
job.cancel("scope cancelled")
127+
waiter.send(Unit)
128+
job.join()
129+
130+
val ch = byteStream.readFrom()
131+
assertTrue(ch.isClosedForRead)
132+
assertTrue(ch.isClosedForWrite)
133+
assertIs<CancellationException>(ch.closedCause)
134+
ch.closedCause?.message.shouldContain("scope cancelled")
135+
}
136+
137+
@Test
138+
fun testChannelCancellation() = runTest {
139+
// cancelling the channel should cancel the scope (via write failing)
140+
val waiter = Channel<Unit>(1)
141+
val flow = flow {
142+
emit(testByteArray(128))
143+
emit(testByteArray(277))
144+
waiter.receive()
145+
emit(testByteArray(97))
146+
}
147+
148+
val uncaughtExceptions = mutableListOf<Throwable>()
149+
val exHandler = CoroutineExceptionHandler { _, throwable -> uncaughtExceptions.add(throwable) }
150+
val job = Job()
151+
val scope = CoroutineScope(job + exHandler)
152+
val byteStream = flow.toByteStream(scope)
153+
assertIs<ByteStream.ChannelStream>(byteStream)
154+
155+
val ch = byteStream.readFrom()
156+
val cause = RuntimeException("chan cancelled")
157+
ch.cancel(cause)
158+
159+
// unblock the flow
160+
waiter.send(Unit)
161+
162+
job.join()
163+
assertTrue(job.isCancelled)
164+
assertEquals(1, uncaughtExceptions.size)
165+
uncaughtExceptions.first().message.shouldContain("chan cancelled")
166+
}
167+
}
168+
}

runtime/runtime-core/jvm/test/aws/smithy/kotlin/runtime/content/ByteStreamInputStreamTest.kt

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,12 @@
44
*/
55
package aws.smithy.kotlin.runtime.content
66

7-
import aws.smithy.kotlin.runtime.io.SdkByteReadChannel
8-
import aws.smithy.kotlin.runtime.io.SdkSource
9-
import aws.smithy.kotlin.runtime.io.source
107
import java.io.InputStream
118
import kotlin.test.Test
129
import kotlin.test.assertContentEquals
1310
import kotlin.test.assertEquals
1411

15-
fun interface ByteStreamFactory {
16-
fun inputStream(input: ByteArray): InputStream
17-
companion object {
18-
val BYTE_ARRAY: ByteStreamFactory = ByteStreamFactory { input -> ByteStream.fromBytes(input).toInputStream() }
19-
20-
val SDK_SOURCE: ByteStreamFactory = ByteStreamFactory { input ->
21-
object : ByteStream.SourceStream() {
22-
override fun readFrom(): SdkSource = input.source()
23-
override val contentLength: Long = input.size.toLong()
24-
}.toInputStream()
25-
}
26-
27-
val SDK_CHANNEL: ByteStreamFactory = ByteStreamFactory { input ->
28-
object : ByteStream.ChannelStream() {
29-
override fun readFrom(): SdkByteReadChannel = SdkByteReadChannel(input)
30-
override val contentLength: Long = input.size.toLong()
31-
}.toInputStream()
32-
}
33-
}
34-
}
12+
fun ByteStreamFactory.inputStream(input: ByteArray): InputStream = byteStream(input).toInputStream()
3513

3614
class ByteStreamBufferInputStreamTest : ByteStreamInputStreamTest(ByteStreamFactory.BYTE_ARRAY)
3715
class ByteStreamSourceStreamInputStreamTest : ByteStreamInputStreamTest(ByteStreamFactory.SDK_SOURCE)

0 commit comments

Comments
 (0)