@@ -79,6 +79,47 @@ static const npy_cfloat oneF = 1.0f, zeroF = 0.0f;
79
79
* #step1 = 1.F, 1., &oneF, &oneD#
80
80
* #step0 = 0.F, 0., &zeroF, &zeroD#
81
81
*/
82
+
83
+ static inline void
84
+ @name @_matrix_copy (npy_bool transpose ,
85
+ void * _ip , npy_intp is_m , npy_intp is_n ,
86
+ void * _op , npy_intp os_m , npy_intp os_n ,
87
+ npy_intp dm , npy_intp dn )
88
+ {
89
+
90
+ char * ip = (char * )_ip , * op = (char * )_op ;
91
+
92
+ npy_intp m , n , ib , ob ;
93
+
94
+ if (transpose ) {
95
+ ib = is_m * dm , ob = os_m * dm ;
96
+
97
+ for (n = 0 ; n < dn ; n ++ ) {
98
+ for (m = 0 ; m < dm ; m ++ ) {
99
+ * (@ctype @ * )op = * (@ctype @ * )ip ;
100
+ ip += is_m ;
101
+ op += os_m ;
102
+ }
103
+ ip += is_n - ib ;
104
+ op += os_n - ob ;
105
+ }
106
+
107
+ return ;
108
+ }
109
+
110
+ ib = is_n * dn , ob = os_n * dn ;
111
+
112
+ for (m = 0 ; m < dm ; m ++ ) {
113
+ for (n = 0 ; n < dn ; n ++ ) {
114
+ * (@ctype @ * )op = * (@ctype @ * )ip ;
115
+ ip += is_n ;
116
+ op += os_n ;
117
+ }
118
+ ip += is_m - ib ;
119
+ op += os_m - ob ;
120
+ }
121
+ }
122
+
82
123
NPY_NO_EXPORT void
83
124
@name @_gemv (void * ip1 , npy_intp is1_m , npy_intp is1_n ,
84
125
void * ip2 , npy_intp is2_n ,
@@ -429,10 +470,43 @@ NPY_NO_EXPORT void
429
470
npy_bool i2blasable = i2_c_blasable || i2_f_blasable ;
430
471
npy_bool o_c_blasable = is_blasable2d (os_m , os_p , dm , dp , sz );
431
472
npy_bool o_f_blasable = is_blasable2d (os_p , os_m , dp , dm , sz );
473
+ npy_bool oblasable = o_c_blasable || o_f_blasable ;
432
474
npy_bool vector_matrix = ((dm == 1 ) && i2blasable &&
433
475
is_blasable2d (is1_n , sz , dn , 1 , sz ));
434
476
npy_bool matrix_vector = ((dp == 1 ) && i1blasable &&
435
477
is_blasable2d (is2_n , sz , dn , 1 , sz ));
478
+ npy_bool noblas_fallback = too_big_for_blas || any_zero_dim ;
479
+ npy_bool matrix_matrix = !noblas_fallback && !special_case ;
480
+ npy_bool allocate_buffer = matrix_matrix && (
481
+ !i1blasable || !i2blasable || !oblasable
482
+ );
483
+
484
+ uint8_t * tmp_ip12op = NULL ;
485
+ void * tmp_ip1 = NULL , * tmp_ip2 = NULL , * tmp_op = NULL ;
486
+
487
+ if (allocate_buffer ){
488
+ npy_intp ip1_size = i1blasable ? 0 : sz * dm * dn ,
489
+ ip2_size = i2blasable ? 0 : sz * dn * dp ,
490
+ op_size = oblasable ? 0 : sz * dm * dp ,
491
+ total_size = ip1_size + ip2_size + op_size ;
492
+
493
+ tmp_ip12op = (uint8_t * )malloc (total_size );
494
+
495
+ if (tmp_ip12op == NULL ) {
496
+ PyGILState_STATE gil_state = PyGILState_Ensure ();
497
+ PyErr_SetString (
498
+ PyExc_MemoryError , "Out of memory in matmul"
499
+ );
500
+ PyGILState_Release (gil_state );
501
+
502
+ return ;
503
+ }
504
+
505
+ tmp_ip1 = tmp_ip12op ;
506
+ tmp_ip2 = tmp_ip12op + ip1_size ;
507
+ tmp_op = tmp_ip12op + ip1_size + ip2_size ;
508
+ }
509
+
436
510
#endif
437
511
438
512
for (iOuter = 0 ; iOuter < dOuter ; iOuter ++ ,
@@ -444,7 +518,7 @@ NPY_NO_EXPORT void
444
518
* PyUFunc_MatmulLoopSelector. But that call does not have access to
445
519
* n, m, p and strides.
446
520
*/
447
- if (too_big_for_blas || any_zero_dim ) {
521
+ if (noblas_fallback ) {
448
522
@TYPE @_matmul_inner_noblas (ip1 , is1_m , is1_n ,
449
523
ip2 , is2_n , is2_p ,
450
524
op , os_m , os_p , dm , dn , dp );
@@ -478,30 +552,73 @@ NPY_NO_EXPORT void
478
552
op , os_m , os_p , dm , dn , dp );
479
553
}
480
554
} else {
481
- /* matrix @ matrix */
482
- if (i1blasable && i2blasable && o_c_blasable ) {
483
- @TYPE @_matmul_matrixmatrix (ip1 , is1_m , is1_n ,
484
- ip2 , is2_n , is2_p ,
485
- op , os_m , os_p ,
486
- dm , dn , dp );
487
- } else if (i1blasable && i2blasable && o_f_blasable ) {
488
- /*
489
- * Use transpose equivalence:
490
- * matmul(a, b, o) == matmul(b.T, a.T, o.T)
491
- */
492
- @TYPE @_matmul_matrixmatrix (ip2 , is2_p , is2_n ,
493
- ip1 , is1_n , is1_m ,
494
- op , os_p , os_m ,
495
- dp , dn , dm );
496
- } else {
497
- /*
498
- * If parameters are castable to int and we copy the
499
- * non-blasable (or non-ccontiguous output)
500
- * we could still use BLAS, see gh-12365.
501
- */
502
- @TYPE @_matmul_inner_noblas (ip1 , is1_m , is1_n ,
503
- ip2 , is2_n , is2_p ,
504
- op , os_m , os_p , dm , dn , dp );
555
+ /* matrix @ matrix
556
+ * copy if not blasable, see gh-12365 & gh-23588 */
557
+ npy_bool i1_transpose = is1_m < is1_n ,
558
+ i2_transpose = is2_n < is2_p ,
559
+ o_transpose = os_m < os_p ;
560
+
561
+ npy_intp tmp_is1_m = i1_transpose ? sz : sz * dn ,
562
+ tmp_is1_n = i1_transpose ? sz * dm : sz ,
563
+ tmp_is2_n = i2_transpose ? sz : sz * dp ,
564
+ tmp_is2_p = i2_transpose ? sz * dn : sz ,
565
+ tmp_os_m = o_transpose ? sz : sz * dp ,
566
+ tmp_os_p = o_transpose ? sz * dm : sz ;
567
+
568
+ if (!i1blasable ) {
569
+ @TYPE @_matrix_copy (
570
+ i1_transpose , ip1 , is1_m , is1_n ,
571
+ tmp_ip1 , tmp_is1_m , tmp_is1_n ,
572
+ dm , dn
573
+ );
574
+ }
575
+
576
+ if (!i2blasable ) {
577
+ @TYPE @_matrix_copy (
578
+ i2_transpose , ip2 , is2_n , is2_p ,
579
+ tmp_ip2 , tmp_is2_n , tmp_is2_p ,
580
+ dn , dp
581
+ );
582
+ }
583
+
584
+ void * ip1_ = i1blasable ? ip1 : tmp_ip1 ,
585
+ * ip2_ = i2blasable ? ip2 : tmp_ip2 ,
586
+ * op_ = oblasable ? op : tmp_op ;
587
+
588
+ npy_intp is1_m_ = i1blasable ? is1_m : tmp_is1_m ,
589
+ is1_n_ = i1blasable ? is1_n : tmp_is1_n ,
590
+ is2_n_ = i2blasable ? is2_n : tmp_is2_n ,
591
+ is2_p_ = i2blasable ? is2_p : tmp_is2_p ,
592
+ os_m_ = oblasable ? os_m : tmp_os_m ,
593
+ os_p_ = oblasable ? os_p : tmp_os_p ;
594
+
595
+ /*
596
+ * Use transpose equivalence:
597
+ * matmul(a, b, o) == matmul(b.T, a.T, o.T)
598
+ */
599
+ if (o_f_blasable ) {
600
+ @TYPE @_matmul_matrixmatrix (
601
+ ip2_ , is2_p_ , is2_n_ ,
602
+ ip1_ , is1_n_ , is1_m_ ,
603
+ op_ , os_p_ , os_m_ ,
604
+ dp , dn , dm
605
+ );
606
+ }
607
+ else {
608
+ @TYPE @_matmul_matrixmatrix (
609
+ ip1_ , is1_m_ , is1_n_ ,
610
+ ip2_ , is2_n_ , is2_p_ ,
611
+ op_ , os_m_ , os_p_ ,
612
+ dm , dn , dp
613
+ );
614
+ }
615
+
616
+ if (!oblasable ){
617
+ @TYPE @_matrix_copy (
618
+ o_transpose , tmp_op , tmp_os_m , tmp_os_p ,
619
+ op , os_m , os_p ,
620
+ dm , dp
621
+ );
505
622
}
506
623
}
507
624
#else
@@ -511,6 +628,9 @@ NPY_NO_EXPORT void
511
628
512
629
#endif
513
630
}
631
+ #if @USEBLAS @ && defined(HAVE_CBLAS )
632
+ if (allocate_buffer ) free (tmp_ip12op );
633
+ #endif
514
634
}
515
635
516
636
/**end repeat**/
0 commit comments