Skip to content

Commit 5c2da21

Browse files
Andrewjturner314
authored andcommitted
As standard layout method (#616)
This adds an `.as_standard_layout()` method which returns a standard-layout array containing the data, cloning if necessary.
1 parent 5b5d898 commit 5c2da21

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

src/impl_methods.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,48 @@ where
12181218
D::is_contiguous(&self.dim, &self.strides)
12191219
}
12201220

1221+
/// Return a standard-layout array containing the data, cloning if
1222+
/// necessary.
1223+
///
1224+
/// If `self` is in standard layout, a COW view of the data is returned
1225+
/// without cloning. Otherwise, the data is cloned, and the returned array
1226+
/// owns the cloned data.
1227+
///
1228+
/// ```
1229+
/// use ndarray::Array2;
1230+
///
1231+
/// let standard = Array2::<f64>::zeros((3, 4));
1232+
/// assert!(standard.is_standard_layout());
1233+
/// let cow_view = standard.as_standard_layout();
1234+
/// assert!(cow_view.is_view());
1235+
/// assert!(cow_view.is_standard_layout());
1236+
///
1237+
/// let fortran = standard.reversed_axes();
1238+
/// assert!(!fortran.is_standard_layout());
1239+
/// let cow_owned = fortran.as_standard_layout();
1240+
/// assert!(cow_owned.is_owned());
1241+
/// assert!(cow_owned.is_standard_layout());
1242+
/// ```
1243+
pub fn as_standard_layout(&self) -> CowArray<'_, A, D>
1244+
where
1245+
S: Data<Elem = A>,
1246+
A: Clone,
1247+
{
1248+
if self.is_standard_layout() {
1249+
CowArray::from(self.view())
1250+
} else {
1251+
let v: Vec<A> = self.iter().cloned().collect();
1252+
let dim = self.dim.clone();
1253+
assert_eq!(v.len(), dim.size());
1254+
let owned_array: Array<A, D> = unsafe {
1255+
// Safe because the shape and element type are from the existing array
1256+
// and the strides are the default strides.
1257+
Array::from_shape_vec_unchecked(dim, v)
1258+
};
1259+
CowArray::from(owned_array)
1260+
}
1261+
}
1262+
12211263
/// Return a pointer to the first element in the array.
12221264
///
12231265
/// Raw access to array elements needs to follow the strided indexing

tests/array.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,74 @@ fn array_macros() {
19701970
assert_eq!(empty2, array![[]]);
19711971
}
19721972

1973+
#[cfg(test)]
1974+
mod as_standard_layout_tests {
1975+
use super::*;
1976+
use ndarray::Data;
1977+
use std::fmt::Debug;
1978+
1979+
fn test_as_standard_layout_for<S, D>(orig: ArrayBase<S, D>)
1980+
where
1981+
S: Data,
1982+
S::Elem: Clone + Debug + PartialEq,
1983+
D: Dimension,
1984+
{
1985+
let orig_is_standard = orig.is_standard_layout();
1986+
let out = orig.as_standard_layout();
1987+
assert!(out.is_standard_layout());
1988+
assert_eq!(out, orig);
1989+
assert_eq!(orig_is_standard, out.is_view());
1990+
}
1991+
1992+
#[test]
1993+
fn test_f_layout() {
1994+
let shape = (2, 2).f();
1995+
let arr = Array::<i32, Ix2>::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap();
1996+
assert!(!arr.is_standard_layout());
1997+
test_as_standard_layout_for(arr);
1998+
}
1999+
2000+
#[test]
2001+
fn test_c_layout() {
2002+
let arr = Array::<i32, Ix2>::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap();
2003+
assert!(arr.is_standard_layout());
2004+
test_as_standard_layout_for(arr);
2005+
}
2006+
2007+
#[test]
2008+
fn test_f_layout_view() {
2009+
let shape = (2, 2).f();
2010+
let arr = Array::<i32, Ix2>::from_shape_vec(shape, vec![1, 2, 3, 4]).unwrap();
2011+
let arr_view = arr.view();
2012+
assert!(!arr_view.is_standard_layout());
2013+
test_as_standard_layout_for(arr);
2014+
}
2015+
2016+
#[test]
2017+
fn test_c_layout_view() {
2018+
let arr = Array::<i32, Ix2>::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap();
2019+
let arr_view = arr.view();
2020+
assert!(arr_view.is_standard_layout());
2021+
test_as_standard_layout_for(arr_view);
2022+
}
2023+
2024+
#[test]
2025+
fn test_zero_dimensional_array() {
2026+
let arr_view = ArrayView1::<i32>::from(&[]);
2027+
assert!(arr_view.is_standard_layout());
2028+
test_as_standard_layout_for(arr_view);
2029+
}
2030+
2031+
#[test]
2032+
fn test_custom_layout() {
2033+
let shape = (1, 2, 3, 2).strides((12, 1, 2, 6));
2034+
let arr_data: Vec<i32> = (0..12).collect();
2035+
let arr = Array::<i32, Ix4>::from_shape_vec(shape, arr_data).unwrap();
2036+
assert!(!arr.is_standard_layout());
2037+
test_as_standard_layout_for(arr);
2038+
}
2039+
}
2040+
19732041
#[cfg(test)]
19742042
mod array_cow_tests {
19752043
use super::*;

0 commit comments

Comments
 (0)