Skip to content

Commit 516a504

Browse files
committed
Allow aliasing in ArrayView::from_shape
Changes the checks in the ArrayView::from_shape constructor so that it allows a few more cases: custom strides that lead to overlapping are allowed. Before, both ArrayViewMut and ArrayView applied the same check, that the dimensions and strides must be such that no elements can be reached by more than one index. However, this rule only applies for mutable data, for ArrayView we can allow this kind of aliasing. This is in fact how broadcasting works, where we use strides to repeat the same array data multiple times.
1 parent e578d58 commit 516a504

File tree

4 files changed

+99
-53
lines changed

4 files changed

+99
-53
lines changed

src/dimension/mod.rs

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
100100
}
101101
}
102102

103+
/// Select how aliasing is checked
104+
///
105+
/// For owned or mutable data:
106+
///
107+
/// The strides must not allow any element to be referenced by two different indices.
108+
///
109+
#[derive(Copy, Clone, PartialEq)]
110+
pub(crate) enum CanIndexCheckMode
111+
{
112+
/// Owned or mutable: No aliasing
113+
OwnedMutable,
114+
/// Aliasing
115+
ReadOnly,
116+
}
117+
103118
/// Checks whether the given data and dimension meet the invariants of the
104119
/// `ArrayBase` type, assuming the strides are created using
105120
/// `dim.default_strides()` or `dim.fortran_strides()`.
@@ -125,12 +140,13 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
125140
/// `A` and in units of bytes between the least address and greatest address
126141
/// accessible by moving along all axes does not exceed `isize::MAX`.
127142
pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(
128-
data: &[A], dim: &D, strides: &Strides<D>,
143+
data: &[A], dim: &D, strides: &Strides<D>, mode: CanIndexCheckMode,
129144
) -> Result<(), ShapeError>
130145
{
131146
if let Strides::Custom(strides) = strides {
132-
can_index_slice(data, dim, strides)
147+
can_index_slice(data, dim, strides, mode)
133148
} else {
149+
// contiguous shapes: never aliasing, mode does not matter
134150
can_index_slice_not_custom(data.len(), dim)
135151
}
136152
}
@@ -239,15 +255,19 @@ where D: Dimension
239255
/// allocation. (In other words, the pointer to the first element of the array
240256
/// must be computed using `offset_from_low_addr_ptr_to_logical_ptr` so that
241257
/// negative strides are correctly handled.)
242-
pub(crate) fn can_index_slice<A, D: Dimension>(data: &[A], dim: &D, strides: &D) -> Result<(), ShapeError>
258+
///
259+
/// Note, condition (4) is guaranteed to be checked last
260+
pub(crate) fn can_index_slice<A, D: Dimension>(
261+
data: &[A], dim: &D, strides: &D, mode: CanIndexCheckMode,
262+
) -> Result<(), ShapeError>
243263
{
244264
// Check conditions 1 and 2 and calculate `max_offset`.
245265
let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
246-
can_index_slice_impl(max_offset, data.len(), dim, strides)
266+
can_index_slice_impl(max_offset, data.len(), dim, strides, mode)
247267
}
248268

249269
fn can_index_slice_impl<D: Dimension>(
250-
max_offset: usize, data_len: usize, dim: &D, strides: &D,
270+
max_offset: usize, data_len: usize, dim: &D, strides: &D, mode: CanIndexCheckMode,
251271
) -> Result<(), ShapeError>
252272
{
253273
// Check condition 3.
@@ -260,7 +280,7 @@ fn can_index_slice_impl<D: Dimension>(
260280
}
261281

262282
// Check condition 4.
263-
if !is_empty && dim_stride_overlap(dim, strides) {
283+
if !is_empty && mode != CanIndexCheckMode::ReadOnly && dim_stride_overlap(dim, strides) {
264284
return Err(from_kind(ErrorKind::Unsupported));
265285
}
266286

@@ -782,6 +802,7 @@ mod test
782802
slice_min_max,
783803
slices_intersect,
784804
solve_linear_diophantine_eq,
805+
CanIndexCheckMode,
785806
IntoDimension,
786807
};
787808
use crate::error::{from_kind, ErrorKind};
@@ -796,11 +817,11 @@ mod test
796817
let v: alloc::vec::Vec<_> = (0..12).collect();
797818
let dim = (2, 3, 2).into_dimension();
798819
let strides = (1, 2, 6).into_dimension();
799-
assert!(super::can_index_slice(&v, &dim, &strides).is_ok());
820+
assert!(super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());
800821

801822
let strides = (2, 4, 12).into_dimension();
802823
assert_eq!(
803-
super::can_index_slice(&v, &dim, &strides),
824+
super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable),
804825
Err(from_kind(ErrorKind::OutOfBounds))
805826
);
806827
}
@@ -848,71 +869,79 @@ mod test
848869
#[test]
849870
fn can_index_slice_ix0()
850871
{
851-
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0()).unwrap();
852-
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0()).unwrap_err();
872+
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap();
873+
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap_err();
853874
}
854875

855876
#[test]
856877
fn can_index_slice_ix1()
857878
{
858-
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0)).unwrap();
859-
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1)).unwrap();
860-
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0)).unwrap_err();
861-
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
862-
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0)).unwrap();
863-
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2)).unwrap();
864-
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize)).unwrap();
865-
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1)).unwrap_err();
866-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err();
867-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1)).unwrap();
868-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap();
879+
let mode = CanIndexCheckMode::OwnedMutable;
880+
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0), mode).unwrap();
881+
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
882+
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0), mode).unwrap_err();
883+
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
884+
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0), mode).unwrap();
885+
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2), mode).unwrap();
886+
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize), mode).unwrap();
887+
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1), mode).unwrap_err();
888+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0), mode).unwrap_err();
889+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1), mode).unwrap();
890+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize), mode).unwrap();
869891
}
870892

871893
#[test]
872894
fn can_index_slice_ix2()
873895
{
874-
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0)).unwrap();
875-
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1)).unwrap();
876-
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0)).unwrap();
877-
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1)).unwrap();
878-
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
879-
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
880-
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1)).unwrap_err();
881-
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1)).unwrap();
882-
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2)).unwrap_err();
883-
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1)).unwrap();
884-
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1)).unwrap_err();
896+
let mode = CanIndexCheckMode::OwnedMutable;
897+
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0), mode).unwrap();
898+
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1), mode).unwrap();
899+
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0), mode).unwrap();
900+
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1), mode).unwrap();
901+
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();
902+
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
903+
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap_err();
904+
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap();
905+
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2), mode).unwrap_err();
906+
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap();
907+
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap_err();
908+
909+
// aliasing strides: ok when readonly
910+
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::OwnedMutable).unwrap_err();
911+
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::ReadOnly).unwrap();
885912
}
886913

887914
#[test]
888915
fn can_index_slice_ix3()
889916
{
890-
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3)).unwrap();
891-
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap_err();
892-
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap();
893-
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap_err();
894-
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap();
917+
let mode = CanIndexCheckMode::OwnedMutable;
918+
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3), mode).unwrap();
919+
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap_err();
920+
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap();
921+
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap_err();
922+
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap();
895923
}
896924

897925
#[test]
898926
fn can_index_slice_zero_size_elem()
899927
{
900-
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1)).unwrap();
901-
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1)).unwrap();
902-
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1)).unwrap();
928+
let mode = CanIndexCheckMode::OwnedMutable;
929+
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
930+
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1), mode).unwrap();
931+
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1), mode).unwrap();
903932

904933
// These might seem okay because the element type is zero-sized, but
905934
// there could be a zero-sized type such that the number of instances
906935
// in existence are carefully controlled.
907-
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
908-
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1)).unwrap_err();
936+
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
937+
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1), mode).unwrap_err();
909938

910-
can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0)).unwrap();
911-
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
939+
can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0), mode).unwrap();
940+
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();
912941

913942
// This case would be probably be sound, but that's not entirely clear
914943
// and it's not worth the special case code.
915-
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
944+
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
916945
}
917946

918947
quickcheck! {
@@ -923,8 +952,8 @@ mod test
923952
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
924953
result.is_err()
925954
} else {
926-
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
927-
result == can_index_slice(&data, &dim, &dim.fortran_strides())
955+
result == can_index_slice(&data, &dim, &dim.default_strides(), CanIndexCheckMode::OwnedMutable) &&
956+
result == can_index_slice(&data, &dim, &dim.fortran_strides(), CanIndexCheckMode::OwnedMutable)
928957
}
929958
}
930959
}

src/impl_constructors.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use num_traits::{One, Zero};
2020
use std::mem;
2121
use std::mem::MaybeUninit;
2222

23-
use crate::dimension;
2423
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
24+
use crate::dimension::{self, CanIndexCheckMode};
2525
use crate::error::{self, ShapeError};
2626
use crate::extension::nonnull::nonnull_from_vec_data;
2727
use crate::imp_prelude::*;
@@ -466,7 +466,7 @@ where
466466
{
467467
let dim = shape.dim;
468468
let is_custom = shape.strides.is_custom();
469-
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?;
469+
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?;
470470
if !is_custom && dim.size() != v.len() {
471471
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
472472
}
@@ -510,7 +510,7 @@ where
510510
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self
511511
{
512512
// debug check for issues that indicates wrong use of this constructor
513-
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
513+
debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());
514514

515515
let ptr = nonnull_from_vec_data(&mut v).add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides));
516516
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)

src/impl_views/constructors.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
use std::ptr::NonNull;
1010

11-
use crate::dimension;
1211
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
12+
use crate::dimension::{self, CanIndexCheckMode};
1313
use crate::error::ShapeError;
1414
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
1515
use crate::imp_prelude::*;
@@ -54,7 +54,7 @@ where D: Dimension
5454
fn from_shape_impl(shape: StrideShape<D>, xs: &'a [A]) -> Result<Self, ShapeError>
5555
{
5656
let dim = shape.dim;
57-
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
57+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::ReadOnly)?;
5858
let strides = shape.strides.strides_for_dim(&dim);
5959
unsafe {
6060
Ok(Self::new_(
@@ -157,7 +157,7 @@ where D: Dimension
157157
fn from_shape_impl(shape: StrideShape<D>, xs: &'a mut [A]) -> Result<Self, ShapeError>
158158
{
159159
let dim = shape.dim;
160-
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
160+
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::OwnedMutable)?;
161161
let strides = shape.strides.strides_for_dim(&dim);
162162
unsafe {
163163
Ok(Self::new_(

tests/array.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use defmac::defmac;
1010
use itertools::{zip, Itertools};
1111
use ndarray::indices;
1212
use ndarray::prelude::*;
13+
use ndarray::ErrorKind;
1314
use ndarray::{arr3, rcarr2};
1415
use ndarray::{Slice, SliceInfo, SliceInfoElem};
1516
use num_complex::Complex;
@@ -2060,6 +2061,22 @@ fn test_view_from_shape()
20602061
assert_eq!(a, answer);
20612062
}
20622063

2064+
#[test]
2065+
fn test_view_from_shape_allow_overlap()
2066+
{
2067+
let data = [0, 1, 2];
2068+
let view = ArrayView::from_shape((2, 3).strides((0, 1)), &data).unwrap();
2069+
assert_eq!(view, aview2(&[data; 2]));
2070+
}
2071+
2072+
#[test]
2073+
fn test_view_mut_from_shape_deny_overlap()
2074+
{
2075+
let mut data = [0, 1, 2];
2076+
let result = ArrayViewMut::from_shape((2, 3).strides((0, 1)), &mut data);
2077+
assert_matches!(result.map_err(|e| e.kind()), Err(ErrorKind::Unsupported));
2078+
}
2079+
20632080
#[test]
20642081
fn test_contiguous()
20652082
{

0 commit comments

Comments
 (0)