Skip to content

Commit 29896d8

Browse files
committed
append: Fix situations where we need to recompute stride
When the axis has length 0, or 1, we need to carefully compute new strides.
1 parent f003aaa commit 29896d8

File tree

2 files changed

+85
-10
lines changed

2 files changed

+85
-10
lines changed

src/impl_owned_array.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ impl<A, D> Array<A, D>
261261
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
262262
}
263263

264+
let current_axis_len = self.len_of(axis);
264265
let remaining_shape = self.raw_dim().remove_axis(axis);
265266
let array_rem_shape = array.raw_dim().remove_axis(axis);
266267

@@ -280,22 +281,46 @@ impl<A, D> Array<A, D>
280281

281282
let self_is_empty = self.is_empty();
282283

283-
// array must be empty or have `axis` as the outermost (longest stride)
284-
// axis
285-
if !(self_is_empty ||
286-
self.axes().max_by_key(|ax| ax.stride).map(|ax| ax.axis) == Some(axis))
287-
{
288-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
284+
// array must be empty or have `axis` as the outermost (longest stride) axis
285+
if !self_is_empty && current_axis_len > 1 {
286+
// `axis` must be max stride axis or equal to its stride
287+
let max_stride_axis = self.axes().max_by_key(|ax| ax.stride).unwrap();
288+
if max_stride_axis.axis != axis && max_stride_axis.stride > self.stride_of(axis) {
289+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
290+
}
289291
}
290292

291293
// array must be be "full" (have no exterior holes)
292294
if self.len() != self.data.len() {
293295
return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout));
294296
}
297+
295298
let strides = if self_is_empty {
296-
// recompute strides - if the array was previously empty, it could have
297-
// zeros in strides.
298-
res_dim.default_strides()
299+
// recompute strides - if the array was previously empty, it could have zeros in
300+
// strides.
301+
// The new order is based on c/f-contig but must have `axis` as outermost axis.
302+
if axis == Axis(self.ndim() - 1) {
303+
// prefer f-contig when appending to the last axis
304+
// Axis n - 1 is outermost axis
305+
res_dim.fortran_strides()
306+
} else {
307+
// Default with modification
308+
res_dim.slice_mut().swap(0, axis.index());
309+
let mut strides = res_dim.default_strides();
310+
res_dim.slice_mut().swap(0, axis.index());
311+
strides.slice_mut().swap(0, axis.index());
312+
strides
313+
}
314+
} else if current_axis_len == 1 {
315+
// This is the outermost/longest stride axis; so we find the max across the other axes
316+
let new_stride = self.axes().fold(1, |acc, ax| {
317+
if ax.axis == axis { acc } else {
318+
Ord::max(acc, ax.len as isize * ax.stride)
319+
}
320+
});
321+
let mut strides = self.strides.clone();
322+
strides[axis.index()] = new_stride as usize;
323+
strides
299324
} else {
300325
self.strides.clone()
301326
};
@@ -383,7 +408,8 @@ where
383408
return;
384409
}
385410
sort_axes_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides);
386-
debug_assert!(a.is_standard_layout());
411+
debug_assert!(a.is_standard_layout(), "not std layout dim: {:?}, strides: {:?}",
412+
a.shape(), a.strides());
387413
}
388414

389415
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)

tests/append.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,52 @@ fn append_array_3d() {
146146
[83, 84],
147147
[87, 88]]]);
148148
}
149+
150+
#[test]
151+
fn test_append_2d() {
152+
// create an empty array and append
153+
let mut a = Array::zeros((0, 4));
154+
let ones = ArrayView::from(&[1.; 12]).into_shape((3, 4)).unwrap();
155+
let zeros = ArrayView::from(&[0.; 8]).into_shape((2, 4)).unwrap();
156+
a.try_append_array(Axis(0), ones).unwrap();
157+
a.try_append_array(Axis(0), zeros).unwrap();
158+
a.try_append_array(Axis(0), ones).unwrap();
159+
println!("{:?}", a);
160+
assert_eq!(a.shape(), &[8, 4]);
161+
for (i, row) in a.rows().into_iter().enumerate() {
162+
let ones = i < 3 || i >= 5;
163+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
164+
}
165+
166+
let mut a = Array::zeros((0, 4));
167+
a = a.reversed_axes();
168+
let ones = ones.reversed_axes();
169+
let zeros = zeros.reversed_axes();
170+
a.try_append_array(Axis(1), ones).unwrap();
171+
a.try_append_array(Axis(1), zeros).unwrap();
172+
a.try_append_array(Axis(1), ones).unwrap();
173+
println!("{:?}", a);
174+
assert_eq!(a.shape(), &[4, 8]);
175+
176+
for (i, row) in a.columns().into_iter().enumerate() {
177+
let ones = i < 3 || i >= 5;
178+
assert!(row.iter().all(|&x| x == ones as i32 as f64), "failed on lane {}", i);
179+
}
180+
}
181+
182+
#[test]
183+
fn test_append_middle_axis() {
184+
// ensure we can append to Axis(1) by letting it become outermost
185+
let mut a = Array::<i32, _>::zeros((3, 0, 2));
186+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
187+
println!("{:?}", a);
188+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
189+
println!("{:?}", a);
190+
191+
// ensure we can append to Axis(1) by letting it become outermost
192+
let mut a = Array::<i32, _>::zeros((3, 1, 2));
193+
a.try_append_array(Axis(1), Array::from_iter(0..12).into_shape((3, 2, 2)).unwrap().view()).unwrap();
194+
println!("{:?}", a);
195+
a.try_append_array(Axis(1), Array::from_iter(12..24).into_shape((3, 2, 2)).unwrap().view()).unwrap();
196+
println!("{:?}", a);
197+
}

0 commit comments

Comments
 (0)