Skip to content

Commit d04d64f

Browse files
committed
Safer sort partition
1 parent e2c96cc commit d04d64f

File tree

3 files changed

+34
-37
lines changed

3 files changed

+34
-37
lines changed

library/core/src/slice/sort/select.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::cfg_select;
1010
use crate::mem::{self, SizedTypeProperties};
1111
#[cfg(not(feature = "optimize_for_size"))]
1212
use crate::slice::sort::shared::pivot::choose_pivot;
13-
use crate::slice::sort::shared::smallsort::insertion_sort_shift_left;
13+
use crate::slice::sort::shared::smallsort::{insertion_sort_shift_left, panic_on_ord_violation};
1414
use crate::slice::sort::unstable::quicksort::partition;
1515

1616
/// Reorders the slice such that the element at `index` is at its final sorted position.
@@ -104,7 +104,9 @@ fn partition_at_index_loop<'a, T, F>(
104104
let pivot = &v[pivot_pos];
105105

106106
if !is_less(p, pivot) {
107-
let num_lt = partition(v, pivot_pos, &mut |a, b| !is_less(b, a));
107+
let Some(num_lt) = partition(v, pivot_pos, &mut |a, b| !is_less(b, a)) else {
108+
panic_on_ord_violation();
109+
};
108110

109111
// Continue sorting elements greater than the pivot. We know that `mid` contains
110112
// the pivot. So we can continue after `mid`.
@@ -122,14 +124,15 @@ fn partition_at_index_loop<'a, T, F>(
122124
}
123125
}
124126

125-
let mid = partition(v, pivot_pos, is_less);
127+
let Some(mid) = partition(v, pivot_pos, is_less) else {
128+
panic_on_ord_violation();
129+
};
126130

127131
// Split the slice into `left`, `pivot`, and `right`.
128132
let (left, right) = v.split_at_mut(mid);
129-
let (pivot, right) = right.split_at_mut(1);
130-
let pivot = &pivot[0];
131133

132134
if mid < index {
135+
let (pivot, right) = right.split_first_mut().unwrap();
133136
v = right;
134137
index = index - mid - 1;
135138
ancestor_pivot = Some(pivot);
@@ -198,7 +201,9 @@ fn median_of_medians<T, F: FnMut(&T, &T) -> bool>(mut v: &mut [T], is_less: &mut
198201
return;
199202
}
200203

201-
let p = median_of_ninthers(v, is_less);
204+
let Some(p) = median_of_ninthers(v, is_less) else {
205+
panic_on_ord_violation();
206+
};
202207

203208
if p == k {
204209
return;
@@ -216,7 +221,7 @@ fn median_of_medians<T, F: FnMut(&T, &T) -> bool>(mut v: &mut [T], is_less: &mut
216221
// Optimized for when `k` lies somewhere in the middle of the slice. Selects a pivot
217222
// as close as possible to the median of the slice. For more details on how the algorithm
218223
// operates, refer to the paper <https://drops.dagstuhl.de/opus/volltexte/2017/7612/pdf/LIPIcs-SEA-2017-24.pdf>.
219-
fn median_of_ninthers<T, F: FnMut(&T, &T) -> bool>(v: &mut [T], is_less: &mut F) -> usize {
224+
fn median_of_ninthers<T, F: FnMut(&T, &T) -> bool>(v: &mut [T], is_less: &mut F) -> Option<usize> {
220225
// use `saturating_mul` so the multiplication doesn't overflow on 16-bit platforms.
221226
let frac = if v.len() <= 1024 {
222227
v.len() / 12

library/core/src/slice/sort/shared/smallsort.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ unsafe fn bidirectional_merge<T: FreezeMarker, F: FnMut(&T, &T) -> bool>(
842842

843843
#[cfg_attr(not(panic = "immediate-abort"), inline(never), cold)]
844844
#[cfg_attr(panic = "immediate-abort", inline)]
845-
fn panic_on_ord_violation() -> ! {
845+
pub(crate) fn panic_on_ord_violation() -> ! {
846846
// This is indicative of a logic bug in the user-provided comparison function or Ord
847847
// implementation. They are expected to implement a total order as explained in the Ord
848848
// documentation.

library/core/src/slice/sort/unstable/quicksort.rs

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ use crate::mem::ManuallyDrop;
66
#[cfg(not(feature = "optimize_for_size"))]
77
use crate::slice::sort::shared::pivot::choose_pivot;
88
#[cfg(not(feature = "optimize_for_size"))]
9-
use crate::slice::sort::shared::smallsort::UnstableSmallSortTypeImpl;
9+
use crate::slice::sort::shared::smallsort::{UnstableSmallSortTypeImpl, panic_on_ord_violation};
1010
#[cfg(not(feature = "optimize_for_size"))]
1111
use crate::slice::sort::unstable::heapsort;
12-
use crate::{cfg_select, intrinsics, ptr};
12+
use crate::{cfg_select, ptr};
1313

1414
/// Sorts `v` recursively.
1515
///
@@ -53,16 +53,18 @@ pub(crate) fn quicksort<'a, T, F>(
5353

5454
// Continue sorting elements greater than the pivot. We know that `num_lt` contains
5555
// the pivot. So we can continue after `num_lt`.
56-
v = &mut v[(num_lt + 1)..];
56+
v = num_lt
57+
.and_then(|num_lt| v.get_mut(num_lt + 1..))
58+
.unwrap_or_else(|| panic_on_ord_violation());
5759
ancestor_pivot = None;
5860
continue;
5961
}
6062
}
6163

6264
// Partition the slice.
63-
let num_lt = partition(v, pivot_pos, is_less);
64-
// SAFETY: partition ensures that `num_lt` will be in-bounds.
65-
unsafe { intrinsics::assume(num_lt < v.len()) };
65+
let Some(num_lt) = partition(v, pivot_pos, is_less) else {
66+
panic_on_ord_violation();
67+
};
6668

6769
// Split the slice into `left`, `pivot`, and `right`.
6870
let (left, right) = v.split_at_mut(num_lt);
@@ -84,56 +86,46 @@ pub(crate) fn quicksort<'a, T, F>(
8486
/// on the left side of `v` followed by the other elements, notionally considered greater or
8587
/// equal to `pivot`.
8688
///
87-
/// Returns the number of elements that are compared true for `is_less(elem, pivot)`.
89+
/// Returns the number of elements that are compared true for `is_less(elem, pivot)`
90+
/// if `is_less` implements a total order.
8891
///
8992
/// If `is_less` does not implement a total order the resulting order and return value are
90-
/// unspecified. All original elements will remain in `v` and any possible modifications via
93+
/// unspecified, except that if `Some` is returned, the value will be in bounds of `v`.
94+
/// All original elements will remain in `v` and any possible modifications via
9195
/// interior mutability will be observable. Same is true if `is_less` panics or `v.len()`
9296
/// exceeds `scratch.len()`.
93-
pub(crate) fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> usize
97+
pub(crate) fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> Option<usize>
9498
where
9599
F: FnMut(&T, &T) -> bool,
96100
{
97101
let len = v.len();
98102

99103
// Allows for panic-free code-gen by proving this property to the compiler.
100-
if len == 0 {
101-
return 0;
104+
if len == 0 || pivot >= len {
105+
return None;
102106
}
103107

104-
if pivot >= len {
105-
intrinsics::abort();
106-
}
107-
108-
// SAFETY: We checked that `pivot` is in-bounds.
109-
unsafe {
110-
// Place the pivot at the beginning of slice.
111-
v.swap_unchecked(0, pivot);
112-
}
113-
let (pivot, v_without_pivot) = v.split_at_mut(1);
108+
v.swap(0, pivot);
109+
let (pivot, v_without_pivot) = v.split_first_mut()?;
114110

115111
// Assuming that Rust generates noalias LLVM IR we can be sure that a partition function
116112
// signature of the form `(v: &mut [T], pivot: &T)` guarantees that pivot and v can't alias.
117113
// Having this guarantee is crucial for optimizations. It's possible to copy the pivot value
118114
// into a stack value, but this creates issues for types with interior mutability mandating
119115
// a drop guard.
120-
let pivot = &mut pivot[0];
121116

122117
// This construct is used to limit the LLVM IR generated, which saves large amounts of
123118
// compile-time by only instantiating the code that is needed. Idea by Frank Steffahn.
124119
let num_lt = (const { inst_partition::<T, F>() })(v_without_pivot, pivot, is_less);
125120

126121
if num_lt >= len {
127-
intrinsics::abort();
122+
return None;
128123
}
129124

130-
// SAFETY: We checked that `num_lt` is in-bounds.
131-
unsafe {
132-
// Place the pivot between the two partitions.
133-
v.swap_unchecked(0, num_lt);
134-
}
125+
// Place the pivot between the two partitions.
126+
v.swap(0, num_lt);
135127

136-
num_lt
128+
Some(num_lt)
137129
}
138130

139131
const fn inst_partition<T, F: FnMut(&T, &T) -> bool>() -> fn(&mut [T], &T, &mut F) -> usize {

0 commit comments

Comments
 (0)