|
7 | 7 | // except according to those terms.
|
8 | 8 |
|
9 | 9 | use crate::imp_prelude::*;
|
10 |
| -use crate::numeric_util; |
| 10 | + |
11 | 11 | #[cfg(feature = "blas")]
|
12 | 12 | use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
|
| 13 | +use crate::numeric_util; |
13 | 14 |
|
14 | 15 | use crate::{LinalgScalar, Zip};
|
15 | 16 |
|
16 | 17 | use std::any::TypeId;
|
17 | 18 | use std::mem::MaybeUninit;
|
18 | 19 | use alloc::vec::Vec;
|
19 | 20 |
|
| 21 | +#[cfg(feature = "blas")] |
| 22 | +use libc::c_int; |
20 | 23 | #[cfg(feature = "blas")]
|
21 | 24 | use std::cmp;
|
22 | 25 | #[cfg(feature = "blas")]
|
23 | 26 | use std::mem::swap;
|
24 |
| -#[cfg(feature = "blas")] |
25 |
| -use libc::c_int; |
26 | 27 |
|
27 | 28 | #[cfg(feature = "blas")]
|
28 | 29 | use cblas_sys as blas_sys;
|
@@ -377,11 +378,15 @@ fn mat_mul_impl<A>(
|
377 | 378 | ) where
|
378 | 379 | A: LinalgScalar,
|
379 | 380 | {
|
380 |
| - |
381 | 381 | // size cutoff for using BLAS
|
382 | 382 | let cut = GEMM_BLAS_CUTOFF;
|
383 | 383 | let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
|
384 |
| - if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) { |
| 384 | + if !(m > cut || n > cut || a > cut) |
| 385 | + || !(same_type::<A, f32>() |
| 386 | + || same_type::<A, f64>() |
| 387 | + || same_type::<A, c32>() |
| 388 | + || same_type::<A, c64>()) |
| 389 | + { |
385 | 390 | return mat_mul_general(alpha, lhs, rhs, beta, c);
|
386 | 391 | }
|
387 | 392 | {
|
@@ -459,6 +464,53 @@ fn mat_mul_impl<A>(
|
459 | 464 | }
|
460 | 465 | gemm!(f32, cblas_sgemm);
|
461 | 466 | gemm!(f64, cblas_dgemm);
|
| 467 | + |
| 468 | + macro_rules! gemm { |
| 469 | + ($ty:ty, $gemm:ident) => { |
| 470 | + if blas_row_major_2d::<$ty, _>(&lhs_) |
| 471 | + && blas_row_major_2d::<$ty, _>(&rhs_) |
| 472 | + && blas_row_major_2d::<$ty, _>(&c_) |
| 473 | + { |
| 474 | + let (m, k) = match lhs_trans { |
| 475 | + CblasNoTrans => lhs_.dim(), |
| 476 | + _ => { |
| 477 | + let (rows, cols) = lhs_.dim(); |
| 478 | + (cols, rows) |
| 479 | + } |
| 480 | + }; |
| 481 | + let n = match rhs_trans { |
| 482 | + CblasNoTrans => rhs_.raw_dim()[1], |
| 483 | + _ => rhs_.raw_dim()[0], |
| 484 | + }; |
| 485 | + // adjust strides, these may [1, 1] for column matrices |
| 486 | + let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index); |
| 487 | + let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index); |
| 488 | + let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index); |
| 489 | + |
| 490 | + // gemm is C ← αA^Op B^Op + βC |
| 491 | + // Where Op is notrans/trans/conjtrans |
| 492 | + unsafe { |
| 493 | + blas_sys::$gemm( |
| 494 | + CblasRowMajor, |
| 495 | + lhs_trans, |
| 496 | + rhs_trans, |
| 497 | + m as blas_index, // m, rows of Op(a) |
| 498 | + n as blas_index, // n, cols of Op(b) |
| 499 | + k as blas_index, // k, cols of Op(a) |
| 500 | + &alpha as *const A as *const _, // alpha |
| 501 | + lhs_.ptr.as_ptr() as *const _, // a |
| 502 | + lhs_stride, // lda |
| 503 | + rhs_.ptr.as_ptr() as *const _, // b |
| 504 | + rhs_stride, // ldb |
| 505 | + &beta as *const A as *const _, // beta |
| 506 | + c_.ptr.as_ptr() as *mut _, // c |
| 507 | + c_stride, // ldc |
| 508 | + ); |
| 509 | + } |
| 510 | + return; |
| 511 | + } |
| 512 | + }; |
| 513 | + } |
462 | 514 | gemm!(c32, cblas_cgemm);
|
463 | 515 | gemm!(c64, cblas_zgemm);
|
464 | 516 | }
|
@@ -609,9 +661,7 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
|
609 | 661 | S3: DataMut<Elem = A>,
|
610 | 662 | A: LinalgScalar,
|
611 | 663 | {
|
612 |
| - unsafe { |
613 |
| - general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) |
614 |
| - } |
| 664 | + unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) } |
615 | 665 | }
|
616 | 666 |
|
617 | 667 | /// General matrix-vector multiplication
|
|
0 commit comments