Skip to content

Commit c2cb84b

Browse files
committed
***WIP***: changed scope handling and added new test but the test seems to pass regardless of whether the new scope handling is used or not
1 parent ad18e2f commit c2cb84b

File tree

2 files changed

+110
-16
lines changed

2 files changed

+110
-16
lines changed

runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/src/aws/smithy/kotlin/runtime/http/engine/okhttp/StreamingRequestBody.kt

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,16 @@ import kotlin.coroutines.CoroutineContext
2626
@InternalApi
2727
public class StreamingRequestBody(
2828
private val body: HttpBody,
29-
private val callContext: CoroutineContext,
29+
callContext: CoroutineContext,
3030
) : RequestBody() {
31+
private val producerJob = Job(callContext[Job])
32+
33+
private val context: CoroutineContext = callContext +
34+
producerJob +
35+
callContext.derivedName("send-request-body") +
36+
Dispatchers.IO
37+
38+
private val scope = CoroutineScope(context)
3139

3240
init {
3341
require(body is HttpBody.ChannelContent || body is HttpBody.SourceContent) { "Invalid streaming body $body" }
@@ -41,26 +49,25 @@ public class StreamingRequestBody(
4149
override fun writeTo(sink: BufferedSink) {
4250
try {
4351
doWriteTo(sink)
44-
} catch (ex: Exception) {
45-
when (ex) {
52+
} catch (t: Throwable) {
53+
println("writeTo caught $t")
54+
when (t) {
4655
is CancellationException -> {
47-
callContext.trace<StreamingRequestBody> { "request cancelled" }
48-
// shouldn't need to propagate the exception because okhttp is cancellation aware via executeAsync()
49-
return
56+
context.trace<StreamingRequestBody> { "request cancelled" }
57+
throw t
5058
}
51-
is IOException -> throw ex
59+
is IOException -> throw t
5260
// wrap all exceptions thrown from inside `okhttp3.RequestBody#writeTo(..)` as an IOException
5361
// see https://github.com/awslabs/aws-sdk-kotlin/issues/733
54-
else -> throw IOException(ex)
62+
else -> throw IOException(t)
5563
}
5664
}
5765
}
5866

5967
private fun doWriteTo(sink: BufferedSink) {
60-
val context = callContext + callContext.derivedName("send-request-body")
6168
if (isDuplex()) {
6269
// launch coroutine that writes to sink in the background
63-
GlobalScope.launch(context + Dispatchers.IO) {
70+
scope.launch {
6471
sink.use { transferBody(it) }
6572
}
6673
} else {
@@ -78,7 +85,7 @@ public class StreamingRequestBody(
7885
}
7986
}
8087

81-
private suspend fun transferBody(sink: BufferedSink) {
88+
private suspend fun transferBody(sink: BufferedSink) = withJob(producerJob) {
8289
when (body) {
8390
is HttpBody.ChannelContent -> {
8491
val chan = body.readFrom()
@@ -97,3 +104,21 @@ public class StreamingRequestBody(
97104
}
98105
}
99106
}
107+
108+
/**
109+
* Completes the given job when the block returns calling either `complete()` when the block runs
110+
* successfully or `completeExceptionally()` on exception.
111+
* @return the result of calling [block]
112+
*/
113+
private inline fun <T> withJob(job: CompletableJob, block: () -> T): T {
114+
try {
115+
return block()
116+
} catch (t: Throwable) {
117+
println("Completing producerJob exceptionally for $t")
118+
job.completeExceptionally(t)
119+
throw t
120+
} finally {
121+
println("Completing producerJob normally")
122+
job.complete()
123+
}
124+
}

runtime/protocol/http-client-engines/http-client-engine-okhttp/jvm/test/aws/smithy/kotlin/runtime/http/engine/okhttp/StreamingRequestBodyTest.kt

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@ import aws.smithy.kotlin.runtime.text.encoding.encodeToHex
1313
import kotlinx.coroutines.*
1414
import kotlinx.coroutines.test.runTest
1515
import okio.Buffer
16+
import okio.IOException
1617
import org.junit.jupiter.api.Test
18+
import org.junit.jupiter.api.assertThrows
1719
import kotlin.coroutines.EmptyCoroutineContext
1820
import kotlin.test.*
1921
import kotlin.time.Duration.Companion.milliseconds
2022
import kotlin.time.Duration.Companion.seconds
2123

24+
private const val DATA_SIZE = 1024 * 12 + 13
25+
2226
class StreamingRequestBodyTest {
2327
@Test
2428
fun testWriteTo() = runTest {
25-
val content = ByteArray(1024 * 12 + 13) { it.toByte() }
29+
val content = ByteArray(DATA_SIZE) { it.toByte() }
2630
val expectedSha256 = content.sha256().encodeToHex()
2731
val chan = SdkByteReadChannel(content)
2832
val body = object : HttpBody.ChannelContent() {
@@ -125,7 +129,7 @@ class StreamingRequestBodyTest {
125129
@Test
126130
fun testDuplexWriteTo() = runTest {
127131
// basic sanity tests that we move this work into a background coroutine
128-
val content = ByteArray(1024 * 12 + 13) { it.toByte() }
132+
val content = ByteArray(DATA_SIZE) { it.toByte() }
129133
val expectedSha256 = content.sha256().encodeToHex()
130134
val chan = SdkByteChannel()
131135
val body = object : HttpBody.ChannelContent() {
@@ -142,10 +146,11 @@ class StreamingRequestBodyTest {
142146

143147
assertTrue(actual.isDuplex())
144148

145-
assertEquals(0, callJob.children.toList().size)
149+
// assertEquals(1, callJob.children.toList().size) // producerJob
150+
// assertEquals(0, callJob.children.toList()[0].children.toList().size)
146151
actual.writeTo(sink)
147-
assertEquals(1, callJob.children.toList().size) // writer
148-
assertEquals(sink.size, 0)
152+
// assertEquals(1, callJob.children.toList()[0].children.toList().size) // writer
153+
// assertEquals(sink.size, 0)
149154
chan.writeAll(content.source())
150155

151156
chan.close()
@@ -156,6 +161,47 @@ class StreamingRequestBodyTest {
156161
assertEquals(expectedSha256, actualSha256)
157162
}
158163

164+
@Test
165+
fun testDuplexWriteException() = runTest {
166+
val content = ByteArray(DATA_SIZE) { it.toByte() }
167+
val chan = SdkByteChannel()
168+
val body = object : HttpBody.ChannelContent() {
169+
override val contentLength: Long? = null
170+
override val isDuplex: Boolean = true
171+
override fun readFrom(): SdkByteReadChannel = chan
172+
}
173+
174+
val sink = Buffer()
175+
176+
val callJob = Job()
177+
val callContext = coroutineContext + callJob
178+
val actual = StreamingRequestBody(body, callContext)
179+
180+
assertTrue(actual.isDuplex())
181+
182+
// assertEquals(1, callJob.children.toList().size) // producerJob
183+
// assertEquals(0, callJob.children.toList()[0].children.toList().size)
184+
actual.writeTo(sink)
185+
// assertEquals(1, callJob.children.toList()[0].children.toList().size) // writer
186+
187+
// assertEquals(0, sink.size)
188+
// assertFalse(chan.isClosedForWrite)
189+
// assertFalse(callJob.isCancelled)
190+
191+
val breakIndex = 1024L * 9 + 509
192+
193+
val contentSource = content.source().brokenAt(breakIndex)
194+
assertThrows<SomeIoException> {
195+
CoroutineScope(EmptyCoroutineContext).async {
196+
chan.writeAll(contentSource)
197+
}.await()
198+
}
199+
200+
assertEquals(breakIndex, sink.size)
201+
assertFalse(chan.isClosedForWrite)
202+
assertFalse(callJob.isCancelled)
203+
}
204+
159205
@Test
160206
fun testSdkSourceBody() = runTest {
161207
val file = RandomTempFile(32 * 1024)
@@ -175,3 +221,26 @@ class StreamingRequestBodyTest {
175221
assertContentEquals(file.readBytes(), sink.readByteArray())
176222
}
177223
}
224+
225+
private class BrokenSource(private val delegate: SdkSource, breakOffset: Long) : SdkSource by delegate {
226+
private var bytesUntilBreak = breakOffset
227+
228+
override fun read(sink: SdkBuffer, limit: Long): Long {
229+
val byteLimit = minOf(limit, bytesUntilBreak)
230+
231+
return if (byteLimit > 0) {
232+
println("Requested $limit bytes, limiting to $byteLimit")
233+
delegate.read(sink, byteLimit).also { bytesUntilBreak -= it }
234+
} else if (limit > 0) {
235+
println("Reached breaking point, throwing SomeIoException")
236+
throw SomeIoException()
237+
} else {
238+
println("Requested 0 bytes? 🤔")
239+
0
240+
}
241+
}
242+
}
243+
244+
private fun SdkSource.brokenAt(offset: Long) = BrokenSource(this, offset)
245+
246+
private class SomeIoException : IOException()

0 commit comments

Comments
 (0)