Skip to content

Commit e52b66d

Browse files
author
Muhammad Alhroob
committed
Rewrite the matrix multplication function. rearrange the loops and use blocking to enhance the speed (a factor of 6 for very large matrices)
1 parent 80a2ed7 commit e52b66d

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

math/matrix/src/TMatrixT.cxx

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3076,23 +3076,55 @@ TMatrixT<Element> &TMatrixTAutoloadOps::ElementDiv(TMatrixT<Element> &target, co
30763076
////////////////////////////////////////////////////////////////////////////////
30773077
/// Elementary routine to calculate matrix multiplication A*B
30783078

3079+
30793080
template <class Element>
30803081
void TMatrixTAutoloadOps::AMultB(const Element *const ap, Int_t na, Int_t ncolsa, const Element *const bp, Int_t nb,
3081-
Int_t ncolsb, Element *cp)
3082-
{
3083-
const Element *arp0 = ap; // Pointer to A[i,0];
3084-
while (arp0 < ap + na) {
3085-
for (const Element *bcp = bp; bcp < bp + ncolsb;) { // Pointer to the j-th column of B, Start bcp = B[0,0]
3086-
const Element *arp = arp0; // Pointer to the i-th row of A, reset to A[i,0]
3087-
Element cij = 0;
3088-
while (bcp < bp + nb) { // Scan the i-th row of A and
3089-
cij += *arp++ * *bcp; // the j-th col of B
3090-
bcp += ncolsb;
3082+
Int_t ncolsb, Element *cp) {
3083+
// i,k,j loop order with blocking and unrolling
3084+
const Int_t M = na / ncolsa; // Rows of A
3085+
const Int_t N = ncolsa; // Columns of A, rows of B
3086+
const Int_t P = ncolsb; // Columns of B and C
3087+
3088+
const Int_t BLOCK = 32;
3089+
3090+
#ifdef _OPENMP
3091+
#pragma omp parallel for collapse(2) if(M * P > 10000)
3092+
#endif
3093+
for (Int_t i0 = 0; i0 < M; i0 += BLOCK) {
3094+
for (Int_t k0 = 0; k0 < N; k0 += BLOCK) {
3095+
for (Int_t j0 = 0; j0 < P; j0 += BLOCK) {
3096+
const Int_t iMax = (i0 + BLOCK < M) ? i0 + BLOCK : M;
3097+
const Int_t kMax = (k0 + BLOCK < N) ? k0 + BLOCK : N;
3098+
const Int_t jMax = (j0 + BLOCK < P) ? j0 + BLOCK : P;
3099+
for (Int_t i = i0; i < iMax; ++i) {
3100+
for (Int_t k = k0; k < kMax; ++k) {
3101+
Element aik = ap[i * N + k]; // Hoist A[i,k]
3102+
Int_t j = j0;
3103+
#pragma GCC ivdep
3104+
for (; j <= jMax - 4; j += 4) {
3105+
// Unroll by 4: update C[i,j], C[i,j+1], C[i,j+2], C[i,j+3]
3106+
Element cij0 = cp[i * P + j];
3107+
Element cij1 = cp[i * P + (j + 1)];
3108+
Element cij2 = cp[i * P + (j + 2)];
3109+
Element cij3 = cp[i * P + (j + 3)];
3110+
cij0 += aik * bp[k * P + j];
3111+
cij1 += aik * bp[k * P + (j + 1)];
3112+
cij2 += aik * bp[k * P + (j + 2)];
3113+
cij3 += aik * bp[k * P + (j + 3)];
3114+
cp[i * P + j] = cij0;
3115+
cp[i * P + (j + 1)] = cij1;
3116+
cp[i * P + (j + 2)] = cij2;
3117+
cp[i * P + (j + 3)] = cij3;
3118+
}
3119+
#pragma GCC ivdep
3120+
for (; j < jMax; ++j) {
3121+
// Cleanup loop for remaining j
3122+
cp[i * P + j] += aik * bp[k * P + j];
3123+
}
3124+
}
3125+
}
30913126
}
3092-
*cp++ = cij;
3093-
bcp -= nb - 1; // Set bcp to the (j+1)-th col
30943127
}
3095-
arp0 += ncolsa; // Set ap to the (i+1)-th row
30963128
}
30973129
}
30983130

0 commit comments

Comments
 (0)