Skip to content

Commit 8712bac

Browse files
emmatypingbluss
authored andcommitted
Merge complex and real gemm calls
1 parent 67ec3ad commit 8712bac

File tree

1 file changed

+19
-50
lines changed

1 file changed

+19
-50
lines changed

src/linalg/impl_linalg.rs

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,23 @@ fn mat_mul_impl<A>(
416416
rhs_trans = CblasTrans;
417417
}
418418

419+
macro_rules! cast_ty {
420+
(f32, $var:ident) => {
421+
cast_as(&$var)
422+
};
423+
(f64, $var:ident) => {
424+
cast_as(&$var)
425+
};
426+
(c32, $var:ident) => {
427+
&$var as *const A as *const _
428+
};
429+
(c64, $var:ident) => {
430+
&$var as *const A as *const _
431+
};
432+
}
433+
419434
macro_rules! gemm {
420-
($ty:ty, $gemm:ident) => {
435+
($ty:tt, $gemm:ident) => {
421436
if blas_row_major_2d::<$ty, _>(&lhs_)
422437
&& blas_row_major_2d::<$ty, _>(&rhs_)
423438
&& blas_row_major_2d::<$ty, _>(&c_)
@@ -437,9 +452,9 @@ fn mat_mul_impl<A>(
437452
let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
438453
let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
439454
let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
440-
441455
// gemm is C ← αA^Op B^Op + βC
442456
// Where Op is notrans/trans/conjtrans
457+
443458
unsafe {
444459
blas_sys::$gemm(
445460
CblasRowMajor,
@@ -448,12 +463,12 @@ fn mat_mul_impl<A>(
448463
m as blas_index, // m, rows of Op(a)
449464
n as blas_index, // n, cols of Op(b)
450465
k as blas_index, // k, cols of Op(a)
451-
cast_as(&alpha), // alpha
466+
cast_ty!($ty, alpha), // alpha
452467
lhs_.ptr.as_ptr() as *const _, // a
453468
lhs_stride, // lda
454469
rhs_.ptr.as_ptr() as *const _, // b
455470
rhs_stride, // ldb
456-
cast_as(&beta), // beta
471+
cast_ty!($ty, beta), // beta
457472
c_.ptr.as_ptr() as *mut _, // c
458473
c_stride, // ldc
459474
);
@@ -465,52 +480,6 @@ fn mat_mul_impl<A>(
465480
gemm!(f32, cblas_sgemm);
466481
gemm!(f64, cblas_dgemm);
467482

468-
macro_rules! gemm {
469-
($ty:ty, $gemm:ident) => {
470-
if blas_row_major_2d::<$ty, _>(&lhs_)
471-
&& blas_row_major_2d::<$ty, _>(&rhs_)
472-
&& blas_row_major_2d::<$ty, _>(&c_)
473-
{
474-
let (m, k) = match lhs_trans {
475-
CblasNoTrans => lhs_.dim(),
476-
_ => {
477-
let (rows, cols) = lhs_.dim();
478-
(cols, rows)
479-
}
480-
};
481-
let n = match rhs_trans {
482-
CblasNoTrans => rhs_.raw_dim()[1],
483-
_ => rhs_.raw_dim()[0],
484-
};
485-
// adjust strides, these may [1, 1] for column matrices
486-
let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
487-
let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
488-
let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
489-
490-
// gemm is C ← αA^Op B^Op + βC
491-
// Where Op is notrans/trans/conjtrans
492-
unsafe {
493-
blas_sys::$gemm(
494-
CblasRowMajor,
495-
lhs_trans,
496-
rhs_trans,
497-
m as blas_index, // m, rows of Op(a)
498-
n as blas_index, // n, cols of Op(b)
499-
k as blas_index, // k, cols of Op(a)
500-
&alpha as *const A as *const _, // alpha
501-
lhs_.ptr.as_ptr() as *const _, // a
502-
lhs_stride, // lda
503-
rhs_.ptr.as_ptr() as *const _, // b
504-
rhs_stride, // ldb
505-
&beta as *const A as *const _, // beta
506-
c_.ptr.as_ptr() as *mut _, // c
507-
c_stride, // ldc
508-
);
509-
}
510-
return;
511-
}
512-
};
513-
}
514483
gemm!(c32, cblas_cgemm);
515484
gemm!(c64, cblas_zgemm);
516485
}

0 commit comments

Comments
 (0)