Skip to content

Commit f104291

Browse files
authored
ENH: speed up matmul for non-contiguous operands
At least for larger matrices, copying the matrix into a temporary and then applying blas matrix multiplication can have large speed advantages (since optimized matrix multiplication is much better than the trivial approach used otherwise). Thus, this creates temporary copies and then applies it.
1 parent ede4009 commit f104291

File tree

2 files changed

+183
-25
lines changed

2 files changed

+183
-25
lines changed

benchmarks/benchmarks/bench_linalg.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,41 @@ def time_transpose(self, shape, npdtypes):
217217

218218
def time_vdot(self, shape, npdtypes):
219219
np.vdot(self.xarg, self.x2arg)
220+
221+
222+
class MatmulStrided(Benchmark):
223+
# some interesting points selected from
224+
# https://github.com/numpy/numpy/pull/23752#issuecomment-2629521597
225+
# (m, p, n, batch_size)
226+
args = [
227+
(2, 2, 2, 1), (2, 2, 2, 10), (5, 5, 5, 1), (5, 5, 5, 10),
228+
(10, 10, 10, 1), (10, 10, 10, 10), (20, 20, 20, 1), (20, 20, 20, 10),
229+
(50, 50, 50, 1), (50, 50, 50, 10),
230+
(150, 150, 100, 1), (150, 150, 100, 10),
231+
(400, 400, 100, 1), (400, 400, 100, 10)
232+
]
233+
234+
param_names = ['configuration']
235+
236+
def __init__(self):
237+
self.args_map = {
238+
'matmul_m%03d_p%03d_n%03d_bs%02d' % arg: arg for arg in self.args
239+
}
240+
241+
self.params = [list(self.args_map.keys())]
242+
243+
def setup(self, configuration):
244+
m, p, n, batch_size = self.args_map[configuration]
245+
246+
self.a1raw = np.random.rand(batch_size * m * 2 * n).reshape(
247+
(batch_size, m, 2 * n)
248+
)
249+
250+
self.a1 = self.a1raw[:, :, ::2]
251+
252+
self.a2 = np.random.rand(batch_size * n * p).reshape(
253+
(batch_size, n, p)
254+
)
255+
256+
def time_matmul(self, configuration):
257+
return np.matmul(self.a1, self.a2)

numpy/_core/src/umath/matmul.c.src

Lines changed: 145 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,47 @@ static const npy_cfloat oneF = 1.0f, zeroF = 0.0f;
7979
* #step1 = 1.F, 1., &oneF, &oneD#
8080
* #step0 = 0.F, 0., &zeroF, &zeroD#
8181
*/
82+
83+
static inline void
84+
@name@_matrix_copy(npy_bool transpose,
85+
void *_ip, npy_intp is_m, npy_intp is_n,
86+
void *_op, npy_intp os_m, npy_intp os_n,
87+
npy_intp dm, npy_intp dn)
88+
{
89+
90+
char *ip = (char *)_ip, *op = (char *)_op;
91+
92+
npy_intp m, n, ib, ob;
93+
94+
if (transpose) {
95+
ib = is_m * dm, ob = os_m * dm;
96+
97+
for (n = 0; n < dn; n++) {
98+
for (m = 0; m < dm; m++) {
99+
*(@ctype@ *)op = *(@ctype@ *)ip;
100+
ip += is_m;
101+
op += os_m;
102+
}
103+
ip += is_n - ib;
104+
op += os_n - ob;
105+
}
106+
107+
return;
108+
}
109+
110+
ib = is_n * dn, ob = os_n * dn;
111+
112+
for (m = 0; m < dm; m++) {
113+
for (n = 0; n < dn; n++) {
114+
*(@ctype@ *)op = *(@ctype@ *)ip;
115+
ip += is_n;
116+
op += os_n;
117+
}
118+
ip += is_m - ib;
119+
op += os_m - ob;
120+
}
121+
}
122+
82123
NPY_NO_EXPORT void
83124
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
84125
void *ip2, npy_intp is2_n,
@@ -429,10 +470,43 @@ NPY_NO_EXPORT void
429470
npy_bool i2blasable = i2_c_blasable || i2_f_blasable;
430471
npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz);
431472
npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz);
473+
npy_bool oblasable = o_c_blasable || o_f_blasable;
432474
npy_bool vector_matrix = ((dm == 1) && i2blasable &&
433475
is_blasable2d(is1_n, sz, dn, 1, sz));
434476
npy_bool matrix_vector = ((dp == 1) && i1blasable &&
435477
is_blasable2d(is2_n, sz, dn, 1, sz));
478+
npy_bool noblas_fallback = too_big_for_blas || any_zero_dim;
479+
npy_bool matrix_matrix = !noblas_fallback && !special_case;
480+
npy_bool allocate_buffer = matrix_matrix && (
481+
!i1blasable || !i2blasable || !oblasable
482+
);
483+
484+
uint8_t *tmp_ip12op = NULL;
485+
void *tmp_ip1 = NULL, *tmp_ip2 = NULL, *tmp_op = NULL;
486+
487+
if (allocate_buffer){
488+
npy_intp ip1_size = i1blasable ? 0 : sz * dm * dn,
489+
ip2_size = i2blasable ? 0 : sz * dn * dp,
490+
op_size = oblasable ? 0 : sz * dm * dp,
491+
total_size = ip1_size + ip2_size + op_size;
492+
493+
tmp_ip12op = (uint8_t*)malloc(total_size);
494+
495+
if (tmp_ip12op == NULL) {
496+
PyGILState_STATE gil_state = PyGILState_Ensure();
497+
PyErr_SetString(
498+
PyExc_MemoryError, "Out of memory in matmul"
499+
);
500+
PyGILState_Release(gil_state);
501+
502+
return;
503+
}
504+
505+
tmp_ip1 = tmp_ip12op;
506+
tmp_ip2 = tmp_ip12op + ip1_size;
507+
tmp_op = tmp_ip12op + ip1_size + ip2_size;
508+
}
509+
436510
#endif
437511

438512
for (iOuter = 0; iOuter < dOuter; iOuter++,
@@ -444,7 +518,7 @@ NPY_NO_EXPORT void
444518
* PyUFunc_MatmulLoopSelector. But that call does not have access to
445519
* n, m, p and strides.
446520
*/
447-
if (too_big_for_blas || any_zero_dim) {
521+
if (noblas_fallback) {
448522
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
449523
ip2, is2_n, is2_p,
450524
op, os_m, os_p, dm, dn, dp);
@@ -478,30 +552,73 @@ NPY_NO_EXPORT void
478552
op, os_m, os_p, dm, dn, dp);
479553
}
480554
} else {
481-
/* matrix @ matrix */
482-
if (i1blasable && i2blasable && o_c_blasable) {
483-
@TYPE@_matmul_matrixmatrix(ip1, is1_m, is1_n,
484-
ip2, is2_n, is2_p,
485-
op, os_m, os_p,
486-
dm, dn, dp);
487-
} else if (i1blasable && i2blasable && o_f_blasable) {
488-
/*
489-
* Use transpose equivalence:
490-
* matmul(a, b, o) == matmul(b.T, a.T, o.T)
491-
*/
492-
@TYPE@_matmul_matrixmatrix(ip2, is2_p, is2_n,
493-
ip1, is1_n, is1_m,
494-
op, os_p, os_m,
495-
dp, dn, dm);
496-
} else {
497-
/*
498-
* If parameters are castable to int and we copy the
499-
* non-blasable (or non-ccontiguous output)
500-
* we could still use BLAS, see gh-12365.
501-
*/
502-
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
503-
ip2, is2_n, is2_p,
504-
op, os_m, os_p, dm, dn, dp);
555+
/* matrix @ matrix
556+
* copy if not blasable, see gh-12365 & gh-23588 */
557+
npy_bool i1_transpose = is1_m < is1_n,
558+
i2_transpose = is2_n < is2_p,
559+
o_transpose = os_m < os_p;
560+
561+
npy_intp tmp_is1_m = i1_transpose ? sz : sz*dn,
562+
tmp_is1_n = i1_transpose ? sz*dm : sz,
563+
tmp_is2_n = i2_transpose ? sz : sz*dp,
564+
tmp_is2_p = i2_transpose ? sz*dn : sz,
565+
tmp_os_m = o_transpose ? sz : sz*dp,
566+
tmp_os_p = o_transpose ? sz*dm : sz;
567+
568+
if (!i1blasable) {
569+
@TYPE@_matrix_copy(
570+
i1_transpose, ip1, is1_m, is1_n,
571+
tmp_ip1, tmp_is1_m, tmp_is1_n,
572+
dm, dn
573+
);
574+
}
575+
576+
if (!i2blasable) {
577+
@TYPE@_matrix_copy(
578+
i2_transpose, ip2, is2_n, is2_p,
579+
tmp_ip2, tmp_is2_n, tmp_is2_p,
580+
dn, dp
581+
);
582+
}
583+
584+
void *ip1_ = i1blasable ? ip1 : tmp_ip1,
585+
*ip2_ = i2blasable ? ip2 : tmp_ip2,
586+
*op_ = oblasable ? op : tmp_op;
587+
588+
npy_intp is1_m_ = i1blasable ? is1_m : tmp_is1_m,
589+
is1_n_ = i1blasable ? is1_n : tmp_is1_n,
590+
is2_n_ = i2blasable ? is2_n : tmp_is2_n,
591+
is2_p_ = i2blasable ? is2_p : tmp_is2_p,
592+
os_m_ = oblasable ? os_m : tmp_os_m,
593+
os_p_ = oblasable ? os_p : tmp_os_p;
594+
595+
/*
596+
* Use transpose equivalence:
597+
* matmul(a, b, o) == matmul(b.T, a.T, o.T)
598+
*/
599+
if (o_f_blasable) {
600+
@TYPE@_matmul_matrixmatrix(
601+
ip2_, is2_p_, is2_n_,
602+
ip1_, is1_n_, is1_m_,
603+
op_, os_p_, os_m_,
604+
dp, dn, dm
605+
);
606+
}
607+
else {
608+
@TYPE@_matmul_matrixmatrix(
609+
ip1_, is1_m_, is1_n_,
610+
ip2_, is2_n_, is2_p_,
611+
op_, os_m_, os_p_,
612+
dm, dn, dp
613+
);
614+
}
615+
616+
if(!oblasable){
617+
@TYPE@_matrix_copy(
618+
o_transpose, tmp_op, tmp_os_m, tmp_os_p,
619+
op, os_m, os_p,
620+
dm, dp
621+
);
505622
}
506623
}
507624
#else
@@ -511,6 +628,9 @@ NPY_NO_EXPORT void
511628

512629
#endif
513630
}
631+
#if @USEBLAS@ && defined(HAVE_CBLAS)
632+
if (allocate_buffer) free(tmp_ip12op);
633+
#endif
514634
}
515635

516636
/**end repeat**/

0 commit comments

Comments
 (0)