Skip to content

Commit baa1018

Browse files
committed
FIX: Solve axis iteration order problem by sorting axes
1 parent a6b7033 commit baa1018

File tree

3 files changed

+110
-17
lines changed

3 files changed

+110
-17
lines changed

src/impl_owned_array.rs

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use crate::dimension;
66
use crate::error::{ErrorKind, ShapeError};
77
use crate::OwnedRepr;
88
use crate::Zip;
9-
use crate::NdProducer;
109

1110
/// Methods specific to `Array0`.
1211
///
@@ -251,15 +250,12 @@ impl<A, D> Array<A, D>
251250
/// [1., 1., 1., 1.],
252251
/// [1., 1., 1., 1.]]);
253252
/// ```
254-
pub fn try_append_array(&mut self, axis: Axis, array: ArrayView<A, D>)
253+
pub fn try_append_array(&mut self, axis: Axis, mut array: ArrayView<A, D>)
255254
-> Result<(), ShapeError>
256255
where
257256
A: Clone,
258257
D: RemoveAxis,
259258
{
260-
let self_axis_len = self.len_of(axis);
261-
let array_axis_len = array.len_of(axis);
262-
263259
let remaining_shape = self.raw_dim().remove_axis(axis);
264260
let array_rem_shape = array.raw_dim().remove_axis(axis);
265261

@@ -311,7 +307,7 @@ impl<A, D> Array<A, D>
311307
// make a raw view with the new row
312308
// safe because the data was "full"
313309
let tail_ptr = self.data.as_end_nonnull();
314-
let tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
310+
let mut tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
315311

316312
struct SetLenOnDrop<'a, A: 'a> {
317313
len: usize,
@@ -331,37 +327,86 @@ impl<A, D> Array<A, D>
331327
}
332328
}
333329

334-
// we have a problem here XXX
335-
//
336330
// To be robust for panics and drop the right elements, we want
337331
// to fill the tail in-order, so that we can drop the right elements on
338-
// panic. Don't know how to achieve that.
332+
// panic.
339333
//
340-
// It might be easier to retrace our steps in a scope guard to drop the right
341-
// elements.. (PartialArray style).
334+
// We have: Zip::from(tail_view).and(array)
335+
// Transform tail_view into standard order by inverting and moving its axes.
336+
// Keep the Zip traversal unchanged by applying the same axis transformations to
337+
// `array`. This ensures the Zip traverses the underlying memory in order.
342338
//
343-
// assign the new elements
339+
// XXX It would be possible to skip this transformation if the element
340+
// doesn't have drop. However, in the interest of code coverage, all elements
341+
// use this code initially.
342+
343+
if tail_view.ndim() > 1 {
344+
for i in 0..tail_view.ndim() {
345+
if tail_view.stride_of(Axis(i)) < 0 {
346+
tail_view.invert_axis(Axis(i));
347+
array.invert_axis(Axis(i));
348+
}
349+
}
350+
sort_axes_to_standard_order(&mut tail_view, &mut array);
351+
}
344352
Zip::from(tail_view).and(array)
353+
.debug_assert_c_order()
345354
.for_each(|to, from| {
346355
to.write(from.clone());
347356
length_guard.len += 1;
348357
});
349358

350-
//length_guard.len += len_to_append;
351-
dbg!(len_to_append);
352359
drop(length_guard);
353360

354361
// update array dimension
355362
self.strides = strides;
356363
self.dim = res_dim;
357-
dbg!(&self.dim);
358-
359364
}
360365
// multiple assertions after pointer & dimension update
361366
debug_assert_eq!(self.data.len(), self.len());
362367
debug_assert_eq!(self.len(), new_len);
363-
debug_assert!(self.is_standard_layout());
364368

365369
Ok(())
366370
}
367371
}
372+
373+
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
374+
where
375+
S: RawData,
376+
S2: RawData,
377+
D: Dimension,
378+
{
379+
if a.ndim() <= 1 {
380+
return;
381+
}
382+
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
383+
debug_assert!(a.is_standard_layout());
384+
}
385+
386+
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
387+
where
388+
D: Dimension,
389+
{
390+
debug_assert!(adim.ndim() > 1);
391+
debug_assert_eq!(adim.ndim(), bdim.ndim());
392+
// bubble sort axes
393+
let mut changed = true;
394+
while changed {
395+
changed = false;
396+
for i in 0..adim.ndim() - 1 {
397+
let axis_i = i;
398+
let next_axis = i + 1;
399+
400+
// make sure higher stride axes sort before.
401+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
402+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
403+
changed = true;
404+
adim.slice_mut().swap(axis_i, next_axis);
405+
astrides.slice_mut().swap(axis_i, next_axis);
406+
bdim.slice_mut().swap(axis_i, next_axis);
407+
bstrides.slice_mut().swap(axis_i, next_axis);
408+
}
409+
}
410+
}
411+
}
412+

src/zip/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,13 @@ macro_rules! map_impl {
673673
self.build_and(part)
674674
}
675675

676+
#[allow(unused)]
677+
#[inline]
678+
pub(crate) fn debug_assert_c_order(self) -> Self {
679+
debug_assert!(self.layout.is(CORDER) || self.layout_tendency >= 0);
680+
self
681+
}
682+
676683
fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
677684
where P: NdProducer<Dim=D>,
678685
{

tests/append.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,44 @@ fn append_array1() {
8787
[5., 5., 4., 4.],
8888
[3., 3., 2., 2.]]);
8989
}
90+
91+
#[test]
92+
fn append_array_3d() {
93+
let mut a = Array::zeros((0, 2, 2));
94+
a.try_append_array(Axis(0), array![[[0, 1], [2, 3]]].view()).unwrap();
95+
println!("{:?}", a);
96+
97+
let mut aa = array![[[51, 52], [53, 54]], [[55, 56], [57, 58]]];
98+
let mut av = aa.view();
99+
println!("Send {:?} to append", av);
100+
a.try_append_array(Axis(0), av.clone()).unwrap();
101+
102+
a.swap_axes(0, 1);
103+
let mut aa = array![[[71, 72], [73, 74]], [[75, 76], [77, 78]]];
104+
let mut av = aa.view();
105+
av.swap_axes(0, 1);
106+
println!("Send {:?} to append", av);
107+
a.try_append_array(Axis(1), av.clone()).unwrap();
108+
println!("{:?}", a);
109+
let mut aa = array![[[81, 82], [83, 84]], [[85, 86], [87, 88]]];
110+
let mut av = aa.view();
111+
av.swap_axes(0, 1);
112+
println!("Send {:?} to append", av);
113+
a.try_append_array(Axis(1), av).unwrap();
114+
println!("{:?}", a);
115+
assert_eq!(a,
116+
array![[[0, 1],
117+
[51, 52],
118+
[55, 56],
119+
[71, 72],
120+
[75, 76],
121+
[81, 82],
122+
[85, 86]],
123+
[[2, 3],
124+
[53, 54],
125+
[57, 58],
126+
[73, 74],
127+
[77, 78],
128+
[83, 84],
129+
[87, 88]]]);
130+
}

0 commit comments

Comments
 (0)