@@ -2673,7 +2673,6 @@ class tinyBLAS_PPC {
26732673
26742674 private:
26752675
2676- void (tinyBLAS_PPC::*kernel)(int64_t , int64_t );
26772676
26782677 template <typename VA>
26792678 void packTranspose (const TA* a, int64_t lda, int rows, int cols, TA* vec) {
@@ -2920,7 +2919,72 @@ class tinyBLAS_PPC {
29202919 }
29212920 }
29222921
2923- void KERNEL_4x4 (int64_t ii, int64_t jj) {
2922+ void KERNEL_4x4 (vec_t * vec_A, vec_t * vec_B, acc_t * acc){
2923+ for (int l = 0 ; l < k; l+=4 ) {
2924+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[0 ], vec_B[0 ]);
2925+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[1 ], vec_B[1 ]);
2926+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[2 ], vec_B[2 ]);
2927+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[3 ], vec_B[3 ]);
2928+ }
2929+ }
2930+
2931+ void KERNEL_4x8 (vec_t * vec_A, vec_t * vec_B, acc_t * acc ) {
2932+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[0 ], (vec_t )vec_B[0 ]);
2933+ __builtin_mma_xvf32gerpp (&acc[1 ], vec_A[0 ], (vec_t )vec_B[1 ]);
2934+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[1 ], (vec_t )vec_B[2 ]);
2935+ __builtin_mma_xvf32gerpp (&acc[1 ], vec_A[1 ], (vec_t )vec_B[3 ]);
2936+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[2 ], (vec_t )vec_B[4 ]);
2937+ __builtin_mma_xvf32gerpp (&acc[1 ], vec_A[2 ], (vec_t )vec_B[5 ]);
2938+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A[3 ], (vec_t )vec_B[6 ]);
2939+ __builtin_mma_xvf32gerpp (&acc[1 ], vec_A[3 ], (vec_t )vec_B[7 ]);
2940+
2941+ }
2942+
2943+ void KERNEL_8x4 (vec_t * vec_A, vec_t * vec_B, acc_t * acc) {
2944+ __builtin_mma_xvf32gerpp (&acc[0 ], (vec_t )vec_A[0 ], vec_B[0 ]);
2945+ __builtin_mma_xvf32gerpp (&acc[1 ], (vec_t )vec_A[1 ], vec_B[0 ]);
2946+ __builtin_mma_xvf32gerpp (&acc[0 ], (vec_t )vec_A[2 ], vec_B[1 ]);
2947+ __builtin_mma_xvf32gerpp (&acc[1 ], (vec_t )vec_A[3 ], vec_B[1 ]);
2948+ __builtin_mma_xvf32gerpp (&acc[0 ], (vec_t )vec_A[4 ], vec_B[2 ]);
2949+ __builtin_mma_xvf32gerpp (&acc[1 ], (vec_t )vec_A[5 ], vec_B[2 ]);
2950+ __builtin_mma_xvf32gerpp (&acc[0 ], (vec_t )vec_A[6 ], vec_B[3 ]);
2951+ __builtin_mma_xvf32gerpp (&acc[1 ], (vec_t )vec_A[7 ], vec_B[3 ]);
2952+ }
2953+
2954+ void KERNEL_8x8 (vec_t * vec_A, vec_t * vec_B, acc_t * acc) {
2955+ for (int x = 0 ; x < 16 ; x+=2 ) {
2956+ __builtin_mma_xvf32gerpp (&acc[0 ], (vec_t )vec_A[x], vec_B[x]);
2957+ __builtin_mma_xvf32gerpp (&acc[1 ], (vec_t )vec_A[x], vec_B[x+1 ]);
2958+ __builtin_mma_xvf32gerpp (&acc[2 ], (vec_t )vec_A[x+1 ], vec_B[x]);
2959+ __builtin_mma_xvf32gerpp (&acc[3 ], (vec_t )vec_A[x+1 ], vec_B[x+1 ]);
2960+ }
2961+ }
2962+ void KERNEL_8x16 (vec_t * vec_A, vec_t * vec_B, acc_t * acc) {
2963+ for (int x = 0 ; x < 16 ; x+=2 ) {
2964+ __builtin_mma_xvf32gerpp (&acc[0 ], (vec_t )vec_A[x], (vec_t )vec_B[x]);
2965+ __builtin_mma_xvf32gerpp (&acc[1 ], (vec_t )vec_A[x], (vec_t )vec_B[x+1 ]);
2966+ __builtin_mma_xvf32gerpp (&acc[2 ], (vec_t )vec_A[x], (vec_t )vec_B[x+16 ]);
2967+ __builtin_mma_xvf32gerpp (&acc[3 ], (vec_t )vec_A[x], (vec_t )vec_B[x+17 ]);
2968+ __builtin_mma_xvf32gerpp (&acc[4 ], (vec_t )vec_A[x+1 ], (vec_t )vec_B[x]);
2969+ __builtin_mma_xvf32gerpp (&acc[5 ], (vec_t )vec_A[x+1 ], (vec_t )vec_B[x+1 ]);
2970+ __builtin_mma_xvf32gerpp (&acc[6 ], (vec_t )vec_A[x+1 ], (vec_t )vec_B[x+16 ]);
2971+ __builtin_mma_xvf32gerpp (&acc[7 ], (vec_t )vec_A[x+1 ], (vec_t )vec_B[x+17 ]);
2972+ }
2973+ }
2974+
2975+ void KERNEL_16x8 (vec_t * vec_A, vec_t * vec_B, acc_t * acc) {
2976+ for (int x = 0 ; x < 16 ; x+=2 ) {
2977+ __builtin_mma_xvf32gerpp (&acc[0 ], (vec_t )vec_A[x+0 ], (vec_t )vec_B[x]);
2978+ __builtin_mma_xvf32gerpp (&acc[1 ], (vec_t )vec_A[x+0 ], (vec_t )vec_B[x+1 ]);
2979+ __builtin_mma_xvf32gerpp (&acc[2 ], (vec_t )vec_A[x+1 ], (vec_t )vec_B[x]);
2980+ __builtin_mma_xvf32gerpp (&acc[3 ], (vec_t )vec_A[x+1 ], (vec_t )vec_B[x+1 ]);
2981+ __builtin_mma_xvf32gerpp (&acc[4 ], (vec_t )vec_A[x+16 ], (vec_t )vec_B[x]);
2982+ __builtin_mma_xvf32gerpp (&acc[5 ], (vec_t )vec_A[x+16 ], (vec_t )vec_B[x+1 ]);
2983+ __builtin_mma_xvf32gerpp (&acc[6 ], (vec_t )vec_A[x+17 ], (vec_t )vec_B[x]);
2984+ __builtin_mma_xvf32gerpp (&acc[7 ], (vec_t )vec_A[x+17 ], (vec_t )vec_B[x+1 ]);
2985+ }
2986+ }
2987+ /* void KERNEL_4x4(int64_t ii, int64_t jj) {
29242988 vec_t vec_A[4], vec_B[4], vec_C[4];
29252989 acc_t acc_0;
29262990 __builtin_mma_xxsetaccz(&acc_0);
@@ -2998,7 +3062,7 @@ class tinyBLAS_PPC {
29983062 SAVE_ACC(&acc_1, ii, jj+4);
29993063 SAVE_ACC(&acc_2, ii+4, jj);
30003064 SAVE_ACC(&acc_3, ii+4, jj+4);
3001- }
3065+ }*/
30023066
30033067 void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
30043068 int64_t mc, nc, mp, np;
@@ -3204,6 +3268,112 @@ class tinyBLAS_PPC {
32043268 }
32053269 }
32063270 }
3271+ template <int RM, int RN>
3272+ inline void kernel (int64_t ii, int64_t jj) {
3273+ if constexpr (RM == 4 && RN == 4 ) {
3274+ acc_t acc[1 ];
3275+ __builtin_mma_xxsetaccz (&acc[0 ]);
3276+ vec_t vec_A[4 ], vec_B[4 ], vec_C[4 ];
3277+ for (int l = 0 ; l < k; l += 4 ) {
3278+ packTranspose<vector float >(A+(ii*lda)+l, lda, 4 , 4 , (float *)vec_A);
3279+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, 4 , 4 , (float *)vec_B);
3280+ KERNEL_4x4 (vec_A, vec_B, acc);
3281+ }
3282+ SAVE_ACC (&acc[0 ], ii, jj);
3283+ } else if constexpr (RM == 4 && RN == 8 ) {
3284+ vec_t vec_A[4 ], vec_B[8 ], vec_C[4 ];
3285+ acc_t acc[2 ];
3286+ __builtin_mma_xxsetaccz (&acc[0 ]);
3287+ __builtin_mma_xxsetaccz (&acc[1 ]);
3288+ for (int64_t l = 0 ; l < k; l+=4 ) {
3289+ packTranspose<vector float >(A+(ii*lda)+l, lda, 4 , 4 , (float *)vec_A);
3290+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, 8 , 4 , (float *)vec_B);
3291+ KERNEL_4x8 (vec_A, vec_B, acc);
3292+ }
3293+ SAVE_ACC (&acc[0 ], ii, jj);
3294+ SAVE_ACC (&acc[1 ], ii, jj+4 );
3295+
3296+ } else if constexpr (RM == 8 && RN == 4 ) {
3297+ vec_t vec_A[8 ], vec_B[4 ], vec_C[4 ];
3298+ acc_t acc[2 ];
3299+ __builtin_mma_xxsetaccz (&acc[0 ]);
3300+ __builtin_mma_xxsetaccz (&acc[1 ]);
3301+ for (int64_t l = 0 ; l < k; l+=4 ) {
3302+ packTranspose<vector float >(A+(ii*lda)+l, lda, 8 , 4 , (float *)vec_A);
3303+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, 4 , 4 , (float *)vec_B);
3304+ KERNEL_8x4 (vec_A, vec_B, acc);
3305+ }
3306+ SAVE_ACC (&acc[0 ], ii, jj);
3307+ SAVE_ACC (&acc[1 ], ii+4 , jj);
3308+ } else if constexpr (RM == 8 && RN == 8 ) {
3309+ vec_t vec_A[16 ], vec_B[16 ], vec_C[4 ];
3310+ acc_t acc[4 ];
3311+ __builtin_mma_xxsetaccz (&acc[0 ]);
3312+ __builtin_mma_xxsetaccz (&acc[1 ]);
3313+ __builtin_mma_xxsetaccz (&acc[2 ]);
3314+ __builtin_mma_xxsetaccz (&acc[3 ]);
3315+ for (int l = 0 ; l < k; l+=8 ) {
3316+ packTranspose<vector float >(A+(ii*lda)+l, lda, 8 , 8 , (float *)vec_A);
3317+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, 8 , 8 , (float *)vec_B);
3318+ KERNEL_8x8 (vec_A, vec_B, acc);
3319+ }
3320+ SAVE_ACC (&acc[0 ], ii, jj);
3321+ SAVE_ACC (&acc[1 ], ii, jj+4 );
3322+ SAVE_ACC (&acc[2 ], ii+4 , jj);
3323+ SAVE_ACC (&acc[3 ], ii+4 , jj+4 );
3324+ } else if constexpr (RM == 8 && RN == 16 ) {
3325+ vec_t vec_A[16 ], vec_B[32 ], vec_C[4 ];
3326+ acc_t acc[8 ];
3327+ __builtin_mma_xxsetaccz (&acc[0 ]);
3328+ __builtin_mma_xxsetaccz (&acc[1 ]);
3329+ __builtin_mma_xxsetaccz (&acc[2 ]);
3330+ __builtin_mma_xxsetaccz (&acc[3 ]);
3331+ __builtin_mma_xxsetaccz (&acc[4 ]);
3332+ __builtin_mma_xxsetaccz (&acc[5 ]);
3333+ __builtin_mma_xxsetaccz (&acc[6 ]);
3334+ __builtin_mma_xxsetaccz (&acc[7 ]);
3335+ for (int l = 0 ; l < k; l+=8 ) {
3336+ packTranspose<vector float >(A+(ii*lda)+l, lda, 8 , 8 , (float *)vec_A);
3337+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, 8 , 16 , (float *)vec_B);
3338+ KERNEL_8x16 (vec_A, vec_B, acc);
3339+ }
3340+ SAVE_ACC (&acc[0 ], ii, jj);
3341+ SAVE_ACC (&acc[1 ], ii, jj+4 );
3342+ SAVE_ACC (&acc[2 ], ii, jj+8 );
3343+ SAVE_ACC (&acc[3 ], ii, jj+12 );
3344+ SAVE_ACC (&acc[4 ], ii+4 , jj);
3345+ SAVE_ACC (&acc[5 ], ii+4 , jj+4 );
3346+ SAVE_ACC (&acc[6 ], ii+4 , jj+8 );
3347+ SAVE_ACC (&acc[7 ], ii+4 , jj+12 );
3348+
3349+ } else if constexpr (RM == 16 && RN == 8 ){
3350+ vec_t vec_A[32 ], vec_B[16 ], vec_C[4 ];
3351+ acc_t acc[8 ];
3352+ __builtin_mma_xxsetaccz (&acc[0 ]);
3353+ __builtin_mma_xxsetaccz (&acc[1 ]);
3354+ __builtin_mma_xxsetaccz (&acc[2 ]);
3355+ __builtin_mma_xxsetaccz (&acc[3 ]);
3356+ __builtin_mma_xxsetaccz (&acc[4 ]);
3357+ __builtin_mma_xxsetaccz (&acc[5 ]);
3358+ __builtin_mma_xxsetaccz (&acc[6 ]);
3359+ __builtin_mma_xxsetaccz (&acc[7 ]);
3360+ for (int l = 0 ; l < k; l+=8 ) {
3361+ packTranspose<vector float >(A+(ii*lda)+l, lda, 16 , 8 , (float *)vec_A);
3362+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, 8 , 8 , (float *)vec_B);
3363+ KERNEL_16x8 (vec_A, vec_B, acc);
3364+ }
3365+ SAVE_ACC (&acc[0 ], ii, jj);
3366+ SAVE_ACC (&acc[1 ], ii, jj+4 );
3367+ SAVE_ACC (&acc[2 ], ii+4 , jj);
3368+ SAVE_ACC (&acc[3 ], ii+4 , jj+4 );
3369+ SAVE_ACC (&acc[4 ], ii+8 , jj);
3370+ SAVE_ACC (&acc[5 ], ii+8 , jj+4 );
3371+ SAVE_ACC (&acc[6 ], ii+12 , jj);
3372+ SAVE_ACC (&acc[7 ], ii+12 , jj+4 );
3373+ }else {
3374+ static_assert (false , " RN/RM values not supported" );
3375+ }
3376+ }
32073377
32083378 template <int RM, int RN>
32093379 NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
@@ -3213,21 +3383,12 @@ class tinyBLAS_PPC {
32133383 int64_t duty = (tiles + nth - 1 ) / nth;
32143384 int64_t start = duty * ith;
32153385 int64_t end = start + duty;
3216- if (RM == 4 && RN == 4 ) {
3217- kernel = &tinyBLAS_PPC::KERNEL_4x4;
3218- } else if (RM == 4 && RN == 8 ) {
3219- kernel = &tinyBLAS_PPC::KERNEL_4x8;
3220- } else if (RM == 8 && RN == 4 ) {
3221- kernel = &tinyBLAS_PPC::KERNEL_8x4;
3222- } else if (RM == 8 && RN == 8 ) {
3223- kernel = &tinyBLAS_PPC::KERNEL_8x8;
3224- }
32253386 if (end > tiles)
32263387 end = tiles;
32273388 for (int64_t job = start; job < end; ++job) {
32283389 int64_t ii = m0 + job / xtiles * RM;
32293390 int64_t jj = n0 + job % xtiles * RN;
3230- ( this ->* kernel) (ii, jj);
3391+ kernel<RM,RN> (ii, jj);
32313392 }
32323393 }
32333394
0 commit comments