@@ -528,14 +528,112 @@ vmtlfre64_t vmtlfre64;
528528vmtsfre64_t vmtsfre64;
529529vmrotate_t vmrotate;
530530
531- bool LisSquare ()
531+ u32 LisSquare ()
532532{
533533 double rootL = sqrt (RV->VLENE ());
534- if ((rootL*rootL) == RV->VLENE ()) return true ;
535- else return false ;
534+ if ((rootL*rootL) == RV->VLENE ()) return 1 ;
535+ else return 0 ;
536536}
537537
538538void microdgemm
539+ (
540+ u32 M,
541+ u32 N,
542+ u32 K,
543+ double *A,
544+ double *B,
545+ double alpha,
546+ double *C,
547+ s32 gamma,
548+ u32 lmul
549+ )
550+ {
551+ u32 L = RV->VLENE (); // L is number of elements per vector register
552+ u32 lambda_eff = RV->lambda () * lmul; // lambda_eff is the maximum lambda for this L
553+ assert (0 == K % lambda_eff); // for simplicty, K must be a multiple of lambda_eff
554+
555+ vsetvl (5 , 0 , 64 , 1 , true , true ); // double-precision kernel, set VL to VLENE and LMUL to 1
556+ for (u32 r=16 ; r<32 ; r++) vxor.vv (r, r, r); // T = 0
557+
558+ vsetvl (5 , RV->lambda () * RV->lambda (), 64 , lmul, true , true ); // double-precision kernel, set VL to lambda^2 and LMUL accordingly
559+ s32 INCA = M*lambda_eff; s32 INCB = N*lambda_eff; // iteration increments for A and B panels
560+
561+ // the following setup for the A and B register load pointers works because not all loads are active for all values of lmul
562+ double *A0 = A; double *A1 = A0 + LisSquare () * L; double *A2 = A1 + ((2 == lmul) ? LisSquare () * L : L); double *A3 = A2 + LisSquare () * L; // pointers for loads to the A registers
563+ double *B0 = B; double *B1 = B0 + L; double *B2 = B1 + L; double *B3 = B2 + L; // pointers for loads to the B registers
564+
565+ // the computation loop
566+ for (u32 k=0 ; k<K; k+=lambda_eff)
567+ {
568+ if (debug > 1 ) { std::cout << " k = " << k << std::endl; }
569+
570+ // load the 4 A registers
571+ vmtlfre64.v ( 0 , A0, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 0] = " ; RV->printVRf64 ( 0 ); }
572+ vmtlfre64.v ( 1 , A1, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 1] = " ; RV->printVRf64 ( 1 ); }
573+ vmtlfre64.v ( 2 , A2, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 2] = " ; RV->printVRf64 ( 2 ); }
574+ vmtlfre64.v ( 3 , A3, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 3] = " ; RV->printVRf64 ( 3 ); }
575+
576+ // load the 4 B registers
577+ vmtlfre64.v ( 8 , B0, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 8] = " ; RV->printVRf64 ( 8 ); }
578+ vmtlfre64.v ( 9 , B1, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 9] = " ; RV->printVRf64 ( 9 ); }
579+ vmtlfre64.v (10 , B2, lambda_eff); if (debug > 1 ) { std::cout << " VR[10] = " ; RV->printVRf64 (10 ); }
580+ vmtlfre64.v (11 , B3, lambda_eff); if (debug > 1 ) { std::cout << " VR[11] = " ; RV->printVRf64 (11 ); }
581+
582+ A0 = A0 + INCA ; A1 = A1 + INCA ; A2 = A2 + INCA ; A3 = A3 + INCA; // increment pointers for the A registers
583+ B0 = B0 + INCB ; B1 = B1 + INCB ; B2 = B2 + INCB ; B3 = B3 + INCB; // increment pointers for the B registers
584+
585+ // perform 16 vmmacc's, one for each target register
586+ vfmmacc.v0 (16 , 0 , 8 ); vmrotate.vv ( 8 , 8 ); if (debug > 1 ) { std::cout << " VR[16] = " ; RV->printVRf64 (16 ); }
587+ vfmmacc.v0 (17 , 0 , 9 ); vmrotate.vv ( 9 , 9 ); if (debug > 1 ) { std::cout << " VR[17] = " ; RV->printVRf64 (17 ); }
588+ vfmmacc.v0 (18 , 1 , 8 ); vmrotate.vv ( 8 , 8 ); if (debug > 1 ) { std::cout << " VR[18] = " ; RV->printVRf64 (18 ); }
589+ vfmmacc.v0 (19 , 1 , 9 ); vmrotate.vv ( 9 , 9 ); if (debug > 1 ) { std::cout << " VR[19] = " ; RV->printVRf64 (19 ); }
590+ vfmmacc.v0 (20 , 0 , 10 ); vmrotate.vv (10 , 10 ); if (debug > 1 ) { std::cout << " VR[20] = " ; RV->printVRf64 (20 ); }
591+ vfmmacc.v0 (21 , 0 , 11 ); vmrotate.vv (11 , 11 ); if (debug > 1 ) { std::cout << " VR[21] = " ; RV->printVRf64 (21 ); }
592+ vfmmacc.v0 (22 , 1 , 10 ); vmrotate.vv (10 , 10 ); if (debug > 1 ) { std::cout << " VR[22] = " ; RV->printVRf64 (22 ); }
593+ vfmmacc.v0 (23 , 1 , 11 ); vmrotate.vv (11 , 11 ); if (debug > 1 ) { std::cout << " VR[23] = " ; RV->printVRf64 (23 ); }
594+ vfmmacc.v0 (24 , 2 , 8 ); vmrotate.vv ( 8 , 8 ); if (debug > 1 ) { std::cout << " VR[24] = " ; RV->printVRf64 (24 ); }
595+ vfmmacc.v0 (25 , 2 , 9 ); vmrotate.vv ( 9 , 9 ); if (debug > 1 ) { std::cout << " VR[25] = " ; RV->printVRf64 (25 ); }
596+ vfmmacc.v0 (26 , 3 , 8 ); vmrotate.vv ( 8 , 8 ); if (debug > 1 ) { std::cout << " VR[26] = " ; RV->printVRf64 (26 ); }
597+ vfmmacc.v0 (27 , 3 , 9 ); vmrotate.vv ( 9 , 9 ); if (debug > 1 ) { std::cout << " VR[27] = " ; RV->printVRf64 (27 ); }
598+ vfmmacc.v0 (28 , 2 , 10 ); vmrotate.vv (10 , 10 ); if (debug > 1 ) { std::cout << " VR[28] = " ; RV->printVRf64 (28 ); }
599+ vfmmacc.v0 (29 , 2 , 11 ); vmrotate.vv (11 , 11 ); if (debug > 1 ) { std::cout << " VR[29] = " ; RV->printVRf64 (29 ); }
600+ vfmmacc.v0 (30 , 3 , 10 ); vmrotate.vv (10 , 10 ); if (debug > 1 ) { std::cout << " VR[30] = " ; RV->printVRf64 (30 ); }
601+ vfmmacc.v0 (31 , 3 , 11 ); vmrotate.vv (11 , 11 ); if (debug > 1 ) { std::cout << " VR[31] = " ; RV->printVRf64 (31 ); }
602+ }
603+
604+ // compute the store offsets for each result register - this only has to be done once per <L,lambda> configuration
605+ // the offset vector only needs 16 elements - we use 32 for convenience and will cleanup later
606+ u32 offset[32 ];
607+ offset[16 ] = 0 ;
608+ offset[17 ] = offset[16 ] + (((!LisSquare ()) && (1 == lmul)) ? 2 * RV->lambda () : RV->lambda ());
609+ offset[18 ] = (1 != lmul) ? (offset[17 ] + RV->lambda ()) : (LisSquare () ? offset[16 ] + RV->lambda () * gamma : offset[16 ] + RV->lambda ());
610+ offset[19 ] = ((!LisSquare ()) && (1 == lmul)) ? offset[17 ] + RV->lambda () : offset[18 ] + RV->lambda ();
611+ offset[20 ] = (4 == lmul) ? (offset[16 ] + 4 *RV->lambda ()) : offset[16 ] + (2 *RV->sigma ())/lmul;
612+ offset[21 ] = offset[20 ] + (((!LisSquare ()) && (1 == lmul)) ? 2 * RV->lambda () : RV->lambda ());
613+ offset[22 ] = (1 != lmul) ? (offset[21 ] + RV->lambda ()) : (LisSquare () ? offset[20 ] + RV->lambda () * gamma : offset[20 ] + RV->lambda ());
614+ offset[23 ] = ((!LisSquare ()) && (1 == lmul)) ? offset[21 ] + RV->lambda () : offset[22 ] + RV->lambda ();
615+ offset[24 ] = (4 == lmul) ? (offset[16 ] + 8 *RV->lambda ()) : ((LisSquare () || (1 == lmul)) ? gamma * (M/2 ) : offset[16 ] + 4 *RV->lambda ());
616+ offset[25 ] = offset[24 ] + (((!LisSquare ()) && (1 == lmul)) ? 2 * RV->lambda () : RV->lambda ());
617+ offset[26 ] = (1 != lmul) ? (offset[25 ] + RV->lambda ()) : (LisSquare () ? offset[24 ] + RV->lambda () * gamma : offset[24 ] + RV->lambda ());
618+ offset[27 ] = ((!LisSquare ()) && (1 == lmul)) ? offset[25 ] + RV->lambda () : offset[26 ] + RV->lambda ();
619+ offset[28 ] = (4 == lmul) ? (offset[24 ] + 4 *RV->lambda ()) : offset[24 ] + (2 *RV->sigma ())/lmul;
620+ offset[29 ] = offset[28 ] + (((!LisSquare ()) && (1 == lmul)) ? 2 * RV->lambda () : RV->lambda ());
621+ offset[30 ] = (1 != lmul) ? (offset[29 ] + RV->lambda ()) : (LisSquare () ? offset[28 ] + RV->lambda () * gamma : offset[28 ] + RV->lambda ());
622+ offset[31 ] = ((!LisSquare ()) && (1 == lmul)) ? offset[29 ] + RV->lambda () : offset[30 ] + RV->lambda ();
623+
624+ if (debug > 1 ) { for (u32 i=16 ; i<32 ; i++) std::cout << " offset[" << i << " ] = " << offset[i] << std::endl; }
625+
626+ // do the scaling by alpha and update C
627+ vsetvl (5 , 0 , 64 , 1 , true , true ); // double-precision kernel, set VL to VLENE and LMUL to 1
628+ for (u32 vd=0 ; vd<16 ; vd++)
629+ {
630+ vmtlfre64.v (vd, C+offset[vd+16 ], N); // C[i,j] = alpha * T[i,j] + C[i,j]
631+ vfmacc.vf (vd, alpha, vd+16 );
632+ vmtsfre64.v (vd, C+offset[vd+16 ], N);
633+ }
634+ }
635+
636+ void microdgemm_old
539637(
540638 u32 M,
541639 u32 N,
@@ -746,13 +844,13 @@ bool run_microgemm
746844 u32 lambda_eff = 1 ;
747845 while ((2 *lambda_eff)*(2 *lambda_eff) <= L) lambda_eff *= 2 ;
748846 u32 LMUL = lambda_eff / lambda;
749- std::cout << " L = " << std::setw (2 ) << L << " , lambda = " << std::setw (2 ) << RV->lambda () << " , sigma = " << std::setw (2 ) << RV->sigma () << " , lambda_eff = " << std::setw (2 ) << lambda_eff << " , LMUL = " << LMUL;
847+ std::cout << " L = " << std::setw (4 ) << L << " , lambda = " << std::setw (2 ) << RV->lambda () << " , sigma = " << std::setw (3 ) << RV->sigma () << " , lambda_eff = " << std::setw (2 ) << lambda_eff << " , LMUL = " << LMUL;
750848 std::cout << " , RMUL = " << std::setw (2 ) << rmul << " , CMUL = " << std::setw (2 ) << cmul;
751849
752850 u32 mu = rmul*RV->sigma ();
753851 u32 nu = cmul*RV->sigma ();
754852
755- std::cout << " , mu = " << std::setw (3 ) << mu << " , nu = " << std::setw (3 ) << nu << " , K = " << K << std::endl;
853+ std::cout << " , mu = " << std::setw (3 ) << mu << " , nu = " << std::setw (3 ) << nu << " , K = " << std::setw ( 3 ) << K << std::endl;
756854
757855 u32 M = mu;
758856 u32 N = nu;
@@ -826,43 +924,53 @@ int main
826924 char **argv
827925)
828926{
829- std::cout << " =========================================================================================================================" << std::endl;
830- run_microgemm< 64 , 1 >(1 );
831- run_microgemm< 64 , 1 >(2 );
832- run_microgemm< 64 , 1 >(4 );
833- run_microgemm< 64 , 1 >(8 );
834- run_microgemm< 128 , 1 >(1 );
835- run_microgemm< 128 , 1 >(2 );
836- run_microgemm< 128 , 1 >(4 );
837- run_microgemm< 128 , 1 >(8 );
838- run_microgemm< 256 , 1 >(2 );
839- run_microgemm< 256 , 1 >(4 );
840- run_microgemm< 256 , 1 >(8 );
841- run_microgemm< 256 , 2 >(2 );
842- run_microgemm< 256 , 2 >(4 );
843- run_microgemm< 256 , 2 >(8 );
844- run_microgemm< 512 , 1 >(2 );
845- run_microgemm< 512 , 1 >(4 );
846- run_microgemm< 512 , 1 >(8 );
847- run_microgemm< 512 , 2 >(2 );
848- run_microgemm< 512 , 2 >(4 );
849- run_microgemm< 512 , 2 >(8 );
850- run_microgemm<1024 , 1 >(4 );
851- run_microgemm<1024 , 1 >(8 );
852- run_microgemm<1024 , 1 >(16 );
853- run_microgemm<1024 , 2 >(4 );
854- run_microgemm<1024 , 2 >(8 );
855- run_microgemm<1024 , 2 >(16 );
856- run_microgemm<1024 , 4 >(4 );
857- run_microgemm<1024 , 4 >(8 );
858- run_microgemm<1024 , 4 >(16 );
859- run_microgemm<2048 , 2 >(8 );
860- run_microgemm<2048 , 2 >(16 );
861- run_microgemm<2048 , 4 >(8 );
862- run_microgemm<2048 , 4 >(16 );
863- run_microgemm<4096 , 2 >(16 );
864- run_microgemm<4096 , 4 >(16 );
865- run_microgemm<4096 , 8 >(16 );
927+ std::cout << " =================================================================================================================" << std::endl;
928+ run_microgemm< 64 , 1 >( 1 );
929+ run_microgemm< 64 , 1 >( 2 );
930+ run_microgemm< 64 , 1 >( 4 );
931+ run_microgemm< 64 , 1 >( 8 );
932+ run_microgemm< 128 , 1 >( 1 );
933+ run_microgemm< 128 , 1 >( 2 );
934+ run_microgemm< 128 , 1 >( 4 );
935+ run_microgemm< 128 , 1 >( 8 );
936+ run_microgemm< 256 , 1 >( 2 );
937+ run_microgemm< 256 , 1 >( 4 );
938+ run_microgemm< 256 , 1 >( 8 );
939+ run_microgemm< 256 , 2 >( 2 );
940+ run_microgemm< 256 , 2 >( 4 );
941+ run_microgemm< 256 , 2 >( 8 );
942+ run_microgemm< 512 , 1 >( 2 );
943+ run_microgemm< 512 , 1 >( 4 );
944+ run_microgemm< 512 , 1 >( 8 );
945+ run_microgemm< 512 , 2 >( 2 );
946+ run_microgemm< 512 , 2 >( 4 );
947+ run_microgemm< 512 , 2 >( 8 );
948+ run_microgemm< 1024 , 1 >( 4 );
949+ run_microgemm< 1024 , 1 >( 8 );
950+ run_microgemm< 1024 , 1 >(16 );
951+ run_microgemm< 1024 , 2 >( 4 );
952+ run_microgemm< 1024 , 2 >( 8 );
953+ run_microgemm< 1024 , 2 >(16 );
954+ run_microgemm< 1024 , 4 >( 4 );
955+ run_microgemm< 1024 , 4 >( 8 );
956+ run_microgemm< 1024 , 4 >(16 );
957+ run_microgemm< 2048 , 2 >( 8 );
958+ run_microgemm< 2048 , 2 >(16 );
959+ run_microgemm< 2048 , 4 >( 8 );
960+ run_microgemm< 2048 , 4 >(64 );
961+ run_microgemm< 4096 , 2 >(64 );
962+ run_microgemm< 4096 , 4 >(64 );
963+ run_microgemm< 4096 , 8 >(64 );
964+ run_microgemm< 8192 , 4 >(64 );
965+ run_microgemm< 8192 , 8 >(64 );
966+ run_microgemm<16384 , 4 >(64 );
967+ run_microgemm<16384 , 8 >(64 );
968+ run_microgemm<16384 ,16 >(64 );
969+ run_microgemm<32768 , 8 >(64 );
970+ run_microgemm<32768 ,16 >(64 );
971+ run_microgemm<65536 , 8 >(64 );
972+ run_microgemm<65536 ,16 >(64 );
973+ run_microgemm<65536 ,32 >(64 );
866974
867975 return 0 ;
868976}
0 commit comments