Skip to content

Commit 3d9a9ba

Browse files
committed
move_into: New method .move_into() for moving all array elements
.move_into() lets all elements move out of an Array, into an uninitialized array. We use a DropCounter to check duplication/drops of elements rigorously. The DropCounter code is taken from rayon collect tests, where I wrote it.
1 parent aad9c74 commit 3d9a9ba

File tree

4 files changed

+423
-36
lines changed

4 files changed

+423
-36
lines changed

src/impl_owned_array.rs

Lines changed: 185 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11

22
use alloc::vec::Vec;
3+
use std::mem::MaybeUninit;
34

45
use crate::imp_prelude::*;
6+
57
use crate::dimension;
68
use crate::error::{ErrorKind, ShapeError};
9+
use crate::iterators::Baseiter;
10+
use crate::low_level_util::AbortIfPanic;
711
use crate::OwnedRepr;
812
use crate::Zip;
913

@@ -137,6 +141,145 @@ impl<A> Array<A, Ix2> {
137141
impl<A, D> Array<A, D>
138142
where D: Dimension
139143
{
144+
/// Move all elements from self into `new_array`, which must be of the same shape but
145+
/// can have a different memory layout. The destination is overwritten completely.
146+
///
147+
/// The destination should be a mut reference to an array or an `ArrayViewMut` with
148+
/// `MaybeUninit<A>` elements (which are overwritten without dropping any existing value).
149+
///
150+
/// Minor implementation note: Owned arrays like `self` may be sliced in place and own elements
151+
/// that are not part of their active view; these are dropped at the end of this function,
152+
/// after all elements in the "active view" are moved into `new_array`. If there is a panic in
153+
/// drop of any such element, other elements may be leaked.
154+
///
155+
/// ***Panics*** if the shapes don't agree.
156+
pub fn move_into<'a, AM>(self, new_array: AM)
157+
where
158+
AM: Into<ArrayViewMut<'a, MaybeUninit<A>, D>>,
159+
A: 'a,
160+
{
161+
// Remove generic parameter P and call the implementation
162+
self.move_into_impl(new_array.into())
163+
}
164+
165+
fn move_into_impl(mut self, new_array: ArrayViewMut<MaybeUninit<A>, D>) {
166+
unsafe {
167+
// Safety: copy_to_nonoverlapping cannot panic
168+
let guard = AbortIfPanic(&"move_into: moving out of owned value");
169+
// Move all reachable elements
170+
Zip::from(self.raw_view_mut())
171+
.and(new_array)
172+
.for_each(|src, dst| {
173+
src.copy_to_nonoverlapping(dst.as_mut_ptr(), 1);
174+
});
175+
guard.defuse();
176+
// Drop all unreachable elements
177+
self.drop_unreachable_elements();
178+
}
179+
}
180+
181+
/// This drops all "unreachable" elements in the data storage of self.
182+
///
183+
/// That means those elements that are not visible in the slicing of the array.
184+
/// *Reachable elements are assumed to already have been moved from.*
185+
///
186+
/// # Safety
187+
///
188+
/// This is a panic critical section since `self` is already moved-from.
189+
fn drop_unreachable_elements(mut self) -> OwnedRepr<A> {
190+
let self_len = self.len();
191+
192+
// "deconstruct" self; the owned repr releases ownership of all elements and we
193+
// and carry on with raw view methods
194+
let data_len = self.data.len();
195+
196+
let has_unreachable_elements = self_len != data_len;
197+
if !has_unreachable_elements || std::mem::size_of::<A>() == 0 {
198+
unsafe {
199+
self.data.set_len(0);
200+
}
201+
self.data
202+
} else {
203+
self.drop_unreachable_elements_slow()
204+
}
205+
}
206+
207+
#[inline(never)]
208+
#[cold]
209+
fn drop_unreachable_elements_slow(mut self) -> OwnedRepr<A> {
210+
// "deconstruct" self; the owned repr releases ownership of all elements and we
211+
// and carry on with raw view methods
212+
let self_len = self.len();
213+
let data_len = self.data.len();
214+
let data_ptr = self.data.as_nonnull_mut().as_ptr();
215+
216+
let mut self_;
217+
218+
unsafe {
219+
// Safety: self.data releases ownership of the elements
220+
self_ = self.raw_view_mut();
221+
self.data.set_len(0);
222+
}
223+
224+
225+
// uninvert axes where needed, so that stride > 0
226+
for i in 0..self_.ndim() {
227+
if self_.stride_of(Axis(i)) < 0 {
228+
self_.invert_axis(Axis(i));
229+
}
230+
}
231+
232+
// Sort axes to standard order, Axis(0) has biggest stride and Axis(n - 1) least stride
233+
// Note that self_ has holes, so self_ is not C-contiguous
234+
sort_axes_in_default_order(&mut self_);
235+
236+
unsafe {
237+
// with uninverted axes this is now the element with lowest address
238+
let array_memory_head_ptr = self_.ptr.as_ptr();
239+
let data_end_ptr = data_ptr.add(data_len);
240+
debug_assert!(data_ptr <= array_memory_head_ptr);
241+
debug_assert!(array_memory_head_ptr <= data_end_ptr);
242+
243+
// iter is a raw pointer iterator traversing self_ in its standard order
244+
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
245+
let mut dropped_elements = 0;
246+
247+
// The idea is simply this: the iterator will yield the elements of self_ in
248+
// increasing address order.
249+
//
250+
// The pointers produced by the iterator are those that we *do not* touch.
251+
// The pointers *not mentioned* by the iterator are those we have to drop.
252+
//
253+
// We have to drop elements in the range from `data_ptr` until (not including)
254+
// `data_end_ptr`, except those that are produced by `iter`.
255+
let mut last_ptr = data_ptr;
256+
257+
while let Some(elem_ptr) = iter.next() {
258+
// The interval from last_ptr up until (not including) elem_ptr
259+
// should now be dropped. This interval may be empty, then we just skip this loop.
260+
while last_ptr != elem_ptr {
261+
debug_assert!(last_ptr < data_end_ptr);
262+
std::ptr::drop_in_place(last_ptr);
263+
last_ptr = last_ptr.add(1);
264+
dropped_elements += 1;
265+
}
266+
// Next interval will continue one past the current element
267+
last_ptr = elem_ptr.add(1);
268+
}
269+
270+
while last_ptr < data_end_ptr {
271+
std::ptr::drop_in_place(last_ptr);
272+
last_ptr = last_ptr.add(1);
273+
dropped_elements += 1;
274+
}
275+
276+
assert_eq!(data_len, dropped_elements + self_len,
277+
"Internal error: inconsistency in move_into");
278+
}
279+
self.data
280+
}
281+
282+
140283
/// Append an array to the array
141284
///
142285
/// The axis-to-append-to `axis` must be the array's "growing axis" for this operation
@@ -312,7 +455,7 @@ impl<A, D> Array<A, D>
312455
array.invert_axis(Axis(i));
313456
}
314457
}
315-
sort_axes_to_standard_order(&mut tail_view, &mut array);
458+
sort_axes_to_standard_order_tandem(&mut tail_view, &mut array);
316459
}
317460
Zip::from(tail_view).and(array)
318461
.debug_assert_c_order()
@@ -335,7 +478,21 @@ impl<A, D> Array<A, D>
335478
}
336479
}
337480

338-
fn sort_axes_to_standard_order<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
481+
/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride
482+
///
483+
/// The axes should have stride >= 0 before calling this method.
484+
fn sort_axes_in_default_order<S, D>(a: &mut ArrayBase<S, D>)
485+
where
486+
S: RawData,
487+
D: Dimension,
488+
{
489+
if a.ndim() <= 1 {
490+
return;
491+
}
492+
sort_axes1_impl(&mut a.dim, &mut a.strides);
493+
}
494+
495+
fn sort_axes_to_standard_order_tandem<S, S2, D>(a: &mut ArrayBase<S, D>, b: &mut ArrayBase<S2, D>)
339496
where
340497
S: RawData,
341498
S2: RawData,
@@ -349,6 +506,32 @@ where
349506
a.shape(), a.strides());
350507
}
351508

509+
fn sort_axes1_impl<D>(adim: &mut D, astrides: &mut D)
510+
where
511+
D: Dimension,
512+
{
513+
debug_assert!(adim.ndim() > 1);
514+
debug_assert_eq!(adim.ndim(), astrides.ndim());
515+
// bubble sort axes
516+
let mut changed = true;
517+
while changed {
518+
changed = false;
519+
for i in 0..adim.ndim() - 1 {
520+
let axis_i = i;
521+
let next_axis = i + 1;
522+
523+
// make sure higher stride axes sort before.
524+
debug_assert!(astrides.slice()[axis_i] as isize >= 0);
525+
if (astrides.slice()[axis_i] as isize) < astrides.slice()[next_axis] as isize {
526+
changed = true;
527+
adim.slice_mut().swap(axis_i, next_axis);
528+
astrides.slice_mut().swap(axis_i, next_axis);
529+
}
530+
}
531+
}
532+
}
533+
534+
352535
fn sort_axes_impl<D>(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D)
353536
where
354537
D: Dimension,

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
clippy::deref_addrof,
1313
clippy::unreadable_literal,
1414
clippy::manual_map, // is not an error
15+
clippy::while_let_on_iterator, // is not an error
1516
)]
1617
#![cfg_attr(not(feature = "std"), no_std)]
1718

0 commit comments

Comments
 (0)