Skip to content

Commit 7776bfc

Browse files
jturner314bluss
authored andcommitted
Implement CanSlice<IxDyn> for [AxisSliceInfo]
1 parent c66ad8c commit 7776bfc

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

blas-tests/tests/oper.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ extern crate num_traits;
66
use ndarray::linalg::general_mat_mul;
77
use ndarray::linalg::general_mat_vec_mul;
88
use ndarray::prelude::*;
9-
use ndarray::{AxisSliceInfo, Ix, Ixs, SliceInfo};
9+
use ndarray::{AxisSliceInfo, Ix, Ixs};
1010
use ndarray::{Data, LinalgScalar};
1111

1212
use approx::{assert_abs_diff_eq, assert_relative_eq};
@@ -432,7 +432,7 @@ fn scaled_add_3() {
432432

433433
{
434434
let mut av = a.slice_mut(s![..;s1, ..;s2]);
435-
let c = c.slice(&SliceInfo::<_, IxDyn, IxDyn>::new(cslice).unwrap());
435+
let c = c.slice(&*cslice);
436436

437437
let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
438438
answerv += &(beta * &c);

src/impl_methods.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ where
343343
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
344344
pub fn slice<I>(&self, info: &I) -> ArrayView<'_, A, I::OutDim>
345345
where
346-
I: CanSlice<D>,
346+
I: CanSlice<D> + ?Sized,
347347
S: Data,
348348
{
349349
self.view().slice_move(info)
@@ -361,7 +361,7 @@ where
361361
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
362362
pub fn slice_mut<I>(&mut self, info: &I) -> ArrayViewMut<'_, A, I::OutDim>
363363
where
364-
I: CanSlice<D>,
364+
I: CanSlice<D> + ?Sized,
365365
S: DataMut,
366366
{
367367
self.view_mut().slice_move(info)
@@ -412,7 +412,7 @@ where
412412
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
413413
pub fn slice_move<I>(mut self, info: &I) -> ArrayBase<S, I::OutDim>
414414
where
415-
I: CanSlice<D>,
415+
I: CanSlice<D> + ?Sized,
416416
{
417417
// Slice and collapse in-place without changing the number of dimensions.
418418
self.slice_collapse(info);
@@ -464,7 +464,7 @@ where
464464
/// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.)
465465
pub fn slice_collapse<I>(&mut self, info: &I)
466466
where
467-
I: CanSlice<D>,
467+
I: CanSlice<D> + ?Sized,
468468
{
469469
assert_eq!(
470470
info.in_ndim(),

src/slice.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,18 @@ where
360360
}
361361
}
362362

363+
unsafe impl CanSlice<IxDyn> for [AxisSliceInfo] {
364+
type OutDim = IxDyn;
365+
366+
fn in_ndim(&self) -> usize {
367+
self.iter().filter(|s| !s.is_new_axis()).count()
368+
}
369+
370+
fn out_ndim(&self) -> usize {
371+
self.iter().filter(|s| !s.is_index()).count()
372+
}
373+
}
374+
363375
/// Represents all of the necessary information to perform a slice.
364376
///
365377
/// The type `T` is typically `[AxisSliceInfo; n]`, `[AxisSliceInfo]`, or
@@ -422,13 +434,13 @@ where
422434
///
423435
/// Errors if `Din` or `Dout` is not consistent with `indices`.
424436
pub fn new(indices: T) -> Result<SliceInfo<T, Din, Dout>, ShapeError> {
425-
if let Some(ndim) = Din::NDIM {
426-
if ndim != indices.as_ref().iter().filter(|s| !s.is_new_axis()).count() {
437+
if let Some(in_ndim) = Din::NDIM {
438+
if in_ndim != indices.as_ref().in_ndim() {
427439
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
428440
}
429441
}
430-
if let Some(ndim) = Dout::NDIM {
431-
if ndim != indices.as_ref().iter().filter(|s| !s.is_index()).count() {
442+
if let Some(out_ndim) = Dout::NDIM {
443+
if out_ndim != indices.as_ref().out_ndim() {
432444
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape));
433445
}
434446
}
@@ -456,11 +468,7 @@ where
456468
if let Some(ndim) = Din::NDIM {
457469
ndim
458470
} else {
459-
self.indices
460-
.as_ref()
461-
.iter()
462-
.filter(|s| !s.is_new_axis())
463-
.count()
471+
self.indices.as_ref().in_ndim()
464472
}
465473
}
466474

@@ -475,11 +483,7 @@ where
475483
if let Some(ndim) = Dout::NDIM {
476484
ndim
477485
} else {
478-
self.indices
479-
.as_ref()
480-
.iter()
481-
.filter(|s| !s.is_index())
482-
.count()
486+
self.indices.as_ref().out_ndim()
483487
}
484488
}
485489
}

0 commit comments

Comments
 (0)