Skip to content

Commit 7d5c3d3

Browse files
committed
FIX: Solve axis iteration order problem by sorting axes
1 parent da10c42 commit 7d5c3d3

File tree

3 files changed

+110
-17
lines changed

3 files changed

+110
-17
lines changed

src/impl_owned_array.rs

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use crate::dimension;
66
use crate::error::{ErrorKind, ShapeError};
77
use crate::OwnedRepr;
88
use crate::Zip;
9-
use crate::NdProducer;
109

1110
/// Methods specific to `Array0`.
1211
///
@@ -251,15 +250,12 @@ impl<A, D> Array<A, D>
251250
/// [1., 1., 1., 1.],
252251
/// [1., 1., 1., 1.]]);
253252
/// ```
254-
pub fn try_append_array(&mut self, axis: Axis, array: ArrayView<A, D>)
253+
pub fn try_append_array(&mut self, axis: Axis, mut array: ArrayView<A, D>)
255254
-> Result<(), ShapeError>
256255
where
257256
A: Clone,
258257
D: RemoveAxis,
259258
{
260-
let self_axis_len = self.len_of(axis);
261-
let array_axis_len = array.len_of(axis);
262-
263259
let remaining_shape = self.raw_dim().remove_axis(axis);
264260
let array_rem_shape = array.raw_dim().remove_axis(axis);
265261

@@ -310,7 +306,7 @@ impl<A, D> Array<A, D>
310306
// make a raw view with the new row
311307
// safe because the data was "full"
312308
let tail_ptr = self.data.as_end_nonnull();
313-
let tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
309+
let mut tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
314310

315311
struct SetLenOnDrop<'a, A: 'a> {
316312
len: usize,
@@ -330,37 +326,86 @@ impl<A, D> Array<A, D>
330326
}
331327
}
332328

333-
// we have a problem here XXX
334-
//
335329
// To be robust for panics and drop the right elements, we want
336330
// to fill the tail in-order, so that we can drop the right elements on
337-
// panic. Don't know how to achieve that.
331+
// panic.
338332
//
339-
// It might be easier to retrace our steps in a scope guard to drop the right
340-
// elements.. (PartialArray style).
333+
// We have: Zip::from(tail_view).and(array)
334+
// Transform tail_view into standard order by inverting and moving its axes.
335+
// Keep the Zip traversal unchanged by applying the same axis transformations to
336+
// `array`. This ensures the Zip traverses the underlying memory in order.
341337
//
342-
// assign the new elements
338+
// XXX It would be possible to skip this transformation if the element
339+
// doesn't have drop. However, in the interest of code coverage, all elements
340+
// use this code initially.
341+
342+
if tail_view.ndim() > 1 {
343+
for i in 0..tail_view.ndim() {
344+
if tail_view.stride_of(Axis(i)) < 0 {
345+
tail_view.invert_axis(Axis(i));
346+
array.invert_axis(Axis(i));
347+
}
348+
}
349+
sort_axes_to_standard_order(&mut tail_view, &mut array);
350+
}
343351
Zip::from(tail_view).and(array)
352+
.debug_assert_c_order()
344353
.for_each(|to, from| {
345354
to.write(from.clone());
346355
length_guard.len += 1;
347356
});
348357

349-
//length_guard.len += len_to_append;
350-
dbg!(len_to_append);
351358
drop(length_guard);
352359

353360
// update array dimension
354361
self.strides = strides;
355362
self.dim = res_dim;
356-
dbg!(&self.dim);
357-
358363
}
359364
// multiple assertions after pointer & dimension update
360365
debug_assert_eq!(self.data.len(), self.len());
361366
debug_assert_eq!(self.len(), new_len);
362-
debug_assert!(self.is_standard_layout());
363367

364368
Ok(())
365369
}
366370
}
371+
372+
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
373+
where
374+
S: RawData,
375+
S2: RawData,
376+
D: Dimension,
377+
{
378+
if a.ndim() <= 1 {
379+
return;
380+
}
381+
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
382+
debug_assert!(a.is_standard_layout());
383+
}
384+
385+
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
386+
where
387+
D: Dimension,
388+
{
389+
debug_assert!(adim.ndim() > 1);
390+
debug_assert_eq!(adim.ndim(), bdim.ndim());
391+
// bubble sort axes
392+
let mut changed = true;
393+
while changed {
394+
changed = false;
395+
for i in 0..adim.ndim() - 1 {
396+
let axis_i = i;
397+
let next_axis = i + 1;
398+
399+
// make sure higher stride axes sort before.
400+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
401+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
402+
changed = true;
403+
adim.slice_mut().swap(axis_i, next_axis);
404+
astrides.slice_mut().swap(axis_i, next_axis);
405+
bdim.slice_mut().swap(axis_i, next_axis);
406+
bstrides.slice_mut().swap(axis_i, next_axis);
407+
}
408+
}
409+
}
410+
}
411+

src/zip/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,13 @@ macro_rules! map_impl {
673673
self.build_and(part)
674674
}
675675

676+
#[allow(unused)]
677+
#[inline]
678+
pub(crate) fn debug_assert_c_order(self) -> Self {
679+
debug_assert!(self.layout.is(CORDER) || self.layout_tendency >= 0);
680+
self
681+
}
682+
676683
fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
677684
where P: NdProducer<Dim=D>,
678685
{

tests/append.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,44 @@ fn append_array1() {
8787
[5., 5., 4., 4.],
8888
[3., 3., 2., 2.]]);
8989
}
90+
91+
#[test]
92+
fn append_array_3d() {
93+
let mut a = Array::zeros((0, 2, 2));
94+
a.try_append_array(Axis(0), array![[[0, 1], [2, 3]]].view()).unwrap();
95+
println!("{:?}", a);
96+
97+
let aa = array![[[51, 52], [53, 54]], [[55, 56], [57, 58]]];
98+
let av = aa.view();
99+
println!("Send {:?} to append", av);
100+
a.try_append_array(Axis(0), av.clone()).unwrap();
101+
102+
a.swap_axes(0, 1);
103+
let aa = array![[[71, 72], [73, 74]], [[75, 76], [77, 78]]];
104+
let mut av = aa.view();
105+
av.swap_axes(0, 1);
106+
println!("Send {:?} to append", av);
107+
a.try_append_array(Axis(1), av.clone()).unwrap();
108+
println!("{:?}", a);
109+
let aa = array![[[81, 82], [83, 84]], [[85, 86], [87, 88]]];
110+
let mut av = aa.view();
111+
av.swap_axes(0, 1);
112+
println!("Send {:?} to append", av);
113+
a.try_append_array(Axis(1), av).unwrap();
114+
println!("{:?}", a);
115+
assert_eq!(a,
116+
array![[[0, 1],
117+
[51, 52],
118+
[55, 56],
119+
[71, 72],
120+
[75, 76],
121+
[81, 82],
122+
[85, 86]],
123+
[[2, 3],
124+
[53, 54],
125+
[57, 58],
126+
[73, 74],
127+
[77, 78],
128+
[83, 84],
129+
[87, 88]]]);
130+
}

0 commit comments

Comments
 (0)