@@ -416,8 +416,23 @@ fn mat_mul_impl<A>(
416
416
rhs_trans = CblasTrans ;
417
417
}
418
418
419
+ macro_rules! cast_ty {
420
+ ( f32 , $var: ident) => {
421
+ cast_as( & $var)
422
+ } ;
423
+ ( f64 , $var: ident) => {
424
+ cast_as( & $var)
425
+ } ;
426
+ ( c32, $var: ident) => {
427
+ & $var as * const A as * const _
428
+ } ;
429
+ ( c64, $var: ident) => {
430
+ & $var as * const A as * const _
431
+ } ;
432
+ }
433
+
419
434
macro_rules! gemm {
420
- ( $ty: ty , $gemm: ident) => {
435
+ ( $ty: tt , $gemm: ident) => {
421
436
if blas_row_major_2d:: <$ty, _>( & lhs_)
422
437
&& blas_row_major_2d:: <$ty, _>( & rhs_)
423
438
&& blas_row_major_2d:: <$ty, _>( & c_)
@@ -437,9 +452,9 @@ fn mat_mul_impl<A>(
437
452
let lhs_stride = cmp:: max( lhs_. strides( ) [ 0 ] as blas_index, k as blas_index) ;
438
453
let rhs_stride = cmp:: max( rhs_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
439
454
let c_stride = cmp:: max( c_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
440
-
441
455
// gemm is C ← αA^Op B^Op + βC
442
456
// Where Op is notrans/trans/conjtrans
457
+
443
458
unsafe {
444
459
blas_sys:: $gemm(
445
460
CblasRowMajor ,
@@ -448,12 +463,12 @@ fn mat_mul_impl<A>(
448
463
m as blas_index, // m, rows of Op(a)
449
464
n as blas_index, // n, cols of Op(b)
450
465
k as blas_index, // k, cols of Op(a)
451
- cast_as ( & alpha) , // alpha
466
+ cast_ty! ( $ty , alpha) , // alpha
452
467
lhs_. ptr. as_ptr( ) as * const _, // a
453
468
lhs_stride, // lda
454
469
rhs_. ptr. as_ptr( ) as * const _, // b
455
470
rhs_stride, // ldb
456
- cast_as ( & beta) , // beta
471
+ cast_ty! ( $ty , beta) , // beta
457
472
c_. ptr. as_ptr( ) as * mut _, // c
458
473
c_stride, // ldc
459
474
) ;
@@ -465,52 +480,6 @@ fn mat_mul_impl<A>(
465
480
gemm ! ( f32 , cblas_sgemm) ;
466
481
gemm ! ( f64 , cblas_dgemm) ;
467
482
468
- macro_rules! gemm {
469
- ( $ty: ty, $gemm: ident) => {
470
- if blas_row_major_2d:: <$ty, _>( & lhs_)
471
- && blas_row_major_2d:: <$ty, _>( & rhs_)
472
- && blas_row_major_2d:: <$ty, _>( & c_)
473
- {
474
- let ( m, k) = match lhs_trans {
475
- CblasNoTrans => lhs_. dim( ) ,
476
- _ => {
477
- let ( rows, cols) = lhs_. dim( ) ;
478
- ( cols, rows)
479
- }
480
- } ;
481
- let n = match rhs_trans {
482
- CblasNoTrans => rhs_. raw_dim( ) [ 1 ] ,
483
- _ => rhs_. raw_dim( ) [ 0 ] ,
484
- } ;
485
- // adjust strides, these may [1, 1] for column matrices
486
- let lhs_stride = cmp:: max( lhs_. strides( ) [ 0 ] as blas_index, k as blas_index) ;
487
- let rhs_stride = cmp:: max( rhs_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
488
- let c_stride = cmp:: max( c_. strides( ) [ 0 ] as blas_index, n as blas_index) ;
489
-
490
- // gemm is C ← αA^Op B^Op + βC
491
- // Where Op is notrans/trans/conjtrans
492
- unsafe {
493
- blas_sys:: $gemm(
494
- CblasRowMajor ,
495
- lhs_trans,
496
- rhs_trans,
497
- m as blas_index, // m, rows of Op(a)
498
- n as blas_index, // n, cols of Op(b)
499
- k as blas_index, // k, cols of Op(a)
500
- & alpha as * const A as * const _, // alpha
501
- lhs_. ptr. as_ptr( ) as * const _, // a
502
- lhs_stride, // lda
503
- rhs_. ptr. as_ptr( ) as * const _, // b
504
- rhs_stride, // ldb
505
- & beta as * const A as * const _, // beta
506
- c_. ptr. as_ptr( ) as * mut _, // c
507
- c_stride, // ldc
508
- ) ;
509
- }
510
- return ;
511
- }
512
- } ;
513
- }
514
483
gemm ! ( c32, cblas_cgemm) ;
515
484
gemm ! ( c64, cblas_zgemm) ;
516
485
}
0 commit comments