Skip to content

Commit 90209f5

Browse files
committed
fix: stop using GlobalScope for event stream handlers
1 parent 159e784 commit 90209f5

File tree

3 files changed

+108
-16
lines changed

3 files changed

+108
-16
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"id": "d3a7a353-019d-470c-9c46-f802a67eea0d",
3+
"type": "bugfix",
4+
"description": "Stop using `GlobalScope` for event streaming handlers to improve resource handling"
5+
}

runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/StreamingRequestBody.kt

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@ import kotlin.coroutines.CoroutineContext
2626
@InternalApi
2727
public class StreamingRequestBody(
2828
private val body: HttpBody,
29-
private val callContext: CoroutineContext,
29+
callContext: CoroutineContext,
3030
) : RequestBody() {
31+
private val producerJob = Job(callContext[Job])
32+
33+
private val context: CoroutineContext = callContext +
34+
producerJob +
35+
callContext.derivedName("send-request-body") +
36+
Dispatchers.IO
3137

3238
init {
3339
require(body is HttpBody.ChannelContent || body is HttpBody.SourceContent) { "Invalid streaming body $body" }
@@ -41,26 +47,24 @@ public class StreamingRequestBody(
4147
override fun writeTo(sink: BufferedSink) {
4248
try {
4349
doWriteTo(sink)
44-
} catch (ex: Exception) {
45-
when (ex) {
50+
} catch (t: Throwable) {
51+
when (t) {
4652
is CancellationException -> {
47-
callContext.trace<StreamingRequestBody> { "request cancelled" }
48-
// shouldn't need to propagate the exception because okhttp is cancellation aware via executeAsync()
49-
return
53+
context.trace<StreamingRequestBody> { "request cancelled" }
54+
throw t
5055
}
51-
is IOException -> throw ex
56+
is IOException -> throw t
5257
// wrap all exceptions thrown from inside `okhttp3.RequestBody#writeTo(..)` as an IOException
5358
// see https://github.com/awslabs/aws-sdk-kotlin/issues/733
54-
else -> throw IOException(ex)
59+
else -> throw IOException(t)
5560
}
5661
}
5762
}
5863

5964
private fun doWriteTo(sink: BufferedSink) {
60-
val context = callContext + callContext.derivedName("send-request-body")
6165
if (isDuplex()) {
6266
// launch coroutine that writes to sink in the background
63-
GlobalScope.launch(context + Dispatchers.IO) {
67+
CoroutineScope(context).launch {
6468
sink.use { transferBody(it) }
6569
}
6670
} else {
@@ -78,7 +82,7 @@ public class StreamingRequestBody(
7882
}
7983
}
8084

81-
private suspend fun transferBody(sink: BufferedSink) {
85+
private suspend fun transferBody(sink: BufferedSink) = withJob(producerJob) {
8286
when (body) {
8387
is HttpBody.ChannelContent -> {
8488
val chan = body.readFrom()
@@ -97,3 +101,17 @@ public class StreamingRequestBody(
97101
}
98102
}
99103
}
104+
105+
/**
106+
* Completes the given job when the block returns calling either `complete()` when the block runs
107+
* successfully or `completeExceptionally()` on exception.
108+
* @return the result of calling [block]
109+
*/
110+
private inline fun <T> withJob(job: CompletableJob, block: () -> T): T {
111+
try {
112+
return block().also { job.complete() }
113+
} catch (t: Throwable) {
114+
job.completeExceptionally(t)
115+
throw t
116+
}
117+
}

runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/test/aws/smithy/kotlin/runtime/http/engine/okhttp/StreamingRequestBodyTest.kt

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@ import aws.smithy.kotlin.runtime.text.encoding.encodeToHex
1313
import kotlinx.coroutines.*
1414
import kotlinx.coroutines.test.runTest
1515
import okio.Buffer
16+
import okio.IOException
1617
import org.junit.jupiter.api.Test
18+
import org.junit.jupiter.api.assertThrows
1719
import kotlin.coroutines.EmptyCoroutineContext
1820
import kotlin.test.*
1921
import kotlin.time.Duration.Companion.milliseconds
2022
import kotlin.time.Duration.Companion.seconds
2123

24+
private const val DATA_SIZE = 1024 * 12 + 13
25+
2226
class StreamingRequestBodyTest {
2327
@Test
2428
fun testWriteTo() = runTest {
25-
val content = ByteArray(1024 * 12 + 13) { it.toByte() }
29+
val content = ByteArray(DATA_SIZE) { it.toByte() }
2630
val expectedSha256 = content.sha256().encodeToHex()
2731
val chan = SdkByteReadChannel(content)
2832
val body = object : HttpBody.ChannelContent() {
@@ -125,7 +129,7 @@ class StreamingRequestBodyTest {
125129
@Test
126130
fun testDuplexWriteTo() = runTest {
127131
// basic sanity tests that we move this work into a background coroutine
128-
val content = ByteArray(1024 * 12 + 13) { it.toByte() }
132+
val content = ByteArray(DATA_SIZE) { it.toByte() }
129133
val expectedSha256 = content.sha256().encodeToHex()
130134
val chan = SdkByteChannel()
131135
val body = object : HttpBody.ChannelContent() {
@@ -142,10 +146,11 @@ class StreamingRequestBodyTest {
142146

143147
assertTrue(actual.isDuplex())
144148

145-
assertEquals(0, callJob.children.toList().size)
149+
// assertEquals(1, callJob.children.toList().size) // producerJob
150+
// assertEquals(0, callJob.children.toList()[0].children.toList().size)
146151
actual.writeTo(sink)
147-
assertEquals(1, callJob.children.toList().size) // writer
148-
assertEquals(sink.size, 0)
152+
// assertEquals(1, callJob.children.toList()[0].children.toList().size) // writer
153+
// assertEquals(0, sink.size)
149154
chan.writeAll(content.source())
150155

151156
chan.close()
@@ -156,6 +161,47 @@ class StreamingRequestBodyTest {
156161
assertEquals(expectedSha256, actualSha256)
157162
}
158163

164+
@Test
165+
fun testDuplexWriteException() = runBlocking {
166+
val content = ByteArray(DATA_SIZE) { it.toByte() }
167+
val chan = SdkByteChannel()
168+
val body = object : HttpBody.ChannelContent() {
169+
override val contentLength: Long? = null
170+
override val isDuplex: Boolean = true
171+
override fun readFrom(): SdkByteReadChannel = chan
172+
}
173+
174+
val sink = Buffer()
175+
176+
val callJob = Job()
177+
val callContext = coroutineContext + callJob
178+
val actual = StreamingRequestBody(body, callContext)
179+
180+
assertTrue(actual.isDuplex())
181+
182+
// assertEquals(1, callJob.children.toList().size) // producerJob
183+
// assertEquals(0, callJob.children.toList()[0].children.toList().size)
184+
actual.writeTo(sink)
185+
// assertEquals(1, callJob.children.toList()[0].children.toList().size) // writer
186+
187+
assertEquals(0, sink.size)
188+
assertFalse(chan.isClosedForWrite)
189+
assertFalse(callJob.isCancelled)
190+
191+
val breakIndex = 1024L * 9 + 509
192+
193+
val contentSource = content.source().brokenAt(breakIndex)
194+
assertThrows<SomeIoException> {
195+
CoroutineScope(CoroutineExceptionHandler { ctx, e -> println("Got exception $e") }).async {
196+
chan.writeAll(contentSource)
197+
}.await()
198+
}
199+
200+
assertEquals(breakIndex, sink.size)
201+
assertFalse(chan.isClosedForWrite)
202+
assertFalse(callJob.isCancelled)
203+
}
204+
159205
@Test
160206
fun testSdkSourceBody() = runTest {
161207
val file = RandomTempFile(32 * 1024)
@@ -175,3 +221,26 @@ class StreamingRequestBodyTest {
175221
assertContentEquals(file.readBytes(), sink.readByteArray())
176222
}
177223
}
224+
225+
private class BrokenSource(private val delegate: SdkSource, breakOffset: Long) : SdkSource by delegate {
226+
private var bytesUntilBreak = breakOffset
227+
228+
override fun read(sink: SdkBuffer, limit: Long): Long {
229+
val byteLimit = minOf(limit, bytesUntilBreak)
230+
231+
return if (byteLimit > 0) {
232+
println("Requested $limit bytes, limiting to $byteLimit")
233+
delegate.read(sink, byteLimit).also { bytesUntilBreak -= it }
234+
} else if (limit > 0) {
235+
println("Reached breaking point, throwing SomeIoException")
236+
throw SomeIoException()
237+
} else {
238+
println("Requested 0 bytes? 🤔")
239+
0
240+
}
241+
}
242+
}
243+
244+
private fun SdkSource.brokenAt(offset: Long) = BrokenSource(this, offset)
245+
246+
private class SomeIoException : IOException()

0 commit comments

Comments
 (0)