Skip to content

Commit 153c1a1

Browse files
authored
implement Buf for cursors (#12)
Mostly done by Claude with some supervision from me. Seems to work great!
1 parent 973d723 commit 153c1a1

File tree

3 files changed

+335
-2
lines changed

3 files changed

+335
-2
lines changed

.claude/settings.local.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Bash(cargo test)",
5+
"Bash(cargo test:*)",
6+
"Bash(cargo clippy:*)",
7+
"Bash(cargo fix:*)"
8+
],
9+
"deny": [],
10+
"ask": []
11+
}
12+
}

src/cursor/mod.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ mod tests;
99
mod tokio_imp;
1010

1111
use crate::{BufList, errors::ReadExactError};
12-
use bytes::Bytes;
12+
use bytes::{Buf, Bytes};
1313
use std::{
1414
cmp::Ordering,
15-
io::{self, IoSliceMut, SeekFrom},
15+
io::{self, IoSlice, IoSliceMut, SeekFrom},
1616
};
1717

1818
/// A `Cursor` wraps an in-memory `BufList` and provides it with a [`Seek`] implementation.
@@ -195,6 +195,58 @@ impl<T: AsRef<BufList>> io::BufRead for Cursor<T> {
195195
}
196196
}
197197

198+
impl<T: AsRef<BufList>> Buf for Cursor<T> {
199+
fn remaining(&self) -> usize {
200+
let total = self.data.num_bytes(self.inner.as_ref());
201+
total.saturating_sub(self.data.pos) as usize
202+
}
203+
204+
fn chunk(&self) -> &[u8] {
205+
self.data.fill_buf_impl(self.inner.as_ref())
206+
}
207+
208+
fn advance(&mut self, amt: usize) {
209+
self.data.consume_impl(self.inner.as_ref(), amt);
210+
}
211+
212+
fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize {
213+
if iovs.is_empty() {
214+
return 0;
215+
}
216+
217+
let list = self.inner.as_ref();
218+
let mut filled = 0;
219+
let mut current_chunk = self.data.chunk;
220+
let mut current_pos = self.data.pos;
221+
222+
// Iterate through chunks starting from the current position
223+
while filled < iovs.len() && current_chunk < list.num_chunks() {
224+
if let Some(chunk) = list.get_chunk(current_chunk) {
225+
let chunk_start_pos = list.get_start_pos()[current_chunk];
226+
let offset_in_chunk = (current_pos - chunk_start_pos) as usize;
227+
228+
if offset_in_chunk < chunk.len() {
229+
let chunk_slice = &chunk.as_ref()[offset_in_chunk..];
230+
iovs[filled] = IoSlice::new(chunk_slice);
231+
filled += 1;
232+
}
233+
234+
current_chunk += 1;
235+
// Move to the start of the next chunk
236+
if let Some(&next_start_pos) = list.get_start_pos().get(current_chunk) {
237+
current_pos = next_start_pos;
238+
} else {
239+
break;
240+
}
241+
} else {
242+
break;
243+
}
244+
}
245+
246+
filled
247+
}
248+
}
249+
198250
#[derive(Clone, Debug)]
199251
struct CursorData {
200252
/// The chunk number the cursor is pointing to. Kept in sync with pos.

src/cursor/tests.rs

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ enum CursorOp {
5656
// fill_buf can't be tested here because oracle is a contiguous block. Instead, we check its
5757
// return value separately.
5858
Consume(prop::sample::Index),
59+
// Buf trait operations
60+
BufRemaining,
61+
BufChunk,
62+
BufAdvance(prop::sample::Index),
63+
BufChunksVectored(prop::sample::Index),
64+
BufCopyToBytes(prop::sample::Index),
65+
BufGetU8,
66+
BufGetU64,
67+
BufGetU64Le,
5968
// No need to test futures03 imps since they're simple wrappers around the main imps.
6069
#[cfg(feature = "tokio1")]
6170
PollRead {
@@ -183,6 +192,204 @@ impl CursorOp {
183192
buf_list.consume(amt);
184193
oracle.consume(amt);
185194
}
195+
Self::BufRemaining => {
196+
eprintln!("buf_remaining");
197+
198+
let buf_list_remaining = buf_list.remaining();
199+
let oracle_remaining = oracle.remaining();
200+
ensure!(
201+
buf_list_remaining == oracle_remaining,
202+
"remaining didn't match: buf_list {} == oracle {}",
203+
buf_list_remaining,
204+
oracle_remaining
205+
);
206+
}
207+
Self::BufChunk => {
208+
eprintln!("buf_chunk");
209+
210+
let buf_list_chunk = buf_list.chunk();
211+
let oracle_chunk = oracle.chunk();
212+
213+
// We can't directly compare chunks because BufList returns one
214+
// segment at a time while oracle returns the entire remaining
215+
// buffer. Instead, verify that:
216+
//
217+
// 1. is_empty matches for both chunks.
218+
// 2. Both start with the same data (buf_list's chunk is a prefix of oracle's)
219+
ensure!(
220+
buf_list_chunk.is_empty() == oracle_chunk.is_empty(),
221+
"chunk emptiness didn't match: buf_list is_empty {} == oracle is_empty {}",
222+
buf_list_chunk.is_empty(),
223+
oracle_chunk.is_empty()
224+
);
225+
226+
if !buf_list_chunk.is_empty() {
227+
// Verify buf_list's chunk is a prefix of oracle's chunk
228+
ensure!(
229+
oracle_chunk.starts_with(buf_list_chunk),
230+
"buf_list chunk is not a prefix of oracle chunk"
231+
);
232+
}
233+
}
234+
Self::BufAdvance(index) => {
235+
let amt = index.index(1 + num_bytes * 5 / 4);
236+
eprintln!("buf_advance: {}", amt);
237+
238+
// Skip if already past the end, as the oracle's Buf impl has a debug assertion
239+
// that checks position even when advancing by 0
240+
if buf_list.remaining() > 0 || amt == 0 && oracle.remaining() > 0 {
241+
// Cap the advance amount to the remaining bytes to avoid
242+
// hitting the debug assertion in std::io::Cursor's Buf
243+
// impl. While the Buf trait doesn't require this, the
244+
// oracle has a debug_assert that panics if we try to
245+
// advance past the end.
246+
let amt = amt.min(buf_list.remaining());
247+
buf_list.advance(amt);
248+
oracle.advance(amt);
249+
} else {
250+
eprintln!(" skipping: cursor past end");
251+
}
252+
}
253+
Self::BufChunksVectored(index) => {
254+
let num_iovs = index.index(1 + num_bytes);
255+
eprintln!("buf_chunks_vectored: {} iovs", num_iovs);
256+
257+
// First verify remaining() matches
258+
let buf_list_remaining = buf_list.remaining();
259+
let oracle_remaining = oracle.remaining();
260+
ensure!(
261+
buf_list_remaining == oracle_remaining,
262+
"chunks_vectored: remaining didn't match before \
263+
calling chunks_vectored: buf_list {} == oracle {}",
264+
buf_list_remaining,
265+
oracle_remaining
266+
);
267+
268+
let mut buf_list_iovs = vec![io::IoSlice::new(&[]); num_iovs];
269+
let mut oracle_iovs = vec![io::IoSlice::new(&[]); num_iovs];
270+
271+
let buf_list_filled = buf_list.chunks_vectored(&mut buf_list_iovs);
272+
let oracle_filled = oracle.chunks_vectored(&mut oracle_iovs);
273+
274+
// We can't directly compare filled counts or total bytes
275+
// because BufList may have multiple chunks while the oracle
276+
// (std::io::Cursor) is contiguous. When there are fewer iovs
277+
// than chunks, BufList will only fill what it can, while oracle
278+
// fills everything into one iov.
279+
//
280+
// Instead, we verify that:
281+
// 1. Both returned at least some data if there are bytes
282+
// remaining
283+
// 2. The data that was returned matches (buf_list's data is a
284+
// prefix of oracle's data)
285+
let buf_list_bytes: Vec<u8> = buf_list_iovs[..buf_list_filled]
286+
.iter()
287+
.flat_map(|iov| iov.as_ref().iter().copied())
288+
.collect();
289+
let oracle_bytes: Vec<u8> = oracle_iovs[..oracle_filled]
290+
.iter()
291+
.flat_map(|iov| iov.as_ref().iter().copied())
292+
.collect();
293+
294+
if buf_list_remaining > 0 && num_iovs > 0 {
295+
// If there are bytes remaining and iovs available, should
296+
// return some data.
297+
ensure!(
298+
!buf_list_bytes.is_empty(),
299+
"chunks_vectored should return some data \
300+
when remaining > 0 and num_iovs > 0"
301+
);
302+
ensure!(
303+
!oracle_bytes.is_empty(),
304+
"oracle chunks_vectored should return some data \
305+
when remaining > 0 and num_iovs > 0"
306+
);
307+
308+
// Verify that buf_list's data matches the beginning of
309+
// oracle's data.
310+
ensure!(
311+
oracle_bytes.starts_with(&buf_list_bytes),
312+
"buf_list chunks_vectored data should match beginning \
313+
of oracle data"
314+
);
315+
} else if buf_list_remaining == 0 {
316+
// If no bytes remaining, should return no data
317+
ensure!(
318+
buf_list_bytes.is_empty() && oracle_bytes.is_empty(),
319+
"chunks_vectored should return no data when \
320+
remaining == 0"
321+
);
322+
}
323+
// If num_iovs == 0, we can't check anything since no iovs were
324+
// provided. All we're doing is ensuring that buf_list doesn't
325+
// panic.
326+
}
327+
Self::BufCopyToBytes(index) => {
328+
let len = index.index(1 + num_bytes * 5 / 4);
329+
eprintln!("buf_copy_to_bytes: {}", len);
330+
331+
// copy_to_bytes can panic if len > remaining, so check first
332+
let buf_list_remaining = buf_list.remaining();
333+
let oracle_remaining = oracle.remaining();
334+
335+
if len <= buf_list_remaining && len <= oracle_remaining {
336+
let buf_list_bytes = buf_list.copy_to_bytes(len);
337+
let oracle_bytes = oracle.copy_to_bytes(len);
338+
339+
ensure!(buf_list_bytes == oracle_bytes, "copy_to_bytes didn't match");
340+
} else {
341+
// Both should panic, so just skip this operation
342+
eprintln!(" skipping: len {} > remaining {}", len, buf_list_remaining);
343+
}
344+
}
345+
Self::BufGetU8 => {
346+
eprintln!("buf_get_u8");
347+
348+
if buf_list.remaining() >= 1 && oracle.remaining() >= 1 {
349+
let buf_list_val = buf_list.get_u8();
350+
let oracle_val = oracle.get_u8();
351+
ensure!(
352+
buf_list_val == oracle_val,
353+
"get_u8 didn't match: buf_list {} == oracle {}",
354+
buf_list_val,
355+
oracle_val
356+
);
357+
} else {
358+
eprintln!(" skipping: not enough bytes remaining");
359+
}
360+
}
361+
Self::BufGetU64 => {
362+
eprintln!("buf_get_u64");
363+
364+
if buf_list.remaining() >= 8 && oracle.remaining() >= 8 {
365+
let buf_list_val = buf_list.get_u64();
366+
let oracle_val = oracle.get_u64();
367+
ensure!(
368+
buf_list_val == oracle_val,
369+
"get_u64 didn't match: buf_list {} == oracle {}",
370+
buf_list_val,
371+
oracle_val
372+
);
373+
} else {
374+
eprintln!(" skipping: not enough bytes remaining");
375+
}
376+
}
377+
Self::BufGetU64Le => {
378+
eprintln!("buf_get_u64_le");
379+
380+
if buf_list.remaining() >= 8 && oracle.remaining() >= 8 {
381+
let buf_list_val = buf_list.get_u64_le();
382+
let oracle_val = oracle.get_u64_le();
383+
ensure!(
384+
buf_list_val == oracle_val,
385+
"get_u64_le didn't match: buf_list {} == oracle {}",
386+
buf_list_val,
387+
oracle_val
388+
);
389+
} else {
390+
eprintln!(" skipping: not enough bytes remaining");
391+
}
392+
}
186393
#[cfg(feature = "tokio1")]
187394
Self::PollRead { capacity, filled } => {
188395
use std::{mem::MaybeUninit, pin::Pin, task::Poll};
@@ -322,3 +529,65 @@ impl CursorOp {
322529
fn cursor_ops_strategy() -> impl Strategy<Value = Vec<CursorOp>> {
323530
prop::collection::vec(any::<CursorOp>(), 0..256)
324531
}
532+
533+
#[test]
534+
fn test_cursor_buf_trait() {
535+
// Create a BufList with multiple chunks
536+
let mut buf_list = BufList::new();
537+
buf_list.push_chunk(&b"hello "[..]);
538+
buf_list.push_chunk(&b"world"[..]);
539+
buf_list.push_chunk(&b"!"[..]);
540+
541+
let mut cursor = crate::Cursor::new(buf_list.clone());
542+
543+
// Test remaining()
544+
assert_eq!(cursor.remaining(), 12);
545+
546+
// Test chunk()
547+
assert_eq!(cursor.chunk(), b"hello ");
548+
549+
// Test advance()
550+
cursor.advance(6);
551+
assert_eq!(cursor.remaining(), 6);
552+
assert_eq!(cursor.chunk(), b"world");
553+
554+
// Advance within the same chunk
555+
cursor.advance(3);
556+
assert_eq!(cursor.remaining(), 3);
557+
assert_eq!(cursor.chunk(), b"ld");
558+
559+
// Advance to the next chunk
560+
cursor.advance(2);
561+
assert_eq!(cursor.remaining(), 1);
562+
assert_eq!(cursor.chunk(), b"!");
563+
564+
// Advance to the end
565+
cursor.advance(1);
566+
assert_eq!(cursor.remaining(), 0);
567+
assert_eq!(cursor.chunk(), b"");
568+
569+
// Test chunks_vectored
570+
let mut cursor = crate::Cursor::new(buf_list.clone());
571+
let mut iovs = [io::IoSlice::new(&[]); 3];
572+
let filled = cursor.chunks_vectored(&mut iovs);
573+
assert_eq!(filled, 3);
574+
assert_eq!(iovs[0].as_ref(), b"hello ");
575+
assert_eq!(iovs[1].as_ref(), b"world");
576+
assert_eq!(iovs[2].as_ref(), b"!");
577+
578+
// Test chunks_vectored after advancing
579+
cursor.advance(6);
580+
let mut iovs = [io::IoSlice::new(&[]); 3];
581+
let filled = cursor.chunks_vectored(&mut iovs);
582+
assert_eq!(filled, 2);
583+
assert_eq!(iovs[0].as_ref(), b"world");
584+
assert_eq!(iovs[1].as_ref(), b"!");
585+
586+
// Test chunks_vectored with more iovs than remaining chunks
587+
let cursor2 = crate::Cursor::new(&buf_list);
588+
let mut iovs2 = [io::IoSlice::new(&[]); 10];
589+
let filled2 = cursor2.chunks_vectored(&mut iovs2);
590+
assert_eq!(filled2, 3, "Should only fill 3 iovs for 3 chunks");
591+
let total_bytes: usize = iovs2[..filled2].iter().map(|iov| iov.len()).sum();
592+
assert_eq!(total_bytes, 12, "Total bytes should be 12");
593+
}

0 commit comments

Comments
 (0)