@@ -535,7 +535,22 @@ u32 LisSquare()
535535 else return 0 ;
536536}
537537
538- void microdgemm
538+ template <typename T>
539+ void microgemm
540+ (
541+ u32 M,
542+ u32 N,
543+ u32 K,
544+ T *A,
545+ T *B,
546+ T alpha,
547+ T *C,
548+ s32 gamma,
549+ u32 lmul
550+ );
551+
552+ template <>
553+ void microgemm<double >
539554(
540555 u32 M,
541556 u32 N,
@@ -548,32 +563,32 @@ void microdgemm
548563 u32 lmul
549564)
550565{
551- u32 L = RV->VLENE (); // L is number of elements per vector register
552- u32 lambda_eff = RV->lambda () * lmul; // lambda_eff is the maximum lambda for this L
566+ u32 L = RV->VLENE (); // L is number of elements per vector register
567+ u32 lambda_eff = RV->lambda () * lmul; // lambda_eff is the maximum lambda for this L
553568 assert (0 == K % lambda_eff); // for simplicty, K must be a multiple of lambda_eff
554569
555570 vsetvl (5 , 0 , 64 , 1 , true , true ); // double-precision kernel, set VL to VLENE and LMUL to 1
556571 for (u32 r=16 ; r<32 ; r++) vxor.vv (r, r, r); // T = 0
557572
558573 vsetvl (5 , RV->lambda () * RV->lambda (), 64 , lmul, true , true ); // double-precision kernel, set VL to lambda^2 and LMUL accordingly
559- s32 INCA = M*lambda_eff; s32 INCB = N*lambda_eff; // iteration increments for A and B panels
574+ s32 INCA = M*lambda_eff; s32 INCB = N*lambda_eff; // iteration increments for A and B panels
560575
561576 // the following setup for the A and B register load pointers works because not all loads are active for all values of lmul
562- double *A0 = A; double *A1 = A0 + LisSquare () * L; double *A2 = A1 + ((2 == lmul) ? LisSquare () * L : L); double *A3 = A2 + LisSquare () * L; // pointers for loads to the A registers
563- double *B0 = B; double *B1 = B0 + L; double *B2 = B1 + L; double *B3 = B2 + L; // pointers for loads to the B registers
577+ double *A0 = A; double *A1 = A0 + LisSquare () * L; double *A2 = A1 + ((2 == lmul) ? LisSquare () * L : L); double *A3 = A2 + LisSquare () * L; // pointers for loads to the A registers
578+ double *B0 = B; double *B1 = B0 + L; double *B2 = B1 + L; double *B3 = B2 + L; // pointers for loads to the B registers
564579
565580 // the computation loop
566581 for (u32 k=0 ; k<K; k+=lambda_eff)
567582 {
568583 if (debug > 1 ) { std::cout << " k = " << k << std::endl; }
569584
570- // load the 4 A registers
585+ // load the 4 A registers
571586 vmtlfre64.v ( 0 , A0, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 0] = " ; RV->printVRf64 ( 0 ); }
572587 vmtlfre64.v ( 1 , A1, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 1] = " ; RV->printVRf64 ( 1 ); }
573588 vmtlfre64.v ( 2 , A2, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 2] = " ; RV->printVRf64 ( 2 ); }
574589 vmtlfre64.v ( 3 , A3, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 3] = " ; RV->printVRf64 ( 3 ); }
575590
576- // load the 4 B registers
591+ // load the 4 B registers
577592 vmtlfre64.v ( 8 , B0, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 8] = " ; RV->printVRf64 ( 8 ); }
578593 vmtlfre64.v ( 9 , B1, lambda_eff); if (debug > 1 ) { std::cout << " VR[ 9] = " ; RV->printVRf64 ( 9 ); }
579594 vmtlfre64.v (10 , B2, lambda_eff); if (debug > 1 ) { std::cout << " VR[10] = " ; RV->printVRf64 (10 ); }
@@ -582,7 +597,7 @@ void microdgemm
582597 A0 = A0 + INCA ; A1 = A1 + INCA ; A2 = A2 + INCA ; A3 = A3 + INCA; // increment pointers for the A registers
583598 B0 = B0 + INCB ; B1 = B1 + INCB ; B2 = B2 + INCB ; B3 = B3 + INCB; // increment pointers for the B registers
584599
585- // perform 16 vmmacc's, one for each target register
600+ // perform 16 vmmacc's, one for each target register
586601 vfmmacc.v0 (16 , 0 , 8 ); vmrotate.vv ( 8 , 8 ); if (debug > 1 ) { std::cout << " VR[16] = " ; RV->printVRf64 (16 ); }
587602 vfmmacc.v0 (17 , 0 , 9 ); vmrotate.vv ( 9 , 9 ); if (debug > 1 ) { std::cout << " VR[17] = " ; RV->printVRf64 (17 ); }
588603 vfmmacc.v0 (18 , 1 , 8 ); vmrotate.vv ( 8 , 8 ); if (debug > 1 ) { std::cout << " VR[18] = " ; RV->printVRf64 (18 ); }
@@ -627,13 +642,13 @@ void microdgemm
627642 vsetvl (5 , 0 , 64 , 1 , true , true ); // double-precision kernel, set VL to VLENE and LMUL to 1
628643 for (u32 vd=0 ; vd<16 ; vd++)
629644 {
630- vmtlfre64.v (vd, C+offset[vd+16 ], N); // C[i,j] = alpha * T[i,j] + C[i,j]
645+ vmtlfre64.v (vd, C+offset[vd+16 ], N); // C[i,j] = alpha * T[i,j] + C[i,j]
631646 vfmacc.vf (vd, alpha, vd+16 );
632647 vmtsfre64.v (vd, C+offset[vd+16 ], N);
633648 }
634649}
635650
636- void microdgemm_old
651+ void microdgemm
637652(
638653 u32 M,
639654 u32 N,
@@ -882,7 +897,7 @@ bool run_microgemm
882897 if (debug > 1 ) { std::cout << " Bp[" << k/lambda_eff << " ] = " ; print (N, lambda_eff, Bp+k*nu); }
883898 }
884899
885- microdgemm (M, N, K, Ap, Bp, alpha, D, N, LMUL);
900+ microgemm< double > (M, N, K, Ap, Bp, alpha, D, N, LMUL);
886901
887902 // Check the result
888903 for (u32 j=0 ; j<N; j++)
0 commit comments