Skip to content

Commit e3b2aff

Browse files
committed
cleanup of the microdgemm kernel
1 parent 3797102 commit e3b2aff

File tree

1 file changed

+150
-42
lines changed

1 file changed

+150
-42
lines changed

Code/OptionG/portable.cc

Lines changed: 150 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -528,14 +528,112 @@ vmtlfre64_t vmtlfre64;
528528
vmtsfre64_t vmtsfre64;
529529
vmrotate_t vmrotate;
530530

531-
bool LisSquare()
531+
u32 LisSquare()
532532
{
533533
double rootL = sqrt(RV->VLENE());
534-
if ((rootL*rootL) == RV->VLENE()) return true;
535-
else return false;
534+
if ((rootL*rootL) == RV->VLENE()) return 1;
535+
else return 0;
536536
}
537537

538538
void microdgemm
539+
(
540+
u32 M,
541+
u32 N,
542+
u32 K,
543+
double *A,
544+
double *B,
545+
double alpha,
546+
double *C,
547+
s32 gamma,
548+
u32 lmul
549+
)
550+
{
551+
u32 L = RV->VLENE(); // L is number of elements per vector register
552+
u32 lambda_eff = RV->lambda() * lmul; // lambda_eff is the maximum lambda for this L
553+
assert(0 == K % lambda_eff); // for simplicty, K must be a multiple of lambda_eff
554+
555+
vsetvl(5, 0, 64, 1, true, true); // double-precision kernel, set VL to VLENE and LMUL to 1
556+
for (u32 r=16; r<32; r++) vxor.vv(r, r, r); // T = 0
557+
558+
vsetvl(5, RV->lambda() * RV->lambda(), 64, lmul, true, true); // double-precision kernel, set VL to lambda^2 and LMUL accordingly
559+
s32 INCA = M*lambda_eff; s32 INCB = N*lambda_eff; // iteration increments for A and B panels
560+
561+
// the following setup for the A and B register load pointers works because not all loads are active for all values of lmul
562+
double *A0 = A; double *A1 = A0 + LisSquare() * L; double *A2 = A1 + ((2 == lmul) ? LisSquare() * L : L); double *A3 = A2 + LisSquare() * L; // pointers for loads to the A registers
563+
double *B0 = B; double *B1 = B0 + L; double *B2 = B1 + L; double *B3 = B2 + L; // pointers for loads to the B registers
564+
565+
// the computation loop
566+
for (u32 k=0; k<K; k+=lambda_eff)
567+
{
568+
if (debug > 1) { std::cout << "k = " << k << std::endl; }
569+
570+
// load the 4 A registers
571+
vmtlfre64.v( 0, A0, lambda_eff); if (debug > 1) { std::cout << "VR[ 0] = "; RV->printVRf64( 0); }
572+
vmtlfre64.v( 1, A1, lambda_eff); if (debug > 1) { std::cout << "VR[ 1] = "; RV->printVRf64( 1); }
573+
vmtlfre64.v( 2, A2, lambda_eff); if (debug > 1) { std::cout << "VR[ 2] = "; RV->printVRf64( 2); }
574+
vmtlfre64.v( 3, A3, lambda_eff); if (debug > 1) { std::cout << "VR[ 3] = "; RV->printVRf64( 3); }
575+
576+
// load the 4 B registers
577+
vmtlfre64.v( 8, B0, lambda_eff); if (debug > 1) { std::cout << "VR[ 8] = "; RV->printVRf64( 8); }
578+
vmtlfre64.v( 9, B1, lambda_eff); if (debug > 1) { std::cout << "VR[ 9] = "; RV->printVRf64( 9); }
579+
vmtlfre64.v(10, B2, lambda_eff); if (debug > 1) { std::cout << "VR[10] = "; RV->printVRf64(10); }
580+
vmtlfre64.v(11, B3, lambda_eff); if (debug > 1) { std::cout << "VR[11] = "; RV->printVRf64(11); }
581+
582+
A0 = A0 + INCA ; A1 = A1 + INCA ; A2 = A2 + INCA ; A3 = A3 + INCA; // increment pointers for the A registers
583+
B0 = B0 + INCB ; B1 = B1 + INCB ; B2 = B2 + INCB ; B3 = B3 + INCB; // increment pointers for the B registers
584+
585+
// perform 16 vmmacc's, one for each target register
586+
vfmmacc.v0(16, 0, 8); vmrotate.vv( 8, 8); if (debug > 1) { std::cout << "VR[16] = "; RV->printVRf64(16); }
587+
vfmmacc.v0(17, 0, 9); vmrotate.vv( 9, 9); if (debug > 1) { std::cout << "VR[17] = "; RV->printVRf64(17); }
588+
vfmmacc.v0(18, 1, 8); vmrotate.vv( 8, 8); if (debug > 1) { std::cout << "VR[18] = "; RV->printVRf64(18); }
589+
vfmmacc.v0(19, 1, 9); vmrotate.vv( 9, 9); if (debug > 1) { std::cout << "VR[19] = "; RV->printVRf64(19); }
590+
vfmmacc.v0(20, 0, 10); vmrotate.vv(10, 10); if (debug > 1) { std::cout << "VR[20] = "; RV->printVRf64(20); }
591+
vfmmacc.v0(21, 0, 11); vmrotate.vv(11, 11); if (debug > 1) { std::cout << "VR[21] = "; RV->printVRf64(21); }
592+
vfmmacc.v0(22, 1, 10); vmrotate.vv(10, 10); if (debug > 1) { std::cout << "VR[22] = "; RV->printVRf64(22); }
593+
vfmmacc.v0(23, 1, 11); vmrotate.vv(11, 11); if (debug > 1) { std::cout << "VR[23] = "; RV->printVRf64(23); }
594+
vfmmacc.v0(24, 2, 8); vmrotate.vv( 8, 8); if (debug > 1) { std::cout << "VR[24] = "; RV->printVRf64(24); }
595+
vfmmacc.v0(25, 2, 9); vmrotate.vv( 9, 9); if (debug > 1) { std::cout << "VR[25] = "; RV->printVRf64(25); }
596+
vfmmacc.v0(26, 3, 8); vmrotate.vv( 8, 8); if (debug > 1) { std::cout << "VR[26] = "; RV->printVRf64(26); }
597+
vfmmacc.v0(27, 3, 9); vmrotate.vv( 9, 9); if (debug > 1) { std::cout << "VR[27] = "; RV->printVRf64(27); }
598+
vfmmacc.v0(28, 2, 10); vmrotate.vv(10, 10); if (debug > 1) { std::cout << "VR[28] = "; RV->printVRf64(28); }
599+
vfmmacc.v0(29, 2, 11); vmrotate.vv(11, 11); if (debug > 1) { std::cout << "VR[29] = "; RV->printVRf64(29); }
600+
vfmmacc.v0(30, 3, 10); vmrotate.vv(10, 10); if (debug > 1) { std::cout << "VR[30] = "; RV->printVRf64(30); }
601+
vfmmacc.v0(31, 3, 11); vmrotate.vv(11, 11); if (debug > 1) { std::cout << "VR[31] = "; RV->printVRf64(31); }
602+
}
603+
604+
// compute the store offsets for each result register - this only has to be done once per <L,lambda> configuration
605+
// the offset vector only needs 16 elements - we use 32 for convenience and will cleanup later
606+
u32 offset[32];
607+
offset[16] = 0;
608+
offset[17] = offset[16] + (((!LisSquare()) && (1 == lmul)) ? 2 * RV->lambda() : RV->lambda());
609+
offset[18] = (1 != lmul) ? (offset[17] + RV->lambda()) : (LisSquare() ? offset[16] + RV->lambda() * gamma : offset[16] + RV->lambda());
610+
offset[19] = ((!LisSquare()) && (1 == lmul)) ? offset[17] + RV->lambda() : offset[18] + RV->lambda();
611+
offset[20] = (4 == lmul) ? (offset[16] + 4*RV->lambda()) : offset[16] + (2*RV->sigma())/lmul;
612+
offset[21] = offset[20] + (((!LisSquare()) && (1 == lmul)) ? 2 * RV->lambda() : RV->lambda());
613+
offset[22] = (1 != lmul) ? (offset[21] + RV->lambda()) : (LisSquare() ? offset[20] + RV->lambda() * gamma : offset[20] + RV->lambda());
614+
offset[23] = ((!LisSquare()) && (1 == lmul)) ? offset[21] + RV->lambda() : offset[22] + RV->lambda();
615+
offset[24] = (4 == lmul) ? (offset[16] + 8*RV->lambda()) : ((LisSquare() || (1 == lmul)) ? gamma * (M/2) : offset[16] + 4*RV->lambda());
616+
offset[25] = offset[24] + (((!LisSquare()) && (1 == lmul)) ? 2 * RV->lambda() : RV->lambda());
617+
offset[26] = (1 != lmul) ? (offset[25] + RV->lambda()) : (LisSquare() ? offset[24] + RV->lambda() * gamma : offset[24] + RV->lambda());
618+
offset[27] = ((!LisSquare()) && (1 == lmul)) ? offset[25] + RV->lambda() : offset[26] + RV->lambda();
619+
offset[28] = (4 == lmul) ? (offset[24] + 4*RV->lambda()) : offset[24] + (2*RV->sigma())/lmul;
620+
offset[29] = offset[28] + (((!LisSquare()) && (1 == lmul)) ? 2 * RV->lambda() : RV->lambda());
621+
offset[30] = (1 != lmul) ? (offset[29] + RV->lambda()) : (LisSquare() ? offset[28] + RV->lambda() * gamma : offset[28] + RV->lambda());
622+
offset[31] = ((!LisSquare()) && (1 == lmul)) ? offset[29] + RV->lambda() : offset[30] + RV->lambda();
623+
624+
if (debug > 1) { for (u32 i=16; i<32; i++) std::cout << "offset[" << i << "] = " << offset[i] << std::endl; }
625+
626+
// do the scaling by alpha and update C
627+
vsetvl(5, 0, 64, 1, true, true); // double-precision kernel, set VL to VLENE and LMUL to 1
628+
for (u32 vd=0; vd<16; vd++)
629+
{
630+
vmtlfre64.v(vd, C+offset[vd+16], N); // C[i,j] = alpha * T[i,j] + C[i,j]
631+
vfmacc.vf(vd, alpha, vd+16);
632+
vmtsfre64.v(vd, C+offset[vd+16], N);
633+
}
634+
}
635+
636+
void microdgemm_old
539637
(
540638
u32 M,
541639
u32 N,
@@ -746,13 +844,13 @@ bool run_microgemm
746844
u32 lambda_eff = 1;
747845
while ((2*lambda_eff)*(2*lambda_eff) <= L) lambda_eff *= 2;
748846
u32 LMUL = lambda_eff / lambda;
749-
std::cout << "L = " << std::setw(2) << L << ", lambda = " << std::setw(2) << RV->lambda() << ", sigma = " << std::setw(2) << RV->sigma() << ", lambda_eff = " << std::setw(2) << lambda_eff << ", LMUL = " << LMUL;
847+
std::cout << "L = " << std::setw(4) << L << ", lambda = " << std::setw(2) << RV->lambda() << ", sigma = " << std::setw(3) << RV->sigma() << ", lambda_eff = " << std::setw(2) << lambda_eff << ", LMUL = " << LMUL;
750848
std::cout << ", RMUL = " << std::setw(2) << rmul << ", CMUL = " << std::setw(2) << cmul;
751849

752850
u32 mu = rmul*RV->sigma();
753851
u32 nu = cmul*RV->sigma();
754852

755-
std::cout << ", mu = " << std::setw(3) << mu << ", nu = " << std::setw(3) << nu << ", K = " << K << std::endl;
853+
std::cout << ", mu = " << std::setw(3) << mu << ", nu = " << std::setw(3) << nu << ", K = " << std::setw(3) << K << std::endl;
756854

757855
u32 M = mu;
758856
u32 N = nu;
@@ -826,43 +924,53 @@ int main
826924
char **argv
827925
)
828926
{
829-
std::cout << "=========================================================================================================================" << std::endl;
830-
run_microgemm< 64, 1>(1);
831-
run_microgemm< 64, 1>(2);
832-
run_microgemm< 64, 1>(4);
833-
run_microgemm< 64, 1>(8);
834-
run_microgemm< 128, 1>(1);
835-
run_microgemm< 128, 1>(2);
836-
run_microgemm< 128, 1>(4);
837-
run_microgemm< 128, 1>(8);
838-
run_microgemm< 256, 1>(2);
839-
run_microgemm< 256, 1>(4);
840-
run_microgemm< 256, 1>(8);
841-
run_microgemm< 256, 2>(2);
842-
run_microgemm< 256, 2>(4);
843-
run_microgemm< 256, 2>(8);
844-
run_microgemm< 512, 1>(2);
845-
run_microgemm< 512, 1>(4);
846-
run_microgemm< 512, 1>(8);
847-
run_microgemm< 512, 2>(2);
848-
run_microgemm< 512, 2>(4);
849-
run_microgemm< 512, 2>(8);
850-
run_microgemm<1024, 1>(4);
851-
run_microgemm<1024, 1>(8);
852-
run_microgemm<1024, 1>(16);
853-
run_microgemm<1024, 2>(4);
854-
run_microgemm<1024, 2>(8);
855-
run_microgemm<1024, 2>(16);
856-
run_microgemm<1024, 4>(4);
857-
run_microgemm<1024, 4>(8);
858-
run_microgemm<1024, 4>(16);
859-
run_microgemm<2048, 2>(8);
860-
run_microgemm<2048, 2>(16);
861-
run_microgemm<2048, 4>(8);
862-
run_microgemm<2048, 4>(16);
863-
run_microgemm<4096, 2>(16);
864-
run_microgemm<4096, 4>(16);
865-
run_microgemm<4096, 8>(16);
927+
std::cout << "=================================================================================================================" << std::endl;
928+
run_microgemm< 64, 1>( 1);
929+
run_microgemm< 64, 1>( 2);
930+
run_microgemm< 64, 1>( 4);
931+
run_microgemm< 64, 1>( 8);
932+
run_microgemm< 128, 1>( 1);
933+
run_microgemm< 128, 1>( 2);
934+
run_microgemm< 128, 1>( 4);
935+
run_microgemm< 128, 1>( 8);
936+
run_microgemm< 256, 1>( 2);
937+
run_microgemm< 256, 1>( 4);
938+
run_microgemm< 256, 1>( 8);
939+
run_microgemm< 256, 2>( 2);
940+
run_microgemm< 256, 2>( 4);
941+
run_microgemm< 256, 2>( 8);
942+
run_microgemm< 512, 1>( 2);
943+
run_microgemm< 512, 1>( 4);
944+
run_microgemm< 512, 1>( 8);
945+
run_microgemm< 512, 2>( 2);
946+
run_microgemm< 512, 2>( 4);
947+
run_microgemm< 512, 2>( 8);
948+
run_microgemm< 1024, 1>( 4);
949+
run_microgemm< 1024, 1>( 8);
950+
run_microgemm< 1024, 1>(16);
951+
run_microgemm< 1024, 2>( 4);
952+
run_microgemm< 1024, 2>( 8);
953+
run_microgemm< 1024, 2>(16);
954+
run_microgemm< 1024, 4>( 4);
955+
run_microgemm< 1024, 4>( 8);
956+
run_microgemm< 1024, 4>(16);
957+
run_microgemm< 2048, 2>( 8);
958+
run_microgemm< 2048, 2>(16);
959+
run_microgemm< 2048, 4>( 8);
960+
run_microgemm< 2048, 4>(64);
961+
run_microgemm< 4096, 2>(64);
962+
run_microgemm< 4096, 4>(64);
963+
run_microgemm< 4096, 8>(64);
964+
run_microgemm< 8192, 4>(64);
965+
run_microgemm< 8192, 8>(64);
966+
run_microgemm<16384, 4>(64);
967+
run_microgemm<16384, 8>(64);
968+
run_microgemm<16384,16>(64);
969+
run_microgemm<32768, 8>(64);
970+
run_microgemm<32768,16>(64);
971+
run_microgemm<65536, 8>(64);
972+
run_microgemm<65536,16>(64);
973+
run_microgemm<65536,32>(64);
866974

867975
return 0;
868976
}

0 commit comments

Comments
 (0)