Skip to content

Commit 67ec3ad

Browse files
emmatypingbluss
authored andcommitted
Actually use blas with complex numbers
1 parent 1175edb commit 67ec3ad

File tree

2 files changed

+86
-14
lines changed

2 files changed

+86
-14
lines changed

src/linalg/impl_linalg.rs

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,23 @@
77
// except according to those terms.
88

99
use crate::imp_prelude::*;
10-
use crate::numeric_util;
10+
1111
#[cfg(feature = "blas")]
1212
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
13+
use crate::numeric_util;
1314

1415
use crate::{LinalgScalar, Zip};
1516

1617
use std::any::TypeId;
1718
use std::mem::MaybeUninit;
1819
use alloc::vec::Vec;
1920

21+
#[cfg(feature = "blas")]
22+
use libc::c_int;
2023
#[cfg(feature = "blas")]
2124
use std::cmp;
2225
#[cfg(feature = "blas")]
2326
use std::mem::swap;
24-
#[cfg(feature = "blas")]
25-
use libc::c_int;
2627

2728
#[cfg(feature = "blas")]
2829
use cblas_sys as blas_sys;
@@ -377,11 +378,15 @@ fn mat_mul_impl<A>(
377378
) where
378379
A: LinalgScalar,
379380
{
380-
381381
// size cutoff for using BLAS
382382
let cut = GEMM_BLAS_CUTOFF;
383383
let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
384-
if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) {
384+
if !(m > cut || n > cut || a > cut)
385+
|| !(same_type::<A, f32>()
386+
|| same_type::<A, f64>()
387+
|| same_type::<A, c32>()
388+
|| same_type::<A, c64>())
389+
{
385390
return mat_mul_general(alpha, lhs, rhs, beta, c);
386391
}
387392
{
@@ -459,6 +464,53 @@ fn mat_mul_impl<A>(
459464
}
460465
gemm!(f32, cblas_sgemm);
461466
gemm!(f64, cblas_dgemm);
467+
468+
macro_rules! gemm {
469+
($ty:ty, $gemm:ident) => {
470+
if blas_row_major_2d::<$ty, _>(&lhs_)
471+
&& blas_row_major_2d::<$ty, _>(&rhs_)
472+
&& blas_row_major_2d::<$ty, _>(&c_)
473+
{
474+
let (m, k) = match lhs_trans {
475+
CblasNoTrans => lhs_.dim(),
476+
_ => {
477+
let (rows, cols) = lhs_.dim();
478+
(cols, rows)
479+
}
480+
};
481+
let n = match rhs_trans {
482+
CblasNoTrans => rhs_.raw_dim()[1],
483+
_ => rhs_.raw_dim()[0],
484+
};
485+
// adjust strides, these may [1, 1] for column matrices
486+
let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
487+
let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
488+
let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
489+
490+
// gemm is C ← αA^Op B^Op + βC
491+
// Where Op is notrans/trans/conjtrans
492+
unsafe {
493+
blas_sys::$gemm(
494+
CblasRowMajor,
495+
lhs_trans,
496+
rhs_trans,
497+
m as blas_index, // m, rows of Op(a)
498+
n as blas_index, // n, cols of Op(b)
499+
k as blas_index, // k, cols of Op(a)
500+
&alpha as *const A as *const _, // alpha
501+
lhs_.ptr.as_ptr() as *const _, // a
502+
lhs_stride, // lda
503+
rhs_.ptr.as_ptr() as *const _, // b
504+
rhs_stride, // ldb
505+
&beta as *const A as *const _, // beta
506+
c_.ptr.as_ptr() as *mut _, // c
507+
c_stride, // ldc
508+
);
509+
}
510+
return;
511+
}
512+
};
513+
}
462514
gemm!(c32, cblas_cgemm);
463515
gemm!(c64, cblas_zgemm);
464516
}
@@ -609,9 +661,7 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
609661
S3: DataMut<Elem = A>,
610662
A: LinalgScalar,
611663
{
612-
unsafe {
613-
general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut())
614-
}
664+
unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) }
615665
}
616666

617667
/// General matrix-vector multiplication

xtest-blas/tests/oper.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
extern crate approx;
2+
extern crate blas_src;
23
extern crate defmac;
34
extern crate ndarray;
4-
extern crate num_traits;
5-
extern crate blas_src;
65
extern crate num_complex;
6+
extern crate num_traits;
77

88
use ndarray::prelude::*;
99

@@ -275,8 +275,19 @@ fn gemm_c64_1_f() {
275275
let x = range_mat_complex64(n, 1);
276276
let mut y = range_mat_complex64(m, 1);
277277
let answer = reference_mat_mul(&a, &x) + &y;
278-
general_mat_mul(Complex64::new(1.0, 0.), &a, &x, Complex64::new(1.0, 0.), &mut y);
279-
assert_relative_eq!(y.mapv(|i| i.norm_sqr()), answer.mapv(|i| i.norm_sqr()), epsilon = 1e-12, max_relative = 1e-7);
278+
general_mat_mul(
279+
Complex64::new(1.0, 0.),
280+
&a,
281+
&x,
282+
Complex64::new(1.0, 0.),
283+
&mut y,
284+
);
285+
assert_relative_eq!(
286+
y.mapv(|i| i.norm_sqr()),
287+
answer.mapv(|i| i.norm_sqr()),
288+
epsilon = 1e-12,
289+
max_relative = 1e-7
290+
);
280291
}
281292

282293
#[test]
@@ -287,8 +298,19 @@ fn gemm_c32_1_f() {
287298
let x = range_mat_complex(n, 1);
288299
let mut y = range_mat_complex(m, 1);
289300
let answer = reference_mat_mul(&a, &x) + &y;
290-
general_mat_mul(Complex32::new(1.0, 0.), &a, &x, Complex32::new(1.0, 0.), &mut y);
291-
assert_relative_eq!(y.mapv(|i| i.norm_sqr()), answer.mapv(|i| i.norm_sqr()), epsilon = 1e-12, max_relative = 1e-7);
301+
general_mat_mul(
302+
Complex32::new(1.0, 0.),
303+
&a,
304+
&x,
305+
Complex32::new(1.0, 0.),
306+
&mut y,
307+
);
308+
assert_relative_eq!(
309+
y.mapv(|i| i.norm_sqr()),
310+
answer.mapv(|i| i.norm_sqr()),
311+
epsilon = 1e-12,
312+
max_relative = 1e-7
313+
);
292314
}
293315

294316
#[test]

0 commit comments

Comments
 (0)