Skip to content

Commit 6b8df0c

Browse files
authored
Fix: incorrect FromIterator<bool> implementation (#5056)
It incorrectly assumed that the upper bound for iterators is true. Also adds a `set_len` method. Signed-off-by: Connor Tsui <[email protected]>
1 parent 469b801 commit 6b8df0c

File tree

1 file changed

+94
-16
lines changed

1 file changed

+94
-16
lines changed

vortex-buffer/src/bit/buf_mut.rs

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,20 @@ impl BitBufferMut {
221221
unsafe { unset_bit_unchecked(self.buffer.as_mut_ptr(), self.offset + index) }
222222
}
223223

224+
/// Foces the length of the `BitBufferMut` to `new_len`.
225+
///
226+
/// # Safety
227+
///
228+
/// - `new_len` must be less than or equal to [`capacity()`](Self::capacity)
229+
/// - The elements at `old_len..new_len` must be initialized
230+
pub unsafe fn set_len(&mut self, new_len: usize) {
231+
debug_assert!(
232+
new_len <= self.capacity(),
233+
"`set_len` requires that new_len <= capacity()"
234+
);
235+
self.len = new_len;
236+
}
237+
224238
/// Truncate the buffer to the given length.
225239
pub fn truncate(&mut self, len: usize) {
226240
if len > self.len {
@@ -483,24 +497,35 @@ impl From<Vec<bool>> for BitBufferMut {
483497

484498
impl FromIterator<bool> for BitBufferMut {
485499
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
486-
let iter = iter.into_iter();
487-
let (low, high) = iter.size_hint();
488-
if let Some(len) = high {
489-
let mut buf = BitBufferMut::new_unset(len);
490-
for (i, v) in iter.enumerate() {
491-
if v {
492-
// SAFETY: i is in bounds
493-
unsafe { buf.set_unchecked(i) }
494-
}
495-
}
496-
buf
497-
} else {
498-
let mut buf = BitBufferMut::with_capacity(low);
499-
for v in iter {
500-
buf.append(v);
500+
let mut iter = iter.into_iter();
501+
502+
// Note that these hints might be incorrect.
503+
let (lower_bound, upper_bound_opt) = iter.size_hint();
504+
let capacity = upper_bound_opt.unwrap_or(lower_bound);
505+
506+
let mut buf = BitBufferMut::new_unset(capacity);
507+
508+
// Directly write within our known capacity.
509+
for i in 0..capacity {
510+
let Some(v) = iter.next() else {
511+
// SAFETY: We are definitely under the capacity and all values are already
512+
// initialized from `new_unset`.
513+
unsafe { buf.set_len(i) };
514+
return buf;
515+
};
516+
517+
if v {
518+
// SAFETY: We have ensured that we are within the capacity.
519+
unsafe { buf.set_unchecked(i) }
501520
}
502-
buf
503521
}
522+
523+
// Append the remaining items (as we do not know how many more there are).
524+
for v in iter {
525+
buf.append(v);
526+
}
527+
528+
buf
504529
}
505530
}
506531

@@ -918,4 +943,57 @@ mod tests {
918943
assert_eq!(frozen.offset(), 3);
919944
assert_eq!(frozen.len(), 6);
920945
}
946+
947+
#[test]
948+
fn test_from_iterator_with_incorrect_size_hint() {
949+
// This test catches a bug where FromIterator assumed the upper bound
950+
// from size_hint was accurate. The iterator contract allows the actual
951+
// count to exceed the upper bound, which could cause UB if we used
952+
// append_unchecked beyond the allocated capacity.
953+
954+
// Custom iterator that lies about its size hint.
955+
struct LyingIterator {
956+
values: Vec<bool>,
957+
index: usize,
958+
}
959+
960+
impl Iterator for LyingIterator {
961+
type Item = bool;
962+
963+
fn next(&mut self) -> Option<Self::Item> {
964+
(self.index < self.values.len()).then(|| {
965+
let val = self.values[self.index];
966+
self.index += 1;
967+
val
968+
})
969+
}
970+
971+
fn size_hint(&self) -> (usize, Option<usize>) {
972+
// Deliberately return an incorrect upper bound that's smaller
973+
// than the actual number of elements we'll yield.
974+
let remaining = self.values.len() - self.index;
975+
let lower = remaining.min(5); // Correct lower bound (but capped).
976+
let upper = Some(5); // Incorrect upper bound - we actually have more!
977+
(lower, upper)
978+
}
979+
}
980+
981+
// Create an iterator that claims to have at most 5 elements but actually has 10.
982+
let lying_iter = LyingIterator {
983+
values: vec![
984+
true, false, true, false, true, false, true, false, true, false,
985+
],
986+
index: 0,
987+
};
988+
989+
// Collect the iterator. This would cause UB in the old implementation
990+
// if it trusted the upper bound and used append_unchecked beyond capacity.
991+
let bit_buf: BitBufferMut = lying_iter.collect();
992+
993+
// Verify all 10 elements were collected correctly.
994+
assert_eq!(bit_buf.len(), 10);
995+
for i in 0..10 {
996+
assert_eq!(bit_buf.value(i), i % 2 == 0);
997+
}
998+
}
921999
}

0 commit comments

Comments
 (0)