8
8
use crate :: dimension:: slices_intersect;
9
9
use crate :: error:: { ErrorKind , ShapeError } ;
10
10
use crate :: { ArrayViewMut , DimAdd , Dimension , Ix0 , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
11
+ use alloc:: vec:: Vec ;
12
+ use std:: convert:: TryFrom ;
11
13
use std:: fmt;
12
14
use std:: marker:: PhantomData ;
13
15
use std:: ops:: { Deref , Range , RangeFrom , RangeFull , RangeInclusive , RangeTo , RangeToInclusive } ;
@@ -402,6 +404,24 @@ where
402
404
}
403
405
}
404
406
407
+ fn check_dims_for_sliceinfo < Din , Dout > ( indices : & [ AxisSliceInfo ] ) -> Result < ( ) , ShapeError >
408
+ where
409
+ Din : Dimension ,
410
+ Dout : Dimension ,
411
+ {
412
+ if let Some ( in_ndim) = Din :: NDIM {
413
+ if in_ndim != indices. in_ndim ( ) {
414
+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
415
+ }
416
+ }
417
+ if let Some ( out_ndim) = Dout :: NDIM {
418
+ if out_ndim != indices. out_ndim ( ) {
419
+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
420
+ }
421
+ }
422
+ Ok ( ( ) )
423
+ }
424
+
405
425
impl < T , Din , Dout > SliceInfo < T , Din , Dout >
406
426
where
407
427
T : AsRef < [ AxisSliceInfo ] > ,
@@ -424,12 +444,8 @@ where
424
444
out_dim : PhantomData < Dout > ,
425
445
) -> SliceInfo < T , Din , Dout > {
426
446
if cfg ! ( debug_assertions) {
427
- if let Some ( in_ndim) = Din :: NDIM {
428
- assert_eq ! ( in_ndim, indices. as_ref( ) . in_ndim( ) ) ;
429
- }
430
- if let Some ( out_ndim) = Dout :: NDIM {
431
- assert_eq ! ( out_ndim, indices. as_ref( ) . out_ndim( ) ) ;
432
- }
447
+ check_dims_for_sliceinfo :: < Din , Dout > ( indices. as_ref ( ) )
448
+ . expect ( "`Din` and `Dout` must be consistent with `indices`." ) ;
433
449
}
434
450
SliceInfo {
435
451
in_dim,
@@ -449,21 +465,14 @@ where
449
465
///
450
466
/// Errors if `Din` or `Dout` is not consistent with `indices`.
451
467
///
468
+ /// For common types, a safe alternative is to use `TryFrom` instead.
469
+ ///
452
470
/// # Safety
453
471
///
454
472
/// The caller must ensure `indices.as_ref()` always returns the same value
455
473
/// when called multiple times.
456
474
pub unsafe fn new ( indices : T ) -> Result < SliceInfo < T , Din , Dout > , ShapeError > {
457
- if let Some ( in_ndim) = Din :: NDIM {
458
- if in_ndim != indices. as_ref ( ) . in_ndim ( ) {
459
- return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
460
- }
461
- }
462
- if let Some ( out_ndim) = Dout :: NDIM {
463
- if out_ndim != indices. as_ref ( ) . out_ndim ( ) {
464
- return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
465
- }
466
- }
475
+ check_dims_for_sliceinfo :: < Din , Dout > ( indices. as_ref ( ) ) ?;
467
476
Ok ( SliceInfo {
468
477
in_dim : PhantomData ,
469
478
out_dim : PhantomData ,
@@ -508,6 +517,79 @@ where
508
517
}
509
518
}
510
519
520
+ impl < ' a , Din , Dout > TryFrom < & ' a [ AxisSliceInfo ] > for & ' a SliceInfo < [ AxisSliceInfo ] , Din , Dout >
521
+ where
522
+ Din : Dimension ,
523
+ Dout : Dimension ,
524
+ {
525
+ type Error = ShapeError ;
526
+
527
+ fn try_from (
528
+ indices : & ' a [ AxisSliceInfo ] ,
529
+ ) -> Result < & ' a SliceInfo < [ AxisSliceInfo ] , Din , Dout > , ShapeError > {
530
+ check_dims_for_sliceinfo :: < Din , Dout > ( indices) ?;
531
+ unsafe {
532
+ // This is okay because we've already checked the correctness of
533
+ // `Din` and `Dout`, and the only non-zero-sized member of
534
+ // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din,
535
+ // Dout>` should have the same bitwise representation as
536
+ // `&[AxisSliceInfo]`.
537
+ Ok ( & * ( indices as * const [ AxisSliceInfo ]
538
+ as * const SliceInfo < [ AxisSliceInfo ] , Din , Dout > ) )
539
+ }
540
+ }
541
+ }
542
+
543
+ impl < Din , Dout > TryFrom < Vec < AxisSliceInfo > > for SliceInfo < Vec < AxisSliceInfo > , Din , Dout >
544
+ where
545
+ Din : Dimension ,
546
+ Dout : Dimension ,
547
+ {
548
+ type Error = ShapeError ;
549
+
550
+ fn try_from (
551
+ indices : Vec < AxisSliceInfo > ,
552
+ ) -> Result < SliceInfo < Vec < AxisSliceInfo > , Din , Dout > , ShapeError > {
553
+ unsafe {
554
+ // This is okay because `Vec` always returns the same value for
555
+ // `.as_ref()`.
556
+ Self :: new ( indices)
557
+ }
558
+ }
559
+ }
560
+
561
+ macro_rules! impl_tryfrom_array_for_sliceinfo {
562
+ ( $len: expr) => {
563
+ impl <Din , Dout > TryFrom <[ AxisSliceInfo ; $len] >
564
+ for SliceInfo <[ AxisSliceInfo ; $len] , Din , Dout >
565
+ where
566
+ Din : Dimension ,
567
+ Dout : Dimension ,
568
+ {
569
+ type Error = ShapeError ;
570
+
571
+ fn try_from(
572
+ indices: [ AxisSliceInfo ; $len] ,
573
+ ) -> Result <SliceInfo <[ AxisSliceInfo ; $len] , Din , Dout >, ShapeError > {
574
+ unsafe {
575
+ // This is okay because `[AxisSliceInfo; N]` always returns
576
+ // the same value for `.as_ref()`.
577
+ Self :: new( indices)
578
+ }
579
+ }
580
+ }
581
+ } ;
582
+ }
583
+ impl_tryfrom_array_for_sliceinfo ! ( 0 ) ;
584
+ impl_tryfrom_array_for_sliceinfo ! ( 1 ) ;
585
+ impl_tryfrom_array_for_sliceinfo ! ( 2 ) ;
586
+ impl_tryfrom_array_for_sliceinfo ! ( 3 ) ;
587
+ impl_tryfrom_array_for_sliceinfo ! ( 4 ) ;
588
+ impl_tryfrom_array_for_sliceinfo ! ( 5 ) ;
589
+ impl_tryfrom_array_for_sliceinfo ! ( 6 ) ;
590
+ impl_tryfrom_array_for_sliceinfo ! ( 7 ) ;
591
+ impl_tryfrom_array_for_sliceinfo ! ( 8 ) ;
592
+
511
593
impl < T , Din , Dout > AsRef < [ AxisSliceInfo ] > for SliceInfo < T , Din , Dout >
512
594
where
513
595
T : AsRef < [ AxisSliceInfo ] > ,
0 commit comments