Skip to content

Commit f003aaa

Browse files
committed
append: Solve axis iteration order problem by sorting axes
1 parent a7e3aab commit f003aaa

File tree

3 files changed

+110
-13
lines changed

3 files changed

+110
-13
lines changed

src/impl_owned_array.rs

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ impl<A, D> Array<A, D>
251251
/// [1., 1., 1., 1.],
252252
/// [1., 1., 1., 1.]]);
253253
/// ```
254-
pub fn try_append_array(&mut self, axis: Axis, array: ArrayView<A, D>)
254+
pub fn try_append_array(&mut self, axis: Axis, mut array: ArrayView<A, D>)
255255
-> Result<(), ShapeError>
256256
where
257257
A: Clone,
@@ -310,7 +310,7 @@ impl<A, D> Array<A, D>
310310
// make a raw view with the new row
311311
// safe because the data was "full"
312312
let tail_ptr = self.data.as_end_nonnull();
313-
let tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
313+
let mut tail_view = RawArrayViewMut::new(tail_ptr, array_shape, strides.clone());
314314

315315
struct SetLenOnDrop<'a, A: 'a> {
316316
len: usize,
@@ -330,37 +330,86 @@ impl<A, D> Array<A, D>
330330
}
331331
}
332332

333-
// we have a problem here XXX
334-
//
335333
// To be robust for panics and drop the right elements, we want
336334
// to fill the tail in-order, so that we can drop the right elements on
337-
// panic. Don't know how to achieve that.
335+
// panic.
338336
//
339-
// It might be easier to retrace our steps in a scope guard to drop the right
340-
// elements.. (PartialArray style).
337+
// We have: Zip::from(tail_view).and(array)
338+
// Transform tail_view into standard order by inverting and moving its axes.
339+
// Keep the Zip traversal unchanged by applying the same axis transformations to
340+
// `array`. This ensures the Zip traverses the underlying memory in order.
341341
//
342-
// assign the new elements
342+
// XXX It would be possible to skip this transformation if the element
343+
// doesn't have drop. However, in the interest of code coverage, all elements
344+
// use this code initially.
345+
346+
if tail_view.ndim() > 1 {
347+
for i in 0..tail_view.ndim() {
348+
if tail_view.stride_of(Axis(i)) < 0 {
349+
tail_view.invert_axis(Axis(i));
350+
array.invert_axis(Axis(i));
351+
}
352+
}
353+
sort_axes_to_standard_order(&mut tail_view, &mut array);
354+
}
343355
Zip::from(tail_view).and(array)
356+
.debug_assert_c_order()
344357
.for_each(|to, from| {
345358
to.write(from.clone());
346359
length_guard.len += 1;
347360
});
348361

349-
//length_guard.len += len_to_append;
350-
dbg!(len_to_append);
351362
drop(length_guard);
352363

353364
// update array dimension
354365
self.strides = strides;
355366
self.dim = res_dim;
356-
dbg!(&self.dim);
357-
358367
}
359368
// multiple assertions after pointer & dimension update
360369
debug_assert_eq!(self.data.len(), self.len());
361370
debug_assert_eq!(self.len(), new_len);
362-
debug_assert!(self.is_standard_layout());
363371

364372
Ok(())
365373
}
366374
}
375+
376+
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
377+
where
378+
S: RawData,
379+
S2: RawData,
380+
D: Dimension,
381+
{
382+
if a.ndim() <= 1 {
383+
return;
384+
}
385+
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
386+
debug_assert!(a.is_standard_layout());
387+
}
388+
389+
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
390+
where
391+
D: Dimension,
392+
{
393+
debug_assert!(adim.ndim() > 1);
394+
debug_assert_eq!(adim.ndim(), bdim.ndim());
395+
// bubble sort axes
396+
let mut changed = true;
397+
while changed {
398+
changed = false;
399+
for i in 0..adim.ndim() - 1 {
400+
let axis_i = i;
401+
let next_axis = i + 1;
402+
403+
// make sure higher stride axes sort before.
404+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
405+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
406+
changed = true;
407+
adim.slice_mut().swap(axis_i, next_axis);
408+
astrides.slice_mut().swap(axis_i, next_axis);
409+
bdim.slice_mut().swap(axis_i, next_axis);
410+
bstrides.slice_mut().swap(axis_i, next_axis);
411+
}
412+
}
413+
}
414+
}
415+

src/zip/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,13 @@ macro_rules! map_impl {
673673
self.build_and(part)
674674
}
675675

676+
#[allow(unused)]
677+
#[inline]
678+
pub(crate) fn debug_assert_c_order(self) -> Self {
679+
debug_assert!(self.layout.is(CORDER) || self.layout_tendency >= 0);
680+
self
681+
}
682+
676683
fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
677684
where P: NdProducer<Dim=D>,
678685
{

tests/append.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,44 @@ fn append_array1() {
105105
[5., 5., 4., 4.],
106106
[3., 3., 2., 2.]]);
107107
}
108+
109+
#[test]
110+
fn append_array_3d() {
111+
let mut a = Array::zeros((0, 2, 2));
112+
a.try_append_array(Axis(0), array![[[0, 1], [2, 3]]].view()).unwrap();
113+
println!("{:?}", a);
114+
115+
let aa = array![[[51, 52], [53, 54]], [[55, 56], [57, 58]]];
116+
let av = aa.view();
117+
println!("Send {:?} to append", av);
118+
a.try_append_array(Axis(0), av.clone()).unwrap();
119+
120+
a.swap_axes(0, 1);
121+
let aa = array![[[71, 72], [73, 74]], [[75, 76], [77, 78]]];
122+
let mut av = aa.view();
123+
av.swap_axes(0, 1);
124+
println!("Send {:?} to append", av);
125+
a.try_append_array(Axis(1), av.clone()).unwrap();
126+
println!("{:?}", a);
127+
let aa = array![[[81, 82], [83, 84]], [[85, 86], [87, 88]]];
128+
let mut av = aa.view();
129+
av.swap_axes(0, 1);
130+
println!("Send {:?} to append", av);
131+
a.try_append_array(Axis(1), av).unwrap();
132+
println!("{:?}", a);
133+
assert_eq!(a,
134+
array![[[0, 1],
135+
[51, 52],
136+
[55, 56],
137+
[71, 72],
138+
[75, 76],
139+
[81, 82],
140+
[85, 86]],
141+
[[2, 3],
142+
[53, 54],
143+
[57, 58],
144+
[73, 74],
145+
[77, 78],
146+
[83, 84],
147+
[87, 88]]]);
148+
}

0 commit comments

Comments
 (0)