Skip to content

Commit 6a65ef7

Browse files
committed
FEAT: Implement order optimizations for Baseiter
Implement axis merging - this preserves order of elements in the iteration but might simplify iteration. For example, in a contiguous matrix, a shape like [3, 4] can be merged into [1, 12]. Also allow arbitrary order optimization - we then try to iterate in memory order by sorting all axes, currently.
1 parent fe5aff9 commit 6a65ef7

File tree

1 file changed

+217
-7
lines changed

1 file changed

+217
-7
lines changed

src/iterators/mod.rs

Lines changed: 217 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,57 @@ mod windows;
1616
use std::iter::FromIterator;
1717
use std::marker::PhantomData;
1818
use std::ptr;
19+
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
1920
use alloc::vec::Vec;
2021

22+
use crate::imp_prelude::*;
2123
use crate::Ix1;
2224

23-
use super::{ArrayBase, ArrayView, ArrayViewMut, Axis, Data, NdProducer, RemoveAxis};
24-
use super::{Dimension, Ix, Ixs};
25+
use super::{NdProducer, RemoveAxis};
2526

2627
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2728
pub use self::lanes::{Lanes, LanesMut};
2829
pub use self::windows::Windows;
2930

30-
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
31+
use crate::dimension;
32+
33+
/// No traversal optmizations that would change element order or axis dimensions are permitted.
34+
///
35+
/// This option is suitable for example for the indexed iterator.
36+
pub(crate) enum NoOptimization { }
37+
38+
/// Preserve element iteration order, but modify dimensions if profitable; for example we can
39+
/// change from shape [10, 1] to [1, 10], because that axis has len == 1, without consequence here.
40+
///
41+
/// This option is suitable for example for the default .iter() iterator.
42+
pub(crate) enum PreserveOrder { }
43+
44+
/// Allow use of arbitrary element iteration order
45+
///
46+
/// This option is suitable for example for an arbitrary order iterator.
47+
pub(crate) enum ArbitraryOrder { }
48+
49+
pub(crate) trait OrderOption {
50+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = false;
51+
const ALLOW_ARBITRARY_ORDER: bool = false;
52+
}
53+
54+
impl OrderOption for NoOptimization { }
55+
56+
impl OrderOption for PreserveOrder {
57+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
58+
}
59+
60+
impl OrderOption for ArbitraryOrder {
61+
const ALLOW_REMOVE_REDUNDANT_AXES: bool = true;
62+
const ALLOW_ARBITRARY_ORDER: bool = true;
63+
}
3164

3265
/// Base for iterators over all axes.
3366
///
3467
/// Iterator element type is `*mut A`.
68+
///
69+
/// `F` is for layout/iteration order flags
3570
pub(crate) struct Baseiter<A, D> {
3671
ptr: *mut A,
3772
dim: D,
@@ -44,12 +79,43 @@ impl<A, D: Dimension> Baseiter<A, D> {
4479
/// to be correct to avoid performing an unsafe pointer offset while
4580
/// iterating.
4681
#[inline]
47-
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D> {
82+
pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter<A, D> {
83+
Self::new_with_order::<NoOptimization>(ptr, dim, strides)
84+
}
85+
}
86+
87+
impl<A, D: Dimension> Baseiter<A, D> {
88+
/// Creating a Baseiter is unsafe because shape and stride parameters need
89+
/// to be correct to avoid performing an unsafe pointer offset while
90+
/// iterating.
91+
#[inline]
92+
pub unsafe fn new_with_order<Flags: OrderOption>(mut ptr: *mut A, mut dim: D, mut strides: D)
93+
-> Baseiter<A, D>
94+
{
95+
debug_assert_eq!(dim.ndim(), strides.ndim());
96+
if Flags::ALLOW_ARBITRARY_ORDER {
97+
// iterate in memory order; merge axes if possible
98+
// make all axes positive and put the pointer back to the first element in memory
99+
let offset = dimension::offset_from_ptr_to_memory(&dim, &strides);
100+
ptr = ptr.offset(offset);
101+
for i in 0..strides.ndim() {
102+
let s = strides.get_stride(Axis(i));
103+
if s < 0 {
104+
strides.set_stride(Axis(i), -s);
105+
}
106+
}
107+
dimension::sort_axes_to_standard(&mut dim, &mut strides);
108+
}
109+
if Flags::ALLOW_REMOVE_REDUNDANT_AXES {
110+
// preserve element order but shift dimensions
111+
dimension::merge_axes_from_the_back(&mut dim, &mut strides);
112+
dimension::squeeze(&mut dim, &mut strides);
113+
}
48114
Baseiter {
49115
ptr,
50-
index: len.first_index(),
51-
dim: len,
52-
strides: stride,
116+
index: dim.first_index(),
117+
dim,
118+
strides,
53119
}
54120
}
55121
}
@@ -1496,3 +1562,147 @@ where
14961562
debug_assert_eq!(size, result.len());
14971563
result
14981564
}
1565+
1566+
#[cfg(test)]
1567+
#[cfg(feature = "std")]
1568+
mod tests {
1569+
use crate::prelude::*;
1570+
use super::Baseiter;
1571+
use super::{ArbitraryOrder, PreserveOrder, NoOptimization};
1572+
use itertools::assert_equal;
1573+
use itertools::Itertools;
1574+
1575+
// 3-d axis swaps
1576+
fn swaps() -> impl Iterator<Item=Vec<(usize, usize)>> {
1577+
vec![
1578+
vec![],
1579+
vec![(0, 1)],
1580+
vec![(0, 2)],
1581+
vec![(1, 2)],
1582+
vec![(0, 1), (1, 2)],
1583+
vec![(0, 1), (0, 2)],
1584+
].into_iter()
1585+
}
1586+
1587+
// 3-d axis inverts
1588+
fn inverts() -> impl Iterator<Item=Vec<Axis>> {
1589+
vec![
1590+
vec![],
1591+
vec![Axis(0)],
1592+
vec![Axis(1)],
1593+
vec![Axis(2)],
1594+
vec![Axis(0), Axis(1)],
1595+
vec![Axis(0), Axis(2)],
1596+
vec![Axis(1), Axis(2)],
1597+
vec![Axis(0), Axis(1), Axis(2)],
1598+
].into_iter()
1599+
}
1600+
1601+
#[test]
1602+
fn test_arbitrary_order() {
1603+
for swap in swaps() {
1604+
for invert in inverts() {
1605+
for &slice in &[false, true] {
1606+
// pattern is 0, 1; 4, 5; 8, 9; etc..
1607+
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
1608+
if slice {
1609+
a.slice_collapse(s![.., ..;2, ..]);
1610+
}
1611+
for &(i, j) in &swap {
1612+
a.swap_axes(i, j);
1613+
}
1614+
for &i in &invert {
1615+
a.invert_axis(i);
1616+
}
1617+
unsafe {
1618+
// Should have in-memory order for arbitrary order
1619+
let iter = Baseiter::new_with_order::<ArbitraryOrder>(a.as_mut_ptr(),
1620+
a.dim, a.strides);
1621+
if !slice {
1622+
assert_equal(iter.map(|ptr| *ptr), 0..a.len());
1623+
} else {
1624+
assert_eq!(iter.map(|ptr| *ptr).collect_vec(),
1625+
(0..a.len() * 2).filter(|&x| (x / 2) % 2 == 0).collect_vec());
1626+
}
1627+
}
1628+
}
1629+
}
1630+
}
1631+
}
1632+
1633+
#[test]
1634+
fn test_logical_order() {
1635+
for swap in swaps() {
1636+
for invert in inverts() {
1637+
for &slice in &[false, true] {
1638+
let mut a = Array::from_iter(0..24).into_shape((3, 4, 2)).unwrap();
1639+
for &(i, j) in &swap {
1640+
a.swap_axes(i, j);
1641+
}
1642+
for &i in &invert {
1643+
a.invert_axis(i);
1644+
}
1645+
if slice {
1646+
a.slice_collapse(s![.., ..;2, ..]);
1647+
}
1648+
1649+
unsafe {
1650+
let mut iter = Baseiter::new_with_order::<NoOptimization>(a.as_mut_ptr(),
1651+
a.dim, a.strides);
1652+
let mut index = Dim([0, 0, 0]);
1653+
let mut elts = 0;
1654+
while let Some(elt) = iter.next() {
1655+
assert_eq!(*elt, a[index]);
1656+
if let Some(index_) = a.raw_dim().next_for(index) {
1657+
index = index_;
1658+
}
1659+
elts += 1;
1660+
}
1661+
assert_eq!(elts, a.len());
1662+
}
1663+
}
1664+
}
1665+
}
1666+
}
1667+
1668+
#[test]
1669+
fn test_preserve_order() {
1670+
for swap in swaps() {
1671+
for invert in inverts() {
1672+
for &slice in &[false, true] {
1673+
let mut a = Array::from_iter(0..20).into_shape((2, 10, 1)).unwrap();
1674+
for &(i, j) in &swap {
1675+
a.swap_axes(i, j);
1676+
}
1677+
for &i in &invert {
1678+
a.invert_axis(i);
1679+
}
1680+
if slice {
1681+
a.slice_collapse(s![.., ..;2, ..]);
1682+
}
1683+
1684+
unsafe {
1685+
let mut iter = Baseiter::new_with_order::<PreserveOrder>(
1686+
a.as_mut_ptr(), a.dim, a.strides);
1687+
1688+
// check that axes have been merged (when it's easy to check)
1689+
if a.shape() == &[2, 10, 1] && invert.is_empty() {
1690+
assert_eq!(iter.dim, Dim([1, 1, 20]));
1691+
}
1692+
1693+
let mut index = Dim([0, 0, 0]);
1694+
let mut elts = 0;
1695+
while let Some(elt) = iter.next() {
1696+
assert_eq!(*elt, a[index]);
1697+
if let Some(index_) = a.raw_dim().next_for(index) {
1698+
index = index_;
1699+
}
1700+
elts += 1;
1701+
}
1702+
assert_eq!(elts, a.len());
1703+
}
1704+
}
1705+
}
1706+
}
1707+
}
1708+
}

0 commit comments

Comments
 (0)