Skip to content

Commit 3ba6ceb

Browse files
jturner314bluss
authored andcommitted
Add some impls of TryFrom for SliceInfo
1 parent e66e3c8 commit 3ba6ceb

File tree

3 files changed

+129
-53
lines changed

3 files changed

+129
-53
lines changed

src/slice.rs

Lines changed: 98 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
use crate::dimension::slices_intersect;
99
use crate::error::{ErrorKind, ShapeError};
1010
use crate::{ArrayViewMut, DimAdd, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
11+
use alloc::vec::Vec;
12+
use std::convert::TryFrom;
1113
use std::fmt;
1214
use std::marker::PhantomData;
1315
use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
@@ -402,6 +404,24 @@ where
402404
}
403405
}
404406

407+
fn check_dims_for_sliceinfo<Din, Dout>(indices: &[AxisSliceInfo]) -> Result<(), ShapeError>
408+
where
409+
Din: Dimension,
410+
Dout: Dimension,
411+
{
412+
if let Some(in_ndim) = Din::NDIM {
413+
if in_ndim != indices.in_ndim() {
414+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
415+
}
416+
}
417+
if let Some(out_ndim) = Dout::NDIM {
418+
if out_ndim != indices.out_ndim() {
419+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
420+
}
421+
}
422+
Ok(())
423+
}
424+
405425
impl<T, Din, Dout> SliceInfo<T, Din, Dout>
406426
where
407427
T: AsRef<[AxisSliceInfo]>,
@@ -424,12 +444,8 @@ where
424444
out_dim: PhantomData<Dout>,
425445
) -> SliceInfo<T, Din, Dout> {
426446
if cfg!(debug_assertions) {
427-
if let Some(in_ndim) = Din::NDIM {
428-
assert_eq!(in_ndim, indices.as_ref().in_ndim());
429-
}
430-
if let Some(out_ndim) = Dout::NDIM {
431-
assert_eq!(out_ndim, indices.as_ref().out_ndim());
432-
}
447+
check_dims_for_sliceinfo::<Din, Dout>(indices.as_ref())
448+
.expect("`Din` and `Dout` must be consistent with `indices`.");
433449
}
434450
SliceInfo {
435451
in_dim,
@@ -449,21 +465,14 @@ where
449465
///
450466
/// Errors if `Din` or `Dout` is not consistent with `indices`.
451467
///
468+
/// For common types, a safe alternative is to use `TryFrom` instead.
469+
///
452470
/// # Safety
453471
///
454472
/// The caller must ensure `indices.as_ref()` always returns the same value
455473
/// when called multiple times.
456474
pub unsafe fn new(indices: T) -> Result<SliceInfo<T, Din, Dout>, ShapeError> {
457-
if let Some(in_ndim) = Din::NDIM {
458-
if in_ndim != indices.as_ref().in_ndim() {
459-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
460-
}
461-
}
462-
if let Some(out_ndim) = Dout::NDIM {
463-
if out_ndim != indices.as_ref().out_ndim() {
464-
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
465-
}
466-
}
475+
check_dims_for_sliceinfo::<Din, Dout>(indices.as_ref())?;
467476
Ok(SliceInfo {
468477
in_dim: PhantomData,
469478
out_dim: PhantomData,
@@ -508,6 +517,79 @@ where
508517
}
509518
}
510519

520+
impl<'a, Din, Dout> TryFrom<&'a [AxisSliceInfo]> for &'a SliceInfo<[AxisSliceInfo], Din, Dout>
521+
where
522+
Din: Dimension,
523+
Dout: Dimension,
524+
{
525+
type Error = ShapeError;
526+
527+
fn try_from(
528+
indices: &'a [AxisSliceInfo],
529+
) -> Result<&'a SliceInfo<[AxisSliceInfo], Din, Dout>, ShapeError> {
530+
check_dims_for_sliceinfo::<Din, Dout>(indices)?;
531+
unsafe {
532+
// This is okay because we've already checked the correctness of
533+
// `Din` and `Dout`, and the only non-zero-sized member of
534+
// `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din,
535+
// Dout>` should have the same bitwise representation as
536+
// `&[AxisSliceInfo]`.
537+
Ok(&*(indices as *const [AxisSliceInfo]
538+
as *const SliceInfo<[AxisSliceInfo], Din, Dout>))
539+
}
540+
}
541+
}
542+
543+
impl<Din, Dout> TryFrom<Vec<AxisSliceInfo>> for SliceInfo<Vec<AxisSliceInfo>, Din, Dout>
544+
where
545+
Din: Dimension,
546+
Dout: Dimension,
547+
{
548+
type Error = ShapeError;
549+
550+
fn try_from(
551+
indices: Vec<AxisSliceInfo>,
552+
) -> Result<SliceInfo<Vec<AxisSliceInfo>, Din, Dout>, ShapeError> {
553+
unsafe {
554+
// This is okay because `Vec` always returns the same value for
555+
// `.as_ref()`.
556+
Self::new(indices)
557+
}
558+
}
559+
}
560+
561+
macro_rules! impl_tryfrom_array_for_sliceinfo {
562+
($len:expr) => {
563+
impl<Din, Dout> TryFrom<[AxisSliceInfo; $len]>
564+
for SliceInfo<[AxisSliceInfo; $len], Din, Dout>
565+
where
566+
Din: Dimension,
567+
Dout: Dimension,
568+
{
569+
type Error = ShapeError;
570+
571+
fn try_from(
572+
indices: [AxisSliceInfo; $len],
573+
) -> Result<SliceInfo<[AxisSliceInfo; $len], Din, Dout>, ShapeError> {
574+
unsafe {
575+
// This is okay because `[AxisSliceInfo; N]` always returns
576+
// the same value for `.as_ref()`.
577+
Self::new(indices)
578+
}
579+
}
580+
}
581+
};
582+
}
583+
impl_tryfrom_array_for_sliceinfo!(0);
584+
impl_tryfrom_array_for_sliceinfo!(1);
585+
impl_tryfrom_array_for_sliceinfo!(2);
586+
impl_tryfrom_array_for_sliceinfo!(3);
587+
impl_tryfrom_array_for_sliceinfo!(4);
588+
impl_tryfrom_array_for_sliceinfo!(5);
589+
impl_tryfrom_array_for_sliceinfo!(6);
590+
impl_tryfrom_array_for_sliceinfo!(7);
591+
impl_tryfrom_array_for_sliceinfo!(8);
592+
511593
impl<T, Din, Dout> AsRef<[AxisSliceInfo]> for SliceInfo<T, Din, Dout>
512594
where
513595
T: AsRef<[AxisSliceInfo]>,

tests/array.rs

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use ndarray::prelude::*;
1313
use ndarray::{arr3, rcarr2};
1414
use ndarray::indices;
1515
use ndarray::{AxisSliceInfo, Slice, SliceInfo};
16+
use std::convert::TryFrom;
1617

1718
macro_rules! assert_panics {
1819
($body:expr) => {
@@ -216,15 +217,13 @@ fn test_slice_dyninput_array_fixed() {
216217
#[test]
217218
fn test_slice_array_dyn() {
218219
let mut arr = Array3::<f64>::zeros((5, 2, 5));
219-
let info = &unsafe {
220-
SliceInfo::<_, Ix3, IxDyn>::new([
221-
AxisSliceInfo::from(1..),
222-
AxisSliceInfo::from(1),
223-
AxisSliceInfo::from(NewAxis),
224-
AxisSliceInfo::from(..).step_by(2),
225-
])
226-
.unwrap()
227-
};
220+
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
221+
AxisSliceInfo::from(1..),
222+
AxisSliceInfo::from(1),
223+
AxisSliceInfo::from(NewAxis),
224+
AxisSliceInfo::from(..).step_by(2),
225+
])
226+
.unwrap();
228227
arr.slice(info);
229228
arr.slice_mut(info);
230229
arr.view().slice_move(info);
@@ -234,15 +233,13 @@ fn test_slice_array_dyn() {
234233
#[test]
235234
fn test_slice_dyninput_array_dyn() {
236235
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
237-
let info = &unsafe {
238-
SliceInfo::<_, Ix3, IxDyn>::new([
239-
AxisSliceInfo::from(1..),
240-
AxisSliceInfo::from(1),
241-
AxisSliceInfo::from(NewAxis),
242-
AxisSliceInfo::from(..).step_by(2),
243-
])
244-
.unwrap()
245-
};
236+
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from([
237+
AxisSliceInfo::from(1..),
238+
AxisSliceInfo::from(1),
239+
AxisSliceInfo::from(NewAxis),
240+
AxisSliceInfo::from(..).step_by(2),
241+
])
242+
.unwrap();
246243
arr.slice(info);
247244
arr.slice_mut(info);
248245
arr.view().slice_move(info);
@@ -252,15 +249,13 @@ fn test_slice_dyninput_array_dyn() {
252249
#[test]
253250
fn test_slice_dyninput_vec_fixed() {
254251
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
255-
let info = &unsafe {
256-
SliceInfo::<_, Ix3, Ix3>::new(vec![
257-
AxisSliceInfo::from(1..),
258-
AxisSliceInfo::from(1),
259-
AxisSliceInfo::from(NewAxis),
260-
AxisSliceInfo::from(..).step_by(2),
261-
])
262-
.unwrap()
263-
};
252+
let info = &SliceInfo::<_, Ix3, Ix3>::try_from(vec![
253+
AxisSliceInfo::from(1..),
254+
AxisSliceInfo::from(1),
255+
AxisSliceInfo::from(NewAxis),
256+
AxisSliceInfo::from(..).step_by(2),
257+
])
258+
.unwrap();
264259
arr.slice(info);
265260
arr.slice_mut(info);
266261
arr.view().slice_move(info);
@@ -270,15 +265,13 @@ fn test_slice_dyninput_vec_fixed() {
270265
#[test]
271266
fn test_slice_dyninput_vec_dyn() {
272267
let mut arr = Array3::<f64>::zeros((5, 2, 5)).into_dyn();
273-
let info = &unsafe {
274-
SliceInfo::<_, Ix3, IxDyn>::new(vec![
275-
AxisSliceInfo::from(1..),
276-
AxisSliceInfo::from(1),
277-
AxisSliceInfo::from(NewAxis),
278-
AxisSliceInfo::from(..).step_by(2),
279-
])
280-
.unwrap()
281-
};
268+
let info = &SliceInfo::<_, Ix3, IxDyn>::try_from(vec![
269+
AxisSliceInfo::from(1..),
270+
AxisSliceInfo::from(1),
271+
AxisSliceInfo::from(NewAxis),
272+
AxisSliceInfo::from(..).step_by(2),
273+
])
274+
.unwrap();
282275
arr.slice(info);
283276
arr.slice_mut(info);
284277
arr.view().slice_move(info);

tests/oper.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ fn scaled_add_2() {
562562
fn scaled_add_3() {
563563
use approx::assert_relative_eq;
564564
use ndarray::{SliceInfo, AxisSliceInfo};
565+
use std::convert::TryFrom;
565566

566567
let beta = -2.3;
567568
let sizes = vec![
@@ -595,7 +596,7 @@ fn scaled_add_3() {
595596

596597
{
597598
let mut av = a.slice_mut(s![..;s1, ..;s2]);
598-
let c = c.slice(&unsafe { SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap() });
599+
let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::try_from(cslice).unwrap());
599600

600601
let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
601602
answerv += &(beta * &c);

0 commit comments

Comments
 (0)