Skip to content

Commit 01bb218

Browse files
committed
blas: Fix to skip array with too short stride
If we have a matrix of dimension say 5 x 5, BLAS requires the leading stride to be >= 5. Smaller cases are possible for read-only array views in ndarray(broadcasting and custom strides). In this case we mark the array as not BLAS compatible
1 parent 27e347c commit 01bb218

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

src/linalg/impl_linalg.rs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,7 @@ where
863863

864864
#[cfg(feature = "blas")]
865865
#[derive(Copy, Clone)]
866+
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
866867
enum MemoryOrder
867868
{
868869
C,
@@ -887,24 +888,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
887888
let (m, n) = dim.into_pattern();
888889
let s0 = stride[0] as isize;
889890
let s1 = stride[1] as isize;
890-
let (inner_stride, outer_dim) = match order {
891-
MemoryOrder::C => (s1, n),
892-
MemoryOrder::F => (s0, m),
891+
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
892+
MemoryOrder::C => (s1, s0, m, n),
893+
MemoryOrder::F => (s0, s1, n, m),
893894
};
895+
894896
if !(inner_stride == 1 || outer_dim == 1) {
895897
return false;
896898
}
899+
897900
if s0 < 1 || s1 < 1 {
898901
return false;
899902
}
903+
900904
if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
901905
|| (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
902906
{
903907
return false;
904908
}
909+
910+
// leading stride must >= the dimension (no broadcasting/aliasing)
911+
if inner_dim > 1 && (outer_stride as usize) < outer_dim {
912+
return false;
913+
}
914+
905915
if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
906916
return false;
907917
}
918+
908919
true
909920
}
910921

@@ -1068,8 +1079,26 @@ mod blas_tests
10681079
}
10691080

10701081
#[test]
1071-
fn test()
1082+
fn blas_too_short_stride()
10721083
{
1073-
//WIP test that stride is larger than other dimension
1084+
// leading stride must be longer than the other dimension
1085+
// Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1086+
1087+
const N: usize = 5;
1088+
const MAXSTRIDE: usize = N + 2;
1089+
let mut data = [0; MAXSTRIDE * N];
1090+
let mut iter = 0..data.len();
1091+
data.fill_with(|| iter.next().unwrap());
1092+
1093+
for stride in 1..=MAXSTRIDE {
1094+
let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
1095+
eprintln!("{:?}", m);
1096+
1097+
if stride < N {
1098+
assert_eq!(get_blas_compatible_layout(&m), None);
1099+
} else {
1100+
assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C));
1101+
}
1102+
}
10741103
}
10751104
}

0 commit comments

Comments
 (0)