Skip to content

Commit bfdc6d3

Browse files
authored
Chore: Clean up extend impls for BufferMut (#5055)
_I'm moving this from #5049_ since things are becoming more complicated. The only logical change here is at the top of `extend_iter`: ```rust // Since we do not know the length of the iterator, we can only guess how much memory we // need to reserve. Note that these hints may be inaccurate. let (lower_bound, upper_bound_opt) = iter.size_hint(); self.reserve(upper_bound_opt.unwrap_or(lower_bound)); ``` And the other is now using [`offset_from_unsigned`](rust-lang/rust#95892) (used to be called `ptr_sub` on nightly) --------- Signed-off-by: Connor Tsui <[email protected]>
1 parent f5c1c70 commit bfdc6d3

File tree

1 file changed

+86
-33
lines changed

1 file changed

+86
-33
lines changed

vortex-buffer/src/buffer_mut.rs

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -477,59 +477,113 @@ impl<T> AsMut<[T]> for BufferMut<T> {
477477
}
478478

479479
impl<T> BufferMut<T> {
480+
/// A helper method for the two [`Extend`] implementations.
481+
///
482+
/// We use the lower bound hint on the iterator to manually write data, and then we continue to
483+
/// push items normally past the lower bound.
480484
fn extend_iter(&mut self, mut iter: impl Iterator<Item = T>) {
481-
// Attempt to reserve enough memory up-front, although this is only a lower bound.
482-
let (lower, _) = iter.size_hint();
483-
self.reserve(lower);
485+
// Since we do not know the length of the iterator, we can only guess how much memory we
486+
// need to reserve. Note that these hints may be inaccurate.
487+
let (lower_bound, upper_bound_opt) = iter.size_hint();
488+
489+
// In the case that the upper bound is adversarial, we put a hard limit on the amount of
490+
// memory we reserve (and the OS should handle the rest with zero pages).
491+
let reserve_amount = upper_bound_opt
492+
.unwrap_or(lower_bound)
493+
.min(i32::MAX as usize);
494+
self.reserve(reserve_amount);
484495

485-
let remaining = self.capacity() - self.len();
496+
let unwritten = self.capacity() - self.len();
486497

498+
// We store `begin` in the case that the lower bound hint is incorrect.
487499
let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
488500
let mut dst: *mut T = begin.cast_mut();
489-
for _ in 0..remaining {
490-
if let Some(item) = iter.next() {
491-
unsafe {
492-
// SAFETY: We know we have enough capacity to write the item.
493-
dst.write(item);
494-
// Note. we used to have dst.add(iteration).write(item), here.
495-
// however this was much slower than just incrementing dst.
496-
dst = dst.add(1);
497-
}
498-
} else {
501+
502+
// As a first step, we manually iterate the iterator up to the known capacity.
503+
for _ in 0..unwritten {
504+
let Some(item) = iter.next() else {
505+
// The lower bound hint may be incorrect.
499506
break;
500-
}
507+
};
508+
509+
// SAFETY: We have reserved enough capacity to hold this item, and `dst` is a pointer
510+
// derived from a valid reference to byte data.
511+
unsafe { dst.write(item) };
512+
513+
// Note: We used to have `dst.add(iteration).write(item)`, here. However this was much
514+
// slower than just incrementing `dst`.
515+
// SAFETY: The offsets fits in `isize`, and because we were able to reserve the memory
516+
// we know that `add` will not overflow.
517+
unsafe { dst = dst.add(1) };
501518
}
502519

503-
// TODO(joe): replace with ptr_sub when stable
504-
let length = self.len() + unsafe { dst.byte_offset_from(begin) as usize / size_of::<T>() };
520+
// SAFETY: `dst` was derived from `begin`, which were both valid references to byte data,
521+
// and since the only operation that `dst` has is `add`, we know that `dst >= begin`.
522+
let items_written = unsafe { dst.offset_from_unsigned(begin) };
523+
let length = self.len() + items_written;
524+
525+
// SAFETY: We have written valid items between the old length and the new length.
505526
unsafe { self.set_len(length) };
506527

507-
// Append remaining elements
528+
// Finally, since the iterator will have arbitrarily more items to yield, we push the
529+
// remaining items normally.
508530
iter.for_each(|item| self.push(item));
509531
}
510532

511-
/// An unsafe variant of the `Extend` trait and its `extend` method that receives what the
512-
/// caller guarantees to be an iterator with a trusted upper bound.
533+
/// Extends the `BufferMut` with an iterator with `TrustedLen`.
534+
///
535+
/// The caller guarantees that the iterator will have a trusted upper bound, which allows the
536+
/// implementation to reserve all of the memory needed up front.
513537
pub fn extend_trusted<I: TrustedLen<Item = T>>(&mut self, iter: I) {
514-
// Reserve all memory upfront since it's an exact upper bound
515-
let (_, high) = iter.size_hint();
516-
self.reserve(high.vortex_expect("TrustedLen iterator didn't have valid upper bound"));
538+
// Since we know the exact upper bound (from `TrustedLen`), we can reserve all of the memory
539+
// for this operation up front.
540+
let (_, upper_bound) = iter.size_hint();
541+
self.reserve(
542+
upper_bound
543+
.vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
544+
);
517545

546+
// We store `begin` in the case that the upper bound hint is incorrect.
518547
let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast();
519548
let mut dst: *mut T = begin.cast_mut();
549+
520550
iter.for_each(|item| {
521-
unsafe {
522-
// SAFETY: We know we have enough capacity to write the item.
523-
dst.write(item);
524-
// Note. we used to have dst.add(iteration).write(item), here.
525-
// however this was much slower than just incrementing dst.
526-
dst = dst.add(1);
527-
}
551+
// SAFETY: We have reserved enough capacity to hold this item, and `dst` is a pointer
552+
// derived from a valid reference to byte data.
553+
unsafe { dst.write(item) };
554+
555+
// Note: We used to have `dst.add(iteration).write(item)`, here. However this was much
556+
// slower than just incrementing `dst`.
557+
// SAFETY: The offsets fits in `isize`, and because we were able to reserve the memory
558+
// we know that `add` will not overflow.
559+
unsafe { dst = dst.add(1) };
528560
});
529-
// TODO(joe): replace with ptr_sub when stable
530-
let length = self.len() + unsafe { dst.byte_offset_from(begin) as usize / size_of::<T>() };
561+
562+
// SAFETY: `dst` was derived from `begin`, which were both valid references to byte data,
563+
// and since the only operation that `dst` has is `add`, we know that `dst >= begin`.
564+
let items_written = unsafe { dst.offset_from_unsigned(begin) };
565+
let length = self.len() + items_written;
566+
567+
// SAFETY: We have written valid items between the old length and the new length.
531568
unsafe { self.set_len(length) };
532569
}
570+
571+
/// Creates a `BufferMut` from an iterator with a trusted length.
572+
///
573+
/// Internally, this calls [`extend_trusted()`](Self::extend_trusted).
574+
pub fn from_trusted_len_iter<I>(iter: I) -> Self
575+
where
576+
I: TrustedLen<Item = T>,
577+
{
578+
let (_, upper_bound) = iter.size_hint();
579+
let mut buffer = Self::with_capacity(
580+
upper_bound
581+
.vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"),
582+
);
583+
584+
buffer.extend_trusted(iter);
585+
buffer
586+
}
533587
}
534588

535589
impl<T> Extend<T> for BufferMut<T> {
@@ -554,7 +608,6 @@ impl<T> FromIterator<T> for BufferMut<T> {
554608
// We don't infer the capacity here and just let the first call to `extend` do it for us.
555609
let mut buffer = Self::with_capacity(0);
556610
buffer.extend(iter);
557-
debug_assert_eq!(buffer.alignment(), Alignment::of::<T>());
558611
buffer
559612
}
560613
}

0 commit comments

Comments
 (0)