Skip to content

Commit 5664dfd

Browse files
committed
used template syntax for microgemm
1 parent e3b2aff commit 5664dfd

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

Code/OptionG/portable.cc

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,22 @@ u32 LisSquare()
535535
else return 0;
536536
}
537537

538-
void microdgemm
538+
template<typename T>
539+
void microgemm
540+
(
541+
u32 M,
542+
u32 N,
543+
u32 K,
544+
T *A,
545+
T *B,
546+
T alpha,
547+
T *C,
548+
s32 gamma,
549+
u32 lmul
550+
);
551+
552+
template<>
553+
void microgemm<double>
539554
(
540555
u32 M,
541556
u32 N,
@@ -548,32 +563,32 @@ void microdgemm
548563
u32 lmul
549564
)
550565
{
551-
u32 L = RV->VLENE(); // L is number of elements per vector register
552-
u32 lambda_eff = RV->lambda() * lmul; // lambda_eff is the maximum lambda for this L
566+
u32 L = RV->VLENE(); // L is number of elements per vector register
567+
u32 lambda_eff = RV->lambda() * lmul; // lambda_eff is the maximum lambda for this L
553568
assert(0 == K % lambda_eff); // for simplicty, K must be a multiple of lambda_eff
554569

555570
vsetvl(5, 0, 64, 1, true, true); // double-precision kernel, set VL to VLENE and LMUL to 1
556571
for (u32 r=16; r<32; r++) vxor.vv(r, r, r); // T = 0
557572

558573
vsetvl(5, RV->lambda() * RV->lambda(), 64, lmul, true, true); // double-precision kernel, set VL to lambda^2 and LMUL accordingly
559-
s32 INCA = M*lambda_eff; s32 INCB = N*lambda_eff; // iteration increments for A and B panels
574+
s32 INCA = M*lambda_eff; s32 INCB = N*lambda_eff; // iteration increments for A and B panels
560575

561576
// the following setup for the A and B register load pointers works because not all loads are active for all values of lmul
562-
double *A0 = A; double *A1 = A0 + LisSquare() * L; double *A2 = A1 + ((2 == lmul) ? LisSquare() * L : L); double *A3 = A2 + LisSquare() * L; // pointers for loads to the A registers
563-
double *B0 = B; double *B1 = B0 + L; double *B2 = B1 + L; double *B3 = B2 + L; // pointers for loads to the B registers
577+
double *A0 = A; double *A1 = A0 + LisSquare() * L; double *A2 = A1 + ((2 == lmul) ? LisSquare() * L : L); double *A3 = A2 + LisSquare() * L; // pointers for loads to the A registers
578+
double *B0 = B; double *B1 = B0 + L; double *B2 = B1 + L; double *B3 = B2 + L; // pointers for loads to the B registers
564579

565580
// the computation loop
566581
for (u32 k=0; k<K; k+=lambda_eff)
567582
{
568583
if (debug > 1) { std::cout << "k = " << k << std::endl; }
569584

570-
// load the 4 A registers
585+
// load the 4 A registers
571586
vmtlfre64.v( 0, A0, lambda_eff); if (debug > 1) { std::cout << "VR[ 0] = "; RV->printVRf64( 0); }
572587
vmtlfre64.v( 1, A1, lambda_eff); if (debug > 1) { std::cout << "VR[ 1] = "; RV->printVRf64( 1); }
573588
vmtlfre64.v( 2, A2, lambda_eff); if (debug > 1) { std::cout << "VR[ 2] = "; RV->printVRf64( 2); }
574589
vmtlfre64.v( 3, A3, lambda_eff); if (debug > 1) { std::cout << "VR[ 3] = "; RV->printVRf64( 3); }
575590

576-
// load the 4 B registers
591+
// load the 4 B registers
577592
vmtlfre64.v( 8, B0, lambda_eff); if (debug > 1) { std::cout << "VR[ 8] = "; RV->printVRf64( 8); }
578593
vmtlfre64.v( 9, B1, lambda_eff); if (debug > 1) { std::cout << "VR[ 9] = "; RV->printVRf64( 9); }
579594
vmtlfre64.v(10, B2, lambda_eff); if (debug > 1) { std::cout << "VR[10] = "; RV->printVRf64(10); }
@@ -582,7 +597,7 @@ void microdgemm
582597
A0 = A0 + INCA ; A1 = A1 + INCA ; A2 = A2 + INCA ; A3 = A3 + INCA; // increment pointers for the A registers
583598
B0 = B0 + INCB ; B1 = B1 + INCB ; B2 = B2 + INCB ; B3 = B3 + INCB; // increment pointers for the B registers
584599

585-
// perform 16 vmmacc's, one for each target register
600+
// perform 16 vmmacc's, one for each target register
586601
vfmmacc.v0(16, 0, 8); vmrotate.vv( 8, 8); if (debug > 1) { std::cout << "VR[16] = "; RV->printVRf64(16); }
587602
vfmmacc.v0(17, 0, 9); vmrotate.vv( 9, 9); if (debug > 1) { std::cout << "VR[17] = "; RV->printVRf64(17); }
588603
vfmmacc.v0(18, 1, 8); vmrotate.vv( 8, 8); if (debug > 1) { std::cout << "VR[18] = "; RV->printVRf64(18); }
@@ -627,13 +642,13 @@ void microdgemm
627642
vsetvl(5, 0, 64, 1, true, true); // double-precision kernel, set VL to VLENE and LMUL to 1
628643
for (u32 vd=0; vd<16; vd++)
629644
{
630-
vmtlfre64.v(vd, C+offset[vd+16], N); // C[i,j] = alpha * T[i,j] + C[i,j]
645+
vmtlfre64.v(vd, C+offset[vd+16], N); // C[i,j] = alpha * T[i,j] + C[i,j]
631646
vfmacc.vf(vd, alpha, vd+16);
632647
vmtsfre64.v(vd, C+offset[vd+16], N);
633648
}
634649
}
635650

636-
void microdgemm_old
651+
void microdgemm
637652
(
638653
u32 M,
639654
u32 N,
@@ -882,7 +897,7 @@ bool run_microgemm
882897
if (debug > 1) { std::cout << "Bp[" << k/lambda_eff << "] = "; print(N, lambda_eff, Bp+k*nu); }
883898
}
884899

885-
microdgemm(M, N, K, Ap, Bp, alpha, D, N, LMUL);
900+
microgemm<double>(M, N, K, Ap, Bp, alpha, D, N, LMUL);
886901

887902
// Check the result
888903
for (u32 j=0; j<N; j++)

0 commit comments

Comments
 (0)