Skip to content

Commit 522fe81

Browse files
authored
Fix poor buffering case for MultipartReader (#8665) (#8676)
* Demonstrate poor buffering case * Fix for repeated reads of small byteCount from large part (cherry picked from commit 3cc87c3)
1 parent b2f22c2 commit 522fe81

File tree

4 files changed

+140
-5
lines changed

4 files changed

+140
-5
lines changed
Submodule hpack-test-case updated 510 files

okhttp/src/main/kotlin/okhttp3/MultipartReader.kt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import java.io.Closeable
1919
import java.io.IOException
2020
import java.net.ProtocolException
2121
import okhttp3.internal.http1.HeadersReader
22+
import okhttp3.internal.limit
2223
import okio.Buffer
2324
import okio.BufferedSource
2425
import okio.ByteString.Companion.encodeUtf8
@@ -175,10 +176,14 @@ class MultipartReader @Throws(IOException::class) constructor(
175176
* one byte left to read.
176177
*/
177178
private fun currentPartBytesRemaining(maxResult: Long): Long {
178-
source.require(crlfDashDashBoundary.size.toLong())
179-
180-
return when (val delimiterIndex = source.buffer.indexOf(crlfDashDashBoundary)) {
181-
-1L -> minOf(maxResult, source.buffer.size - crlfDashDashBoundary.size + 1)
179+
// Avoid indexOf scanning repeatedly over the entire source by using limit
180+
// Since maxResult could be midway through the boundary, read further to be safe.
181+
val limitSource = source.peek().limit(maxResult + crlfDashDashBoundary.size).buffer()
182+
limitSource.require(crlfDashDashBoundary.size.toLong())
183+
184+
val delimiterIndex = limitSource.buffer.indexOf(crlfDashDashBoundary)
185+
return when (delimiterIndex) {
186+
-1L -> minOf(maxResult, limitSource.buffer.size - crlfDashDashBoundary.size + 1)
182187
else -> minOf(maxResult, delimiterIndex)
183188
}
184189
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (C) 2024 Square, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package okhttp3.internal
18+
19+
import kotlin.jvm.JvmOverloads
20+
import okio.Buffer
21+
import okio.ForwardingSource
22+
import okio.Source
23+
24+
/**
25+
* Return a new [Source] whose [read function][Source.read] returns -1 after [byteCount]
26+
* bytes have been read.
27+
*
28+
* @param onReadExhausted Callback invoked once when the end of bytes has been reached. It receives
29+
* `true` if the end of bytes was because the underlying stream did not contain enough bytes and
30+
* `false` if [byteCount] bytes were successfully read.
31+
*/
32+
@JvmOverloads
33+
internal fun Source.limit(
34+
byteCount: Long,
35+
onReadExhausted: (eof: Boolean) -> Unit = {},
36+
): Source {
37+
require(byteCount >= 0) { "byteCount < 0: $byteCount" }
38+
return FixedLengthSource(this, byteCount, onReadExhausted, truncate = true)
39+
}
40+
41+
internal class FixedLengthSource(
42+
delegate: Source,
43+
private var bytesRemaining: Long,
44+
onReadExhausted: (eof: Boolean) -> Unit,
45+
private val truncate: Boolean,
46+
) : ForwardingSource(delegate) {
47+
/** `null` once invoked. */
48+
private var onReadExhausted: ((eof: Boolean) -> Unit)? = onReadExhausted
49+
50+
override fun read(
51+
sink: Buffer,
52+
byteCount: Long,
53+
): Long {
54+
val requestBytes =
55+
if (truncate) {
56+
if (bytesRemaining == 0L) {
57+
// If the limit was 0 we want to wait until the first call to this function before
58+
// triggering the callback.
59+
onReadExhausted?.invoke(false)
60+
onReadExhausted = null
61+
return -1L
62+
}
63+
minOf(bytesRemaining, byteCount)
64+
} else {
65+
byteCount
66+
}
67+
val readBytes = super.read(sink, requestBytes)
68+
if (readBytes == -1L) {
69+
onReadExhausted!!(true)
70+
onReadExhausted = null
71+
return -1L
72+
}
73+
bytesRemaining -= readBytes
74+
if (bytesRemaining == 0L) {
75+
onReadExhausted!!(false)
76+
onReadExhausted = null
77+
}
78+
return readBytes
79+
}
80+
}

okhttp/src/test/java/okhttp3/MultipartReaderTest.kt

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import java.io.EOFException
1919
import java.net.ProtocolException
2020
import okhttp3.Headers.Companion.headersOf
2121
import okhttp3.MediaType.Companion.toMediaType
22+
import okhttp3.MediaType.Companion.toMediaTypeOrNull
2223
import okhttp3.RequestBody.Companion.toRequestBody
2324
import okhttp3.ResponseBody.Companion.toResponseBody
2425
import okio.Buffer
@@ -538,4 +539,53 @@ class MultipartReaderTest {
538539

539540
assertThat(reader.nextPart()).isNull()
540541
}
542+
543+
@Test
544+
fun `reading a large part with small byteCount`() {
545+
val multipartBody: RequestBody =
546+
MultipartBody.Builder("foo").addPart(
547+
headersOf("header-name", "header-value"),
548+
object : RequestBody() {
549+
override fun contentType(): MediaType? {
550+
return "application/octet-stream".toMediaTypeOrNull()
551+
}
552+
553+
override fun contentLength(): Long {
554+
return (1024 * 1024 * 100).toLong()
555+
}
556+
557+
override fun writeTo(sink: okio.BufferedSink) {
558+
repeat(100) {
559+
sink.writeUtf8(
560+
"a".repeat(1024 * 1024),
561+
)
562+
}
563+
}
564+
},
565+
).build()
566+
val buffer =
567+
Buffer().apply {
568+
multipartBody.writeTo(this)
569+
}
570+
571+
val multipartReader = MultipartReader(buffer, "foo")
572+
while (true) {
573+
val part = multipartReader.nextPart()
574+
575+
if (part == null) break
576+
577+
assertThat(part.headers["header-name"]).isEqualTo("header-value")
578+
while (true) {
579+
val readBuff = Buffer()
580+
val read = part.body.read(readBuff, (1024).toLong())
581+
if (read == -1L) {
582+
break
583+
} else {
584+
assertThat(readBuff.readUtf8()).isEqualTo(
585+
"a".repeat(read.toInt()),
586+
)
587+
}
588+
}
589+
}
590+
}
541591
}

0 commit comments

Comments
 (0)