Skip to content

Commit 85a83d1

Browse files
Implemented map_mut and map_axis_mut
1 parent 3de28af commit 85a83d1

File tree

1 file changed

+60
-4
lines changed

1 file changed

+60
-4
lines changed

src/impl_methods.rs

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
892892
/// **Panics** if any dimension of `chunk_size` is zero<br>
893893
/// (**Panics** if `D` is `IxDyn` and `chunk_size` does not match the
894894
/// number of array axes.)
895-
pub fn exact_chunks<E>(&self, chunk_size: E) -> ExactChunks<A, D>
895+
pub fn exact_chunks<E>(&self, chunk_size: E) -> ExactChunks<A, D>
896896
where E: IntoDimension<Dim=D>,
897897
{
898898
exact_chunks_of(self.view(), chunk_size)
@@ -930,7 +930,7 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
930930
/// [6, 6, 7, 7, 8, 8, 0],
931931
/// [6, 6, 7, 7, 8, 8, 0]]));
932932
/// ```
933-
pub fn exact_chunks_mut<E>(&mut self, chunk_size: E) -> ExactChunksMut<A, D>
933+
pub fn exact_chunks_mut<E>(&mut self, chunk_size: E) -> ExactChunksMut<A, D>
934934
where E: IntoDimension<Dim=D>,
935935
S: DataMut
936936
{
@@ -941,13 +941,13 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
941941
///
942942
/// The windows are all distinct overlapping views of size `window_size`
943943
/// that fit into the array's shape.
944-
///
944+
///
945945
/// Will yield over no elements if window size is larger
946946
/// than the actual array size of any dimension.
947947
///
948948
/// The produced element is an `ArrayView<A, D>` with exactly the dimension
949949
/// `window_size`.
950-
///
950+
///
951951
/// **Panics** if any dimension of `window_size` is zero.<br>
952952
/// (**Panics** if `D` is `IxDyn` and `window_size` does not match the
953953
/// number of array axes.)
@@ -1694,6 +1694,34 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
16941694
}
16951695
}
16961696

1697+
/// Call `f` on a mutable reference of each element and create a new array
1698+
/// with the new values.
1699+
///
1700+
/// Elements are visited in arbitrary order.
1701+
///
1702+
/// Return an array with the same shape as `self`.
1703+
pub fn map_mut<'a, B, F>(&'a mut self, f: F) -> Array<B, D>
1704+
where F: FnMut(&mut A) -> B,
1705+
A: 'a,
1706+
S: DataMut
1707+
{
1708+
let dim = self.dim.clone();
1709+
let strides = self.strides.clone();
1710+
if self.is_contiguous() {
1711+
let slc = self.as_slice_memory_order_mut().unwrap();
1712+
let v = ::iterators::to_vec_mapped(slc.iter_mut(), f);
1713+
unsafe {
1714+
ArrayBase::from_shape_vec_unchecked(
1715+
dim.strides(strides), v)
1716+
}
1717+
} else {
1718+
let v = ::iterators::to_vec_mapped(self.iter_mut(), f);
1719+
unsafe {
1720+
ArrayBase::from_shape_vec_unchecked(dim, v)
1721+
}
1722+
}
1723+
}
1724+
16971725
/// Call `f` by **v**alue on each element and create a new array
16981726
/// with the new values.
16991727
///
@@ -1819,4 +1847,32 @@ impl<A, S, D> ArrayBase<S, D> where S: Data<Elem=A>, D: Dimension
18191847
}
18201848
})
18211849
}
1850+
1851+
/// Reduce the values along an axis into just one value, producing a new
1852+
/// array with one less dimension.
1853+
/// 1-dimensional lanes are passed as mutable references to the reducer,
1854+
/// allowing for side-effects.
1855+
///
1856+
/// Elements are visited in arbitrary order.
1857+
///
1858+
/// Return the result as an `Array`.
1859+
///
1860+
/// **Panics** if `axis` is out of bounds.
1861+
pub fn map_axis_mut<'a, B, F>(&'a mut self, axis: Axis, mut mapping: F)
1862+
-> Array<B, D::Smaller>
1863+
where D: RemoveAxis,
1864+
F: FnMut(ArrayViewMut1<'a, A>) -> B,
1865+
A: 'a + Clone,
1866+
S: DataMut,
1867+
{
1868+
let view_len = self.len_of(axis);
1869+
let view_stride = self.strides.axis(axis);
1870+
// use the 0th subview as a map to each 1d array view extended from
1871+
// the 0th element.
1872+
self.subview_mut(axis, 0).map_mut(|first_elt: &mut A| {
1873+
unsafe {
1874+
mapping(ArrayViewMut::new_(first_elt, Ix1(view_len), Ix1(view_stride)))
1875+
}
1876+
})
1877+
}
18221878
}

0 commit comments

Comments
 (0)