@@ -863,6 +863,7 @@ where
863
863
864
864
#[ cfg( feature = "blas" ) ]
865
865
#[ derive( Copy , Clone ) ]
866
+ #[ cfg_attr( test, derive( PartialEq , Eq , Debug ) ) ]
866
867
enum MemoryOrder
867
868
{
868
869
C ,
@@ -887,24 +888,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
887
888
let ( m, n) = dim. into_pattern ( ) ;
888
889
let s0 = stride[ 0 ] as isize ;
889
890
let s1 = stride[ 1 ] as isize ;
890
- let ( inner_stride, outer_dim) = match order {
891
- MemoryOrder :: C => ( s1, n) ,
892
- MemoryOrder :: F => ( s0, m) ,
891
+ let ( inner_stride, outer_stride , inner_dim , outer_dim) = match order {
892
+ MemoryOrder :: C => ( s1, s0 , m , n) ,
893
+ MemoryOrder :: F => ( s0, s1 , n , m) ,
893
894
} ;
895
+
894
896
if !( inner_stride == 1 || outer_dim == 1 ) {
895
897
return false ;
896
898
}
899
+
897
900
if s0 < 1 || s1 < 1 {
898
901
return false ;
899
902
}
903
+
900
904
if ( s0 > blas_index:: MAX as isize || s0 < blas_index:: MIN as isize )
901
905
|| ( s1 > blas_index:: MAX as isize || s1 < blas_index:: MIN as isize )
902
906
{
903
907
return false ;
904
908
}
909
+
910
+ // leading stride must >= the dimension (no broadcasting/aliasing)
911
+ if inner_dim > 1 && ( outer_stride as usize ) < outer_dim {
912
+ return false ;
913
+ }
914
+
905
915
if m > blas_index:: MAX as usize || n > blas_index:: MAX as usize {
906
916
return false ;
907
917
}
918
+
908
919
true
909
920
}
910
921
@@ -1068,8 +1079,26 @@ mod blas_tests
1068
1079
}
1069
1080
1070
1081
#[ test]
1071
- fn test ( )
1082
+ fn blas_too_short_stride ( )
1072
1083
{
1073
- //WIP test that stride is larger than other dimension
1084
+ // leading stride must be longer than the other dimension
1085
+ // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1086
+
1087
+ const N : usize = 5 ;
1088
+ const MAXSTRIDE : usize = N + 2 ;
1089
+ let mut data = [ 0 ; MAXSTRIDE * N ] ;
1090
+ let mut iter = 0 ..data. len ( ) ;
1091
+ data. fill_with ( || iter. next ( ) . unwrap ( ) ) ;
1092
+
1093
+ for stride in 1 ..=MAXSTRIDE {
1094
+ let m = ArrayView :: from_shape ( ( N , N ) . strides ( ( stride, 1 ) ) , & data) . unwrap ( ) ;
1095
+ eprintln ! ( "{:?}" , m) ;
1096
+
1097
+ if stride < N {
1098
+ assert_eq ! ( get_blas_compatible_layout( & m) , None ) ;
1099
+ } else {
1100
+ assert_eq ! ( get_blas_compatible_layout( & m) , Some ( MemoryOrder :: C ) ) ;
1101
+ }
1102
+ }
1074
1103
}
1075
1104
}
0 commit comments