Skip to content

Commit abfdf4d

Browse files
jturner314bluss
authored andcommitted
Improve layout heuristic for sum_axis
1 parent acc3bfd commit abfdf4d

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

src/numeric/impl_numeric.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use num_traits::{self, FromPrimitive, Zero};
1212
use std::ops::{Add, Div, Mul};
1313

1414
use crate::imp_prelude::*;
15-
use crate::itertools::enumerate;
1615
use crate::numeric_util;
1716

1817
/// # Numerical Methods for Arrays
@@ -246,22 +245,16 @@ where
246245
A: Clone + Zero + Add<Output = A>,
247246
D: RemoveAxis,
248247
{
249-
let n = self.len_of(axis);
250-
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
251-
let stride = self.strides()[axis.index()];
252-
if self.ndim() == 2 && stride == 1 {
253-
// contiguous along the axis we are summing
254-
let ax = axis.index();
255-
for (i, elt) in enumerate(&mut res) {
256-
*elt = self.index_axis(Axis(1 - ax), i).sum();
257-
}
248+
let min_stride_axis = self.dim.min_stride_axis(&self.strides);
249+
if axis == min_stride_axis {
250+
crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum())
258251
} else {
259-
for i in 0..n {
260-
let view = self.index_axis(axis, i);
261-
res = res + &view;
252+
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
253+
for subview in self.axis_iter(axis) {
254+
res = res + &subview;
262255
}
256+
res
263257
}
264-
res
265258
}
266259

267260
/// Return mean along `axis`.

0 commit comments

Comments
 (0)