Skip to content

Commit 246848c

Browse files
committed
shape: Add method .to_shape()
1 parent 294b1b1 commit 246848c

File tree

1 file changed

+106
-1
lines changed

1 file changed

+106
-1
lines changed

src/impl_methods.rs

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@ use crate::dimension::broadcast::co_broadcast;
2626
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2727
use crate::math_cell::MathCell;
2828
use crate::itertools::zip;
29-
use crate::zip::{IntoNdProducer, Zip};
3029
use crate::AxisDescription;
30+
use crate::Layout;
31+
use crate::order::Order;
32+
use crate::shape_builder::{Contiguous, ShapeArg};
33+
use crate::zip::{IntoNdProducer, Zip};
3134

3235
use crate::iter::{
3336
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
@@ -1577,6 +1580,108 @@ where
15771580
}
15781581
}
15791582

1583+
/// Transform the array into `new_shape`; any shape with the same number of elements is
1584+
/// accepted.
1585+
///
1586+
/// `order` specifies the *logical* order in which the array is to be read and reshaped.
1587+
/// The array is returned as a `CowArray`; a view if possible, otherwise an owned array.
1588+
///
1589+
/// For example, when starting from the one-dimensional sequence 1 2 3 4 5 6, it would be
1590+
/// understood as a 2 x 3 array in row major ("C") order this way:
1591+
///
1592+
/// ```text
1593+
/// 1 2 3
1594+
/// 4 5 6
1595+
/// ```
1596+
///
1597+
/// and as 2 x 3 in column major ("F") order this way:
1598+
///
1599+
/// ```text
1600+
/// 1 3 5
1601+
/// 2 4 6
1602+
/// ```
1603+
///
1604+
/// This example should show that any time we "reflow" the elements in the array to a different
1605+
/// number of rows and columns (or more axes if applicable), it is important to pick an index
1606+
/// ordering, and that's the reason for the function parameter for `order`.
1607+
///
1608+
/// **Errors** if the new shape doesn't have the same number of elements as the array's current
1609+
/// shape.
1610+
///
1611+
/// ```
1612+
/// use ndarray::array;
1613+
/// use ndarray::Order;
1614+
///
1615+
/// assert!(
1616+
/// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::RowMajor)).unwrap()
1617+
/// == array![[1., 2., 3.],
1618+
/// [4., 5., 6.]]
1619+
/// );
1620+
///
1621+
/// assert!(
1622+
/// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::ColumnMajor)).unwrap()
1623+
/// == array![[1., 3., 5.],
1624+
/// [2., 4., 6.]]
1625+
/// );
1626+
/// ```
1627+
pub fn to_shape<E>(&self, new_shape: E) -> Result<CowArray<'_, A, E::Dim>, ShapeError>
1628+
where
1629+
E: ShapeArg<StrideType = Contiguous>,
1630+
A: Clone,
1631+
S: Data,
1632+
{
1633+
let (shape, order) = new_shape.into_shape_and_order(Order::RowMajor);
1634+
self.to_shape_order(shape, order)
1635+
}
1636+
1637+
fn to_shape_order<E>(&self, shape: E, mut order: Order)
1638+
-> Result<CowArray<'_, A, E>, ShapeError>
1639+
where
1640+
E: Dimension,
1641+
A: Clone,
1642+
S: Data,
1643+
{
1644+
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
1645+
return Err(error::incompatible_shapes(&self.dim, &shape));
1646+
}
1647+
let layout = self.layout_impl();
1648+
if order == Order::Automatic {
1649+
// the order of the conditionals is significant for preference
1650+
if layout.is(Layout::CORDER) {
1651+
order = Order::RowMajor;
1652+
} else if layout.is(Layout::FORDER) {
1653+
order = Order::ColumnMajor;
1654+
} else if layout.is(Layout::CPREFER) {
1655+
order = Order::RowMajor;
1656+
} else if layout.is(Layout::FPREFER) {
1657+
order = Order::ColumnMajor;
1658+
} else {
1659+
order = Order::RowMajor;
1660+
}
1661+
}
1662+
1663+
unsafe {
1664+
if layout.is(Layout::CORDER) && order == Order::RowMajor {
1665+
let strides = shape.default_strides();
1666+
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1667+
} else if layout.is(Layout::FORDER) && order == Order::ColumnMajor {
1668+
let strides = shape.fortran_strides();
1669+
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1670+
} else {
1671+
let (shape, view) = if order == Order::RowMajor {
1672+
(shape.set_f(false), self.view())
1673+
} else if order == Order::ColumnMajor {
1674+
(shape.set_f(true), self.t())
1675+
} else {
1676+
// Order::Automatic is already resolved
1677+
unreachable!()
1678+
};
1679+
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
1680+
shape, view.into_iter(), A::clone)))
1681+
}
1682+
}
1683+
}
1684+
15801685
/// Transform the array into `shape`; any shape with the same number of
15811686
/// elements is accepted, but the source array or view must be in standard
15821687
/// or column-major (Fortran) layout.

0 commit comments

Comments
 (0)