Skip to content

Commit e66e3c8

Browse files
jturner314bluss
authored andcommitted
Fix safety of SliceInfo::new
1 parent 615113e commit e66e3c8

File tree

3 files changed

+51
-34
lines changed

3 files changed

+51
-34
lines changed

src/slice.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ impl From<NewAxis> for AxisSliceInfo {
307307
///
308308
/// This trait is unsafe to implement because the implementation must ensure
309309
/// that `D`, `Self::OutDim`, `self.in_dim()`, and `self.out_ndim()` are
310-
/// consistent with the `&[AxisSliceInfo]` returned by `self.as_ref()`.
310+
/// consistent with the `&[AxisSliceInfo]` returned by `self.as_ref()` and that
311+
/// `self.as_ref()` always returns the same value when called multiple times.
311312
pub unsafe trait CanSlice<D: Dimension>: AsRef<[AxisSliceInfo]> {
312313
type OutDim: Dimension;
313314

@@ -409,10 +410,13 @@ where
409410
{
410411
/// Returns a new `SliceInfo` instance.
411412
///
412-
/// If you call this method, you are guaranteeing that `in_dim` and
413-
/// `out_dim` are consistent with `indices`.
414-
///
415413
/// **Note:** only unchecked for non-debug builds of `ndarray`.
414+
///
415+
/// # Safety
416+
///
417+
/// The caller must ensure that `in_dim` and `out_dim` are consistent with
418+
/// `indices` and that `indices.as_ref()` always returns the same value
419+
/// when called multiple times.
416420
#[doc(hidden)]
417421
pub unsafe fn new_unchecked(
418422
indices: T,
@@ -444,7 +448,12 @@ where
444448
/// Returns a new `SliceInfo` instance.
445449
///
446450
/// Errors if `Din` or `Dout` is not consistent with `indices`.
447-
pub fn new(indices: T) -> Result<SliceInfo<T, Din, Dout>, ShapeError> {
451+
///
452+
/// # Safety
453+
///
454+
/// The caller must ensure `indices.as_ref()` always returns the same value
455+
/// when called multiple times.
456+
pub unsafe fn new(indices: T) -> Result<SliceInfo<T, Din, Dout>, ShapeError> {
448457
if let Some(in_ndim) = Din::NDIM {
449458
if in_ndim != indices.as_ref().in_ndim() {
450459
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));

tests/array.rs

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,15 @@ fn test_slice_dyninput_array_fixed() {
216216
#[test]
217217
fn test_slice_array_dyn() {
218218
let mut arr = Array3::<f64>::zeros((5, 2, 5));
219-
let info = &SliceInfo::<_, Ix3, IxDyn>::new([
220-
AxisSliceInfo::from(1..),
221-
AxisSliceInfo::from(1),
222-
AxisSliceInfo::from(NewAxis),
223-
AxisSliceInfo::from(..).step_by(2),
224-
])
225-
.unwrap();
219+
let info = &unsafe {
220+
SliceInfo::<_, Ix3, IxDyn>::new([
221+
AxisSliceInfo::from(1..),
222+
AxisSliceInfo::from(1),
223+
AxisSliceInfo::from(NewAxis),
224+
AxisSliceInfo::from(..).step_by(2),
225+
])
226+
.unwrap()
227+
};
226228
arr.slice(info);
227229
arr.slice_mut(info);
228230
arr.view().slice_move(info);
@@ -232,13 +234,15 @@ fn test_slice_array_dyn() {
232234
#[test]
233235
fn test_slice_dyninput_array_dyn() {
234236
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
235-
let info = &SliceInfo::<_, Ix3, IxDyn>::new([
236-
AxisSliceInfo::from(1..),
237-
AxisSliceInfo::from(1),
238-
AxisSliceInfo::from(NewAxis),
239-
AxisSliceInfo::from(..).step_by(2),
240-
])
241-
.unwrap();
237+
let info = &unsafe {
238+
SliceInfo::<_, Ix3, IxDyn>::new([
239+
AxisSliceInfo::from(1..),
240+
AxisSliceInfo::from(1),
241+
AxisSliceInfo::from(NewAxis),
242+
AxisSliceInfo::from(..).step_by(2),
243+
])
244+
.unwrap()
245+
};
242246
arr.slice(info);
243247
arr.slice_mut(info);
244248
arr.view().slice_move(info);
@@ -248,13 +252,15 @@ fn test_slice_dyninput_array_dyn() {
248252
#[test]
249253
fn test_slice_dyninput_vec_fixed() {
250254
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
251-
let info = &SliceInfo::<_, Ix3, Ix3>::new(vec![
252-
AxisSliceInfo::from(1..),
253-
AxisSliceInfo::from(1),
254-
AxisSliceInfo::from(NewAxis),
255-
AxisSliceInfo::from(..).step_by(2),
256-
])
257-
.unwrap();
255+
let info = &unsafe {
256+
SliceInfo::<_, Ix3, Ix3>::new(vec![
257+
AxisSliceInfo::from(1..),
258+
AxisSliceInfo::from(1),
259+
AxisSliceInfo::from(NewAxis),
260+
AxisSliceInfo::from(..).step_by(2),
261+
])
262+
.unwrap()
263+
};
258264
arr.slice(info);
259265
arr.slice_mut(info);
260266
arr.view().slice_move(info);
@@ -264,13 +270,15 @@ fn test_slice_dyninput_vec_fixed() {
264270
#[test]
265271
fn test_slice_dyninput_vec_dyn() {
266272
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
267-
let info = &SliceInfo::<_, Ix3, IxDyn>::new(vec![
268-
AxisSliceInfo::from(1..),
269-
AxisSliceInfo::from(1),
270-
AxisSliceInfo::from(NewAxis),
271-
AxisSliceInfo::from(..).step_by(2),
272-
])
273-
.unwrap();
273+
let info = &unsafe {
274+
SliceInfo::<_, Ix3, IxDyn>::new(vec![
275+
AxisSliceInfo::from(1..),
276+
AxisSliceInfo::from(1),
277+
AxisSliceInfo::from(NewAxis),
278+
AxisSliceInfo::from(..).step_by(2),
279+
])
280+
.unwrap()
281+
};
274282
arr.slice(info);
275283
arr.slice_mut(info);
276284
arr.view().slice_move(info);

tests/oper.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ fn scaled_add_3() {
595595

596596
{
597597
let mut av = a.slice_mut(s![..;s1, ..;s2]);
598-
let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap());
598+
let c = c.slice(&unsafe { SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap() });
599599

600600
let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
601601
answerv += &(beta * &c);

0 commit comments

Comments
 (0)