Skip to content

Commit 1c14889

Browse files
committed
stacking: Improve .select() with special case for 1D arrays
The 1D case is a simpler gather operation since we only select 1 element per index. Special-case it. The performance increase for this benchmark is that the benchmark runtime changes by -92% with this change.
1 parent 8637a7a commit 1c14889

File tree

4 files changed

+53
-9
lines changed

4 files changed

+53
-9
lines changed

benches/append.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ fn select_axis1(bench: &mut Bencher) {
2222
a.select(Axis(1), &selectable)
2323
});
2424
}
25+
26+
#[bench]
27+
fn select_1d(bench: &mut Bencher) {
28+
let a = Array::<f32, _>::zeros(1024);
29+
let mut selectable = (0..a.len()).step_by(17).collect::<Vec<_>>();
30+
selectable.extend(selectable.clone().iter().rev());
31+
32+
bench.iter(|| {
33+
a.select(Axis(0), &selectable)
34+
});
35+
}

src/impl_methods.rs

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -877,16 +877,35 @@ where
877877
S: Data,
878878
D: RemoveAxis,
879879
{
880-
let mut subs = vec![self.view(); indices.len()];
881-
for (&i, sub) in zip(indices, &mut subs[..]) {
882-
sub.collapse_axis(axis, i);
883-
}
884-
if subs.is_empty() {
885-
let mut dim = self.raw_dim();
886-
dim.set_axis(axis, 0);
887-
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
880+
if self.ndim() == 1 {
881+
// using .len_of(axis) means that we check if `axis` is in bounds too.
882+
let axis_len = self.len_of(axis);
883+
// bounds check the indices first
884+
if let Some(max_index) = indices.iter().cloned().max() {
885+
if max_index >= axis_len {
886+
panic!("ndarray: index {} is out of bounds in array of len {}",
887+
max_index, self.len_of(axis));
888+
}
889+
} // else: indices empty is ok
890+
let view = self.view().into_dimensionality::<Ix1>().unwrap();
891+
Array::from_iter(indices.iter().map(move |&index| {
892+
// Safety: bounds checked indexes
893+
unsafe {
894+
view.uget(index).clone()
895+
}
896+
})).into_dimensionality::<D>().unwrap()
888897
} else {
889-
concatenate(axis, &subs).unwrap()
898+
let mut subs = vec![self.view(); indices.len()];
899+
for (&i, sub) in zip(indices, &mut subs[..]) {
900+
sub.collapse_axis(axis, i);
901+
}
902+
if subs.is_empty() {
903+
let mut dim = self.raw_dim();
904+
dim.set_axis(axis, 0);
905+
unsafe { Array::from_shape_vec_unchecked(dim, vec![]) }
906+
} else {
907+
concatenate(axis, &subs).unwrap()
908+
}
890909
}
891910
}
892911

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
clippy::unreadable_literal,
1414
clippy::manual_map, // is not an error
1515
clippy::while_let_on_iterator, // is not an error
16+
clippy::from_iter_instead_of_collect, // using from_iter is good style
1617
)]
1718
#![cfg_attr(not(feature = "std"), no_std)]
1819

tests/array.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,19 @@ fn test_select() {
709709
assert_abs_diff_eq!(c, c_target);
710710
}
711711

712+
#[test]
713+
fn test_select_1d() {
714+
let x = arr1(&[0, 1, 2, 3, 4, 5, 6]);
715+
let r1 = x.select(Axis(0), &[1, 3, 4, 2, 2, 5]);
716+
assert_eq!(r1, arr1(&[1, 3, 4, 2, 2, 5]));
717+
// select nothing
718+
let r2 = x.select(Axis(0), &[]);
719+
assert_eq!(r2, arr1(&[]));
720+
// select nothing from empty
721+
let r3 = r2.select(Axis(0), &[]);
722+
assert_eq!(r3, arr1(&[]));
723+
}
724+
712725
#[test]
713726
fn diag() {
714727
let d = arr2(&[[1., 2., 3.0f32]]).into_diag();

0 commit comments

Comments
 (0)