Skip to content

Commit d534e33

Browse files
committed
fix AsyncMessageReader to handle partial reads
Signed-off-by: Connor Tsui <[email protected]>
1 parent a7d5768 commit d534e33

File tree

3 files changed

+130
-29
lines changed

3 files changed

+130
-29
lines changed

vortex-ipc/src/messages/decoder.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,21 @@ enum State {
4646

4747
#[derive(Debug)]
4848
pub enum PollRead {
49+
/// A complete message was decoded.
4950
Some(DecoderMessage),
50-
/// Returns the _total_ number of bytes needed to make progress.
51-
/// Note this is _not_ the incremental number of bytes needed to make progress.
51+
/// The decoder needs more data to make progress.
52+
///
53+
/// The inner value is the **total*k number of bytes the buffer should contain, not the
54+
/// incremental amount needed. Callers should:
55+
///
56+
/// 1. Resize the buffer to this length.
57+
/// 2. Fill the buffer completely (handling partial reads as needed).
58+
/// 3. Only then call [`MessageDecoder::read_next`] again.
59+
///
60+
/// The decoder checks [`bytes::Buf::remaining`] to determine available data, which for
61+
/// [`bytes::BytesMut`] returns the buffer length regardless of how many bytes were actually
62+
/// written. Calling `read_next` before the buffer is fully populated will cause the decoder
63+
/// to read garbage data.
5264
NeedMore(usize),
5365
}
5466

vortex-ipc/src/messages/reader_async.rs

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use futures::Stream;
1212
use pin_project_lite::pin_project;
1313
use vortex_array::session::ArrayRegistry;
1414
use vortex_error::VortexResult;
15+
use vortex_error::vortex_err;
1516

1617
use crate::messages::DecoderMessage;
1718
use crate::messages::MessageDecoder;
@@ -24,7 +25,7 @@ pin_project! {
2425
read: R,
2526
buffer: BytesMut,
2627
decoder: MessageDecoder,
27-
bytes_read: usize,
28+
state: ReadState,
2829
}
2930
}
3031

@@ -34,40 +35,97 @@ impl<R> AsyncMessageReader<R> {
3435
read,
3536
buffer: BytesMut::new(),
3637
decoder: MessageDecoder::new(registry),
37-
bytes_read: 0,
38+
state: ReadState::default(),
3839
}
3940
}
4041
}
4142

43+
/// The state of an in-progress read operation.
44+
#[derive(Default)]
45+
enum ReadState {
46+
/// Ready to consult the decoder for the next operation.
47+
#[default]
48+
AwaitingDecoder,
49+
/// Filling the buffer with data from the underlying reader.
50+
///
51+
/// Async readers may return fewer bytes than requested (partial reads), especially over network
52+
/// connections. This state persists across multiple `poll_next` calls until the buffer is
53+
/// completely filled, at which point we transition back to [`Self::AwaitingDecoder`].
54+
Filling {
55+
/// The number of bytes read into the buffer so far.
56+
total_bytes_read: usize,
57+
},
58+
}
59+
60+
/// Result of polling the reader to fill the buffer.
61+
enum FillResult {
62+
/// The buffer has been completely filled.
63+
Filled,
64+
/// Need more data (partial read occurred).
65+
Pending,
66+
/// Clean EOF at a message boundary.
67+
Eof,
68+
}
69+
70+
/// Polls the reader to fill the buffer, handling partial reads.
71+
fn poll_fill_buffer<R: AsyncRead>(
72+
read: Pin<&mut R>,
73+
buffer: &mut [u8],
74+
total_bytes_read: &mut usize,
75+
cx: &mut Context<'_>,
76+
) -> Poll<VortexResult<FillResult>> {
77+
let unfilled = &mut buffer[*total_bytes_read..];
78+
79+
let bytes_read = ready!(read.poll_read(cx, unfilled))?;
80+
81+
// `0` bytes read indicates an EOF.
82+
Poll::Ready(if bytes_read == 0 {
83+
if *total_bytes_read > 0 {
84+
Err(vortex_err!(
85+
"unexpected EOF during partial read: read {total_bytes_read} of {} expected bytes",
86+
buffer.len()
87+
))
88+
} else {
89+
Ok(FillResult::Eof)
90+
}
91+
} else {
92+
*total_bytes_read += bytes_read;
93+
if *total_bytes_read == buffer.len() {
94+
Ok(FillResult::Filled)
95+
} else {
96+
debug_assert!(*total_bytes_read < buffer.len());
97+
Ok(FillResult::Pending)
98+
}
99+
})
100+
}
101+
42102
impl<R: AsyncRead> Stream for AsyncMessageReader<R> {
43103
type Item = VortexResult<DecoderMessage>;
44104

45105
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
46106
let mut this = self.project();
47107
loop {
48-
match this.decoder.read_next(this.buffer)? {
49-
PollRead::Some(msg) => return Poll::Ready(Some(Ok(msg))),
50-
PollRead::NeedMore(nbytes) => {
51-
this.buffer.resize(nbytes, 0x00);
52-
53-
match ready!(
54-
this.read
55-
.as_mut()
56-
.poll_read(cx, &mut this.buffer.as_mut()[*this.bytes_read..])
57-
) {
58-
Ok(0) => {
59-
// End of file
60-
return Poll::Ready(None);
61-
}
62-
Ok(nbytes) => {
63-
*this.bytes_read += nbytes;
64-
// If we've finished the read operation, then we continue the loop
65-
// and the decoder should present us with a new response.
66-
if *this.bytes_read == nbytes {
67-
*this.bytes_read = 0;
68-
}
69-
}
70-
Err(e) => return Poll::Ready(Some(Err(e.into()))),
108+
match this.state {
109+
ReadState::AwaitingDecoder => match this.decoder.read_next(this.buffer)? {
110+
PollRead::Some(msg) => return Poll::Ready(Some(Ok(msg))),
111+
PollRead::NeedMore(new_len) => {
112+
this.buffer.resize(new_len, 0x00);
113+
*this.state = ReadState::Filling {
114+
total_bytes_read: 0,
115+
};
116+
}
117+
},
118+
ReadState::Filling { total_bytes_read } => {
119+
match ready!(poll_fill_buffer(
120+
this.read.as_mut(),
121+
this.buffer,
122+
total_bytes_read,
123+
cx
124+
)) {
125+
Err(e) => return Poll::Ready(Some(Err(e))),
126+
Ok(FillResult::Eof) => return Poll::Ready(None),
127+
Ok(FillResult::Filled) => *this.state = ReadState::AwaitingDecoder,
128+
Ok(FillResult::Pending) => {}
71129
}
72130
}
73131
}

vortex-ipc/src/stream.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,38 @@ mod test {
275275
.await
276276
.unwrap();
277277

278-
let result = reader.read_all().await.unwrap();
279-
assert_eq!(result.len(), 10);
278+
let result = reader.read_all().await.unwrap().to_primitive();
279+
assert_eq!(
280+
&[1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10],
281+
result.as_slice::<i32>()
282+
);
283+
}
284+
285+
/// Test with 1-byte chunks to stress-test partial read handling.
286+
#[tokio::test]
287+
async fn test_async_stream_single_byte_chunks() {
288+
let session = ArraySession::default();
289+
let array = buffer![42i64, -1, 0, i64::MAX, i64::MIN].into_array();
290+
let ipc_buffer = array
291+
.to_array_stream()
292+
.into_ipc()
293+
.collect_to_buffer()
294+
.await
295+
.unwrap();
296+
297+
let chunked = ChunkedReader {
298+
inner: Cursor::new(ipc_buffer),
299+
chunk_size: 1,
300+
};
301+
302+
let reader = AsyncIPCReader::try_new(chunked, session.registry().clone())
303+
.await
304+
.unwrap();
305+
306+
let result = reader.read_all().await.unwrap().to_primitive();
307+
assert_eq!(
308+
&[42i64, -1, 0, i64::MAX, i64::MIN],
309+
result.as_slice::<i64>()
310+
);
280311
}
281312
}

0 commit comments

Comments
 (0)