Skip to content

Commit 1157763

Browse files
committed
Improve fold for 2-D arrays
1 parent d4d8088 commit 1157763

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

src/impl_methods.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,13 +1854,25 @@ where
18541854
} else {
18551855
let mut v = self.view();
18561856
// put the narrowest axis at the last position
1857-
if v.ndim() > 1 {
1858-
let last = v.ndim() - 1;
1859-
let narrow_axis = v.axes()
1860-
.filter(|ax| ax.len() > 1)
1861-
.min_by_key(|ax| ax.stride().abs())
1862-
.map_or(last, |ax| ax.axis().index());
1863-
v.swap_axes(last, narrow_axis);
1857+
match v.ndim() {
1858+
0 | 1 => {}
1859+
2 => {
1860+
if self.len_of(Axis(1)) <= 1
1861+
|| self.len_of(Axis(0)) > 1
1862+
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
1863+
{
1864+
v.swap_axes(0, 1);
1865+
}
1866+
}
1867+
n => {
1868+
let last = n - 1;
1869+
let narrow_axis = v
1870+
.axes()
1871+
.filter(|ax| ax.len() > 1)
1872+
.min_by_key(|ax| ax.stride().abs())
1873+
.map_or(last, |ax| ax.axis().index());
1874+
v.swap_axes(last, narrow_axis);
1875+
}
18641876
}
18651877
v.into_elements_base().fold(init, f)
18661878
}

0 commit comments

Comments
 (0)