Skip to content

Commit 179a4d2

Browse files
committed
Move packing from inside MMA kernels
This patch moves calls from packing routines from inside MMA kernel to one step behind. Current call stack : matmul->mnpack->gemm->kernel->PackTanspose+MMA instructions Changed call stack: matmul->mnpack->gemm->PackTranspose->kernel->MMA instrutcions Not seeing much perf difference with this change Signed-off-by: Shalini Salomi Bodapati <[email protected]>
1 parent e298d2f commit 179a4d2

File tree

1 file changed

+174
-13
lines changed

1 file changed

+174
-13
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 174 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,7 +2673,6 @@ class tinyBLAS_PPC {
26732673

26742674
private:
26752675

2676-
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
26772676

26782677
template<typename VA>
26792678
void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
@@ -2920,7 +2919,72 @@ class tinyBLAS_PPC {
29202919
}
29212920
}
29222921

2923-
void KERNEL_4x4(int64_t ii, int64_t jj) {
2922+
void KERNEL_4x4(vec_t* vec_A, vec_t* vec_B, acc_t* acc){
2923+
for (int l = 0; l < k; l+=4) {
2924+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[0], vec_B[0]);
2925+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[1], vec_B[1]);
2926+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[2], vec_B[2]);
2927+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[3], vec_B[3]);
2928+
}
2929+
}
2930+
2931+
void KERNEL_4x8(vec_t* vec_A, vec_t* vec_B, acc_t* acc ) {
2932+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[0], (vec_t)vec_B[0]);
2933+
__builtin_mma_xvf32gerpp(&acc[1], vec_A[0], (vec_t)vec_B[1]);
2934+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[1], (vec_t)vec_B[2]);
2935+
__builtin_mma_xvf32gerpp(&acc[1], vec_A[1], (vec_t)vec_B[3]);
2936+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[2], (vec_t)vec_B[4]);
2937+
__builtin_mma_xvf32gerpp(&acc[1], vec_A[2], (vec_t)vec_B[5]);
2938+
__builtin_mma_xvf32gerpp(&acc[0], vec_A[3], (vec_t)vec_B[6]);
2939+
__builtin_mma_xvf32gerpp(&acc[1], vec_A[3], (vec_t)vec_B[7]);
2940+
2941+
}
2942+
2943+
void KERNEL_8x4(vec_t* vec_A, vec_t* vec_B, acc_t* acc) {
2944+
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[0], vec_B[0]);
2945+
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[1], vec_B[0]);
2946+
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[2], vec_B[1]);
2947+
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[3], vec_B[1]);
2948+
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[4], vec_B[2]);
2949+
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[5], vec_B[2]);
2950+
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[6], vec_B[3]);
2951+
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[7], vec_B[3]);
2952+
}
2953+
2954+
void KERNEL_8x8(vec_t * vec_A, vec_t* vec_B, acc_t* acc) {
2955+
for(int x = 0; x < 16; x+=2) {
2956+
__builtin_mma_xvf32gerpp(&acc[0], (vec_t)vec_A[x], vec_B[x]);
2957+
__builtin_mma_xvf32gerpp(&acc[1], (vec_t)vec_A[x], vec_B[x+1]);
2958+
__builtin_mma_xvf32gerpp(&acc[2], (vec_t)vec_A[x+1], vec_B[x]);
2959+
__builtin_mma_xvf32gerpp(&acc[3], (vec_t)vec_A[x+1], vec_B[x+1]);
2960+
}
2961+
}
2962+
void KERNEL_8x16(vec_t * vec_A, vec_t* vec_B, acc_t* acc) {
2963+
for(int x = 0; x < 16; x+=2) {
2964+
__builtin_mma_xvf32gerpp (&acc[0], (vec_t)vec_A[x], (vec_t)vec_B[x]);
2965+
__builtin_mma_xvf32gerpp (&acc[1], (vec_t)vec_A[x], (vec_t)vec_B[x+1]);
2966+
__builtin_mma_xvf32gerpp (&acc[2], (vec_t)vec_A[x], (vec_t)vec_B[x+16]);
2967+
__builtin_mma_xvf32gerpp (&acc[3], (vec_t)vec_A[x], (vec_t)vec_B[x+17]);
2968+
__builtin_mma_xvf32gerpp (&acc[4], (vec_t)vec_A[x+1], (vec_t)vec_B[x]);
2969+
__builtin_mma_xvf32gerpp (&acc[5], (vec_t)vec_A[x+1], (vec_t)vec_B[x+1]);
2970+
__builtin_mma_xvf32gerpp (&acc[6], (vec_t)vec_A[x+1], (vec_t)vec_B[x+16]);
2971+
__builtin_mma_xvf32gerpp (&acc[7], (vec_t)vec_A[x+1], (vec_t)vec_B[x+17]);
2972+
}
2973+
}
2974+
2975+
void KERNEL_16x8(vec_t * vec_A, vec_t* vec_B, acc_t* acc) {
2976+
for(int x = 0; x < 16; x+=2) {
2977+
__builtin_mma_xvf32gerpp (&acc[0], (vec_t)vec_A[x+0], (vec_t)vec_B[x]);
2978+
__builtin_mma_xvf32gerpp (&acc[1], (vec_t)vec_A[x+0], (vec_t)vec_B[x+1]);
2979+
__builtin_mma_xvf32gerpp (&acc[2], (vec_t)vec_A[x+1], (vec_t)vec_B[x]);
2980+
__builtin_mma_xvf32gerpp (&acc[3], (vec_t)vec_A[x+1], (vec_t)vec_B[x+1]);
2981+
__builtin_mma_xvf32gerpp (&acc[4], (vec_t)vec_A[x+16], (vec_t)vec_B[x]);
2982+
__builtin_mma_xvf32gerpp (&acc[5], (vec_t)vec_A[x+16], (vec_t)vec_B[x+1]);
2983+
__builtin_mma_xvf32gerpp (&acc[6], (vec_t)vec_A[x+17], (vec_t)vec_B[x]);
2984+
__builtin_mma_xvf32gerpp (&acc[7], (vec_t)vec_A[x+17], (vec_t)vec_B[x+1]);
2985+
}
2986+
}
2987+
/*void KERNEL_4x4(int64_t ii, int64_t jj) {
29242988
vec_t vec_A[4], vec_B[4], vec_C[4];
29252989
acc_t acc_0;
29262990
__builtin_mma_xxsetaccz(&acc_0);
@@ -2998,7 +3062,7 @@ class tinyBLAS_PPC {
29983062
SAVE_ACC(&acc_1, ii, jj+4);
29993063
SAVE_ACC(&acc_2, ii+4, jj);
30003064
SAVE_ACC(&acc_3, ii+4, jj+4);
3001-
}
3065+
}*/
30023066

30033067
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
30043068
int64_t mc, nc, mp, np;
@@ -3204,6 +3268,112 @@ class tinyBLAS_PPC {
32043268
}
32053269
}
32063270
}
3271+
template<int RM, int RN>
3272+
inline void kernel(int64_t ii, int64_t jj) {
3273+
if constexpr(RM == 4 && RN == 4) {
3274+
acc_t acc[1];
3275+
__builtin_mma_xxsetaccz(&acc[0]);
3276+
vec_t vec_A[4], vec_B[4], vec_C[4];
3277+
for (int l = 0; l < k; l += 4) {
3278+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
3279+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
3280+
KERNEL_4x4(vec_A, vec_B, acc);
3281+
}
3282+
SAVE_ACC(&acc[0], ii, jj);
3283+
} else if constexpr(RM == 4 && RN == 8) {
3284+
vec_t vec_A[4], vec_B[8], vec_C[4];
3285+
acc_t acc[2];
3286+
__builtin_mma_xxsetaccz(&acc[0]);
3287+
__builtin_mma_xxsetaccz(&acc[1]);
3288+
for (int64_t l = 0; l < k; l+=4) {
3289+
packTranspose<vector float>(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
3290+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
3291+
KERNEL_4x8(vec_A, vec_B, acc);
3292+
}
3293+
SAVE_ACC(&acc[0], ii, jj);
3294+
SAVE_ACC(&acc[1], ii, jj+4);
3295+
3296+
} else if constexpr(RM == 8 && RN == 4) {
3297+
vec_t vec_A[8], vec_B[4], vec_C[4];
3298+
acc_t acc[2];
3299+
__builtin_mma_xxsetaccz(&acc[0]);
3300+
__builtin_mma_xxsetaccz(&acc[1]);
3301+
for (int64_t l = 0; l < k; l+=4) {
3302+
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
3303+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
3304+
KERNEL_8x4(vec_A, vec_B, acc);
3305+
}
3306+
SAVE_ACC(&acc[0], ii, jj);
3307+
SAVE_ACC(&acc[1], ii+4, jj);
3308+
} else if constexpr(RM == 8 && RN == 8) {
3309+
vec_t vec_A[16], vec_B[16], vec_C[4];
3310+
acc_t acc[4];
3311+
__builtin_mma_xxsetaccz(&acc[0]);
3312+
__builtin_mma_xxsetaccz(&acc[1]);
3313+
__builtin_mma_xxsetaccz(&acc[2]);
3314+
__builtin_mma_xxsetaccz(&acc[3]);
3315+
for (int l = 0; l < k; l+=8) {
3316+
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
3317+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
3318+
KERNEL_8x8(vec_A, vec_B, acc);
3319+
}
3320+
SAVE_ACC(&acc[0], ii, jj);
3321+
SAVE_ACC(&acc[1], ii, jj+4);
3322+
SAVE_ACC(&acc[2], ii+4, jj);
3323+
SAVE_ACC(&acc[3], ii+4, jj+4);
3324+
} else if constexpr(RM == 8 && RN == 16) {
3325+
vec_t vec_A[16], vec_B[32], vec_C[4];
3326+
acc_t acc[8];
3327+
__builtin_mma_xxsetaccz(&acc[0]);
3328+
__builtin_mma_xxsetaccz(&acc[1]);
3329+
__builtin_mma_xxsetaccz(&acc[2]);
3330+
__builtin_mma_xxsetaccz(&acc[3]);
3331+
__builtin_mma_xxsetaccz(&acc[4]);
3332+
__builtin_mma_xxsetaccz(&acc[5]);
3333+
__builtin_mma_xxsetaccz(&acc[6]);
3334+
__builtin_mma_xxsetaccz(&acc[7]);
3335+
for (int l = 0; l < k; l+=8) {
3336+
packTranspose<vector float>(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
3337+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 16, (float*)vec_B);
3338+
KERNEL_8x16(vec_A, vec_B, acc);
3339+
}
3340+
SAVE_ACC(&acc[0], ii, jj);
3341+
SAVE_ACC(&acc[1], ii, jj+4);
3342+
SAVE_ACC(&acc[2], ii, jj+8);
3343+
SAVE_ACC(&acc[3], ii, jj+12);
3344+
SAVE_ACC(&acc[4], ii+4, jj);
3345+
SAVE_ACC(&acc[5], ii+4, jj+4);
3346+
SAVE_ACC(&acc[6], ii+4, jj+8);
3347+
SAVE_ACC(&acc[7], ii+4, jj+12);
3348+
3349+
} else if constexpr(RM == 16 && RN == 8){
3350+
vec_t vec_A[32], vec_B[16], vec_C[4];
3351+
acc_t acc[8];
3352+
__builtin_mma_xxsetaccz(&acc[0]);
3353+
__builtin_mma_xxsetaccz(&acc[1]);
3354+
__builtin_mma_xxsetaccz(&acc[2]);
3355+
__builtin_mma_xxsetaccz(&acc[3]);
3356+
__builtin_mma_xxsetaccz(&acc[4]);
3357+
__builtin_mma_xxsetaccz(&acc[5]);
3358+
__builtin_mma_xxsetaccz(&acc[6]);
3359+
__builtin_mma_xxsetaccz(&acc[7]);
3360+
for (int l = 0; l < k; l+=8) {
3361+
packTranspose<vector float>(A+(ii*lda)+l, lda, 16, 8, (float*)vec_A);
3362+
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
3363+
KERNEL_16x8(vec_A, vec_B, acc);
3364+
}
3365+
SAVE_ACC(&acc[0], ii, jj);
3366+
SAVE_ACC(&acc[1], ii, jj+4);
3367+
SAVE_ACC(&acc[2], ii+4, jj);
3368+
SAVE_ACC(&acc[3], ii+4, jj+4);
3369+
SAVE_ACC(&acc[4], ii+8, jj);
3370+
SAVE_ACC(&acc[5], ii+8, jj+4);
3371+
SAVE_ACC(&acc[6], ii+12, jj);
3372+
SAVE_ACC(&acc[7], ii+12, jj+4);
3373+
}else {
3374+
static_assert(false, "RN/RM values not supported");
3375+
}
3376+
}
32073377

32083378
template <int RM, int RN>
32093379
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
@@ -3213,21 +3383,12 @@ class tinyBLAS_PPC {
32133383
int64_t duty = (tiles + nth - 1) / nth;
32143384
int64_t start = duty * ith;
32153385
int64_t end = start + duty;
3216-
if (RM == 4 && RN == 4) {
3217-
kernel = &tinyBLAS_PPC::KERNEL_4x4;
3218-
} else if (RM == 4 && RN == 8) {
3219-
kernel = &tinyBLAS_PPC::KERNEL_4x8;
3220-
} else if (RM == 8 && RN == 4) {
3221-
kernel = &tinyBLAS_PPC::KERNEL_8x4;
3222-
} else if (RM == 8 && RN == 8) {
3223-
kernel = &tinyBLAS_PPC::KERNEL_8x8;
3224-
}
32253386
if (end > tiles)
32263387
end = tiles;
32273388
for (int64_t job = start; job < end; ++job) {
32283389
int64_t ii = m0 + job / xtiles * RM;
32293390
int64_t jj = n0 + job % xtiles * RN;
3230-
(this->*kernel)(ii, jj);
3391+
kernel<RM,RN>(ii, jj);
32313392
}
32323393
}
32333394

0 commit comments

Comments
 (0)