Skip to content

Commit 267063b

Browse files
jturner314bluss
andcommitted
Fix blas usage with vector of stride <= 0
Stride == 0 is unsuppored for vector increments; Stride < 0 would be supported but the code needs to be adapted to pass the right pointer for this case (lowest in memory pointer). Co-authored-by: bluss <[email protected]>
1 parent abfdf4d commit 267063b

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/linalg/impl_linalg.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
use crate::imp_prelude::*;
1010
use crate::numeric_util;
11+
#[cfg(feature = "blas")]
12+
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
1113

1214
use crate::{LinalgScalar, Zip};
1315

@@ -649,6 +651,12 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
649651
}
650652
};
651653

654+
// Low addr in memory pointers required for x, y
655+
let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
656+
let x_ptr = x.ptr.as_ptr().sub(x_offset);
657+
let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides);
658+
let y_ptr = y.ptr.as_ptr().sub(y_offset);
659+
652660
let x_stride = x.strides()[0] as blas_index;
653661
let y_stride = y.strides()[0] as blas_index;
654662

@@ -660,10 +668,10 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
660668
cast_as(&alpha), // alpha
661669
a.ptr.as_ptr() as *const _, // a
662670
a_stride, // lda
663-
x.ptr.as_ptr() as *const _, // x
671+
x_ptr as *const _, // x
664672
x_stride,
665-
cast_as(&beta), // beta
666-
y.ptr.as_ptr() as *mut _, // x
673+
cast_as(&beta), // beta
674+
y_ptr as *mut _, // y
667675
y_stride,
668676
);
669677
return;
@@ -719,7 +727,10 @@ where
719727
return false;
720728
}
721729
let stride = a.strides()[0];
722-
if stride > blas_index::max_value() as isize || stride < blas_index::min_value() as isize {
730+
if stride == 0
731+
|| stride > blas_index::max_value() as isize
732+
|| stride < blas_index::min_value() as isize
733+
{
723734
return false;
724735
}
725736
true

xtest-blas/tests/oper.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,25 @@ fn mat_vec_product_1d() {
2121
assert_eq!(a.t().dot(&b), ans);
2222
}
2323

24+
#[test]
25+
fn mat_vec_product_1d_broadcast() {
26+
let a = arr2(&[[1.], [2.], [3.]]);
27+
let b = arr1(&[1.]);
28+
let b = b.broadcast(3).unwrap();
29+
let ans = arr1(&[6.]);
30+
assert_eq!(a.t().dot(&b), ans);
31+
}
32+
33+
#[test]
34+
fn mat_vec_product_1d_inverted_axis() {
35+
let a = arr2(&[[1.], [2.], [3.]]);
36+
let mut b = arr1(&[1., 2., 3.]);
37+
b.invert_axis(Axis(0));
38+
39+
let ans = arr1(&[3. + 4. + 3.]);
40+
assert_eq!(a.t().dot(&b), ans);
41+
}
42+
2443
fn range_mat(m: Ix, n: Ix) -> Array2<f32> {
2544
Array::linspace(0., (m * n) as f32 - 1., m * n)
2645
.into_shape((m, n))

0 commit comments

Comments
 (0)