Skip to content

Commit e16fee0

Browse files
rootshalinib-ibm
authored andcommitted
Q4/Q8 Tiled Gemm Optimization.
This patch implemenrts tiled GEMM for large blocks where we pack blocks of 64x64 and perfrom matmul. 30 ~ 50 % improvement in llama-bench and llama-batched-bench with Meta-Llama3-8B Qunatized models( Q4_0 and Q8_0). Signed-off-by: root <[email protected]>
1 parent 061f0ef commit e16fee0

File tree

1 file changed

+271
-25
lines changed

1 file changed

+271
-25
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 271 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,30 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
119119
#if defined(__MMA__)
120120
typedef vector unsigned char vec_t;
121121
typedef __vector_quad acc_t;
122+
123+
#include <pthread.h>
124+
125+
typedef vector unsigned char vec_t;
126+
typedef __vector_quad acc_t;
127+
128+
static pthread_key_t t_data_key;
129+
typedef struct {
130+
vec_t* A_pack;
131+
vec_t* B_pack;
132+
int* comparray;
133+
} thread_scratchpad_t;
134+
void thread_cleanup(void* arg) {
135+
thread_scratchpad_t* data = (thread_scratchpad_t*)arg;
136+
if (data) {
137+
delete[] data->A_pack;
138+
delete[] data->B_pack;
139+
delete[] data->comparray;
140+
141+
delete data;
142+
}
143+
}
144+
static bool key_created = false;
145+
122146
#endif
123147
////////////////////////////////////////////////////////////////////////////////////////////////////
124148
// VECTORIZED FUSED MULTIPLY ADD
@@ -621,6 +645,7 @@ class tinyBLAS_Q0_ARM {
621645
const block_q8_0 *const B;
622646
float *const C;
623647
const int64_t k;
648+
int64_t kc;
624649
const int64_t lda;
625650
const int64_t ldb;
626651
const int64_t ldc;
@@ -1582,34 +1607,61 @@ class tinyBLAS_Q0_PPC {
15821607
float *C, int64_t ldc,
15831608
int ith, int nth)
15841609
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1610+
kc = 64;
15851611
}
15861612

15871613
void matmul(int64_t m, int64_t n) {
1588-
mnpack(0, m, 0, n);
1614+
int mc = 64; int nc = 64;
1615+
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
1616+
if (is_aligned) {
1617+
matmul_tiled(m, n, mc, nc, kc);
1618+
} else {
1619+
mnpack(0, m, 0, n);
1620+
}
15891621
}
15901622

15911623
private:
15921624

15931625
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1594-
for (int I = 0; I < RM; I++) {
1595-
for (int J = 0; J < RN; J++) {
1596-
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1597-
}
1598-
}
1626+
for (int I = 0; I < RM; I++) {
1627+
for (int J = 0; J < RN; J++) {
1628+
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1629+
}
1630+
}
15991631
}
16001632

1601-
template<int size>
1602-
inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1603-
vector signed int vec_C[4];
1604-
vector float CA[4] = {0};
1605-
vector float res[4] = {0};
1606-
__builtin_mma_disassemble_acc(vec_C, ACC);
1607-
for (int i = 0; i < 4; i++) {
1608-
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1609-
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1610-
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611-
}
1633+
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1634+
for (int I = 0; I < RM; I++) {
1635+
for (int J = 0; J < RN; J++) {
1636+
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
1637+
*c_ptr += *((float*)&fin_res[idx+I]+J);
1638+
}
1639+
}
16121640
}
1641+
1642+
template<typename ArrayType>
1643+
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
1644+
vector signed int vec_C[4];
1645+
vector float CA[4] = {0};
1646+
vector float res[4] = {0};
1647+
__builtin_mma_disassemble_acc(vec_C, ACC);
1648+
for (int i = 0; i < 4; i++) {
1649+
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1650+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1651+
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1652+
}
1653+
}
1654+
1655+
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
1656+
for (int I = 0; I<8; I++) {
1657+
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
1658+
for (int J = 0; J<4; J++) {
1659+
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
1660+
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
1661+
}
1662+
}
1663+
}
1664+
16131665
/* This function processes quantized data from block_q4_0 elements.
16141666
* First the we try to extract the two int4 values stored in single int8_t into two signed int8.
16151667
* And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
@@ -1630,6 +1682,18 @@ class tinyBLAS_Q0_PPC {
16301682
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
16311683
}
16321684

1685+
inline void process_q8_elements(const int8_t *qs, int *ca) {
1686+
vector signed char c1 = vec_xl(0, qs);
1687+
vector signed char c2 = vec_xl(16, qs);
1688+
vector signed int vsum1 = {0};
1689+
vector signed int vsum2 = {0};
1690+
vsum1 = vec_sum4s(c1, vsum1);
1691+
vsum2 = vec_sum4s(c2, vsum2);
1692+
vector signed int vsum = vec_add(vsum1, vsum2);
1693+
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1694+
}
1695+
1696+
16331697
template <typename V1, typename V2>
16341698
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
16351699
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
@@ -1877,6 +1941,98 @@ class tinyBLAS_Q0_PPC {
18771941
}
18781942
}
18791943
}
1944+
template<typename VA, typename VB>
1945+
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
1946+
int64_t i, j;
1947+
block_q8_0 *aoffset = NULL;
1948+
VA *vecOffset = NULL;
1949+
block_q8_0* aoffsets[8];
1950+
__vector_pair arr[8];
1951+
VB c[8][2] = {0};
1952+
VB c1[8] = {0}; VB c2[8] = {0};
1953+
aoffset = const_cast<block_q8_0*>(a);
1954+
vecOffset = vec;
1955+
j = (rows >> 3);
1956+
int index = 0;
1957+
if (j > 0) {
1958+
do {
1959+
for (int it = 0; it < 8; it++)
1960+
aoffsets[it] = aoffset + it*lda;
1961+
aoffset += 8 * lda;
1962+
for (int blk = 0; blk < kc; blk++) {
1963+
for (int it = 0; it < 8; it++) {
1964+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
1965+
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
1966+
c1[it] = c[it][0];
1967+
c2[it] = c[it][1];
1968+
if (comparray){
1969+
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
1970+
}
1971+
}
1972+
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1973+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1974+
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1975+
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1976+
vecOffset += 256;
1977+
}
1978+
j--;
1979+
index += 8*kc;
1980+
} while(j > 0);
1981+
}
1982+
}
1983+
1984+
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
1985+
int64_t i, j;
1986+
TA *aoffset = NULL;
1987+
int8_t *vecOffset = NULL;
1988+
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1989+
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1990+
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1991+
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1992+
aoffset = const_cast<TA*>(a);
1993+
vecOffset = vec;
1994+
int index = 0;
1995+
j = (rows >> 3);
1996+
if (j > 0) {
1997+
do {
1998+
aoffset1 = aoffset;
1999+
aoffset2 = aoffset1 + lda;
2000+
aoffset3 = aoffset2 + lda;
2001+
aoffset4 = aoffset3 + lda;
2002+
aoffset5 = aoffset4 + lda;
2003+
aoffset6 = aoffset5 + lda;
2004+
aoffset7 = aoffset6 + lda;
2005+
aoffset8 = aoffset7 + lda;
2006+
aoffset += 8 * lda;
2007+
for (int blk = 0; blk < kc; blk++) {
2008+
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
2009+
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
2010+
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
2011+
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
2012+
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
2013+
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
2014+
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
2015+
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
2016+
2017+
process_q4_elements(c1, &comparray[index + 8*blk+0]);
2018+
process_q4_elements(c2, &comparray[index + 8*blk+1]);
2019+
process_q4_elements(c3, &comparray[index + 8*blk+2]);
2020+
process_q4_elements(c4, &comparray[index + 8*blk+3]);
2021+
process_q4_elements(c5, &comparray[index + 8*blk+4]);
2022+
process_q4_elements(c6, &comparray[index + 8*blk+5]);
2023+
process_q4_elements(c7, &comparray[index + 8*blk+6]);
2024+
process_q4_elements(c8, &comparray[index + 8*blk+7]);
2025+
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2026+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2027+
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
2028+
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
2029+
vecOffset += 256;
2030+
}
2031+
j--;
2032+
index += 8*kc;
2033+
} while (j > 0);
2034+
}
2035+
}
18802036

18812037
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
18822038
int m_rem = MIN(m - m0, 16);
@@ -1953,8 +2109,8 @@ class tinyBLAS_Q0_PPC {
19532109
aoffset += lda;
19542110
}
19552111
}
1956-
compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1957-
compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
2112+
compute(&acc_0, 0, 0, comparray, vs, fin_res);
2113+
compute(&acc_1, 0, 4, comparray, vs, fin_res);
19582114
}
19592115
save_res(ii, jj, 0, fin_res);
19602116
save_res(ii, jj+4, 4, fin_res);
@@ -1997,8 +2153,8 @@ class tinyBLAS_Q0_PPC {
19972153
aoffset += lda;
19982154
}
19992155
}
2000-
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2001-
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2156+
compute(&acc_0, 0, 0, comparray, vs, fin_res);
2157+
compute(&acc_1, 4, 4, comparray, vs, fin_res);
20022158
}
20032159
save_res(ii, jj, 0, fin_res);
20042160
save_res(ii+4, jj, 4, fin_res);
@@ -2046,17 +2202,106 @@ class tinyBLAS_Q0_PPC {
20462202
aoffset += lda;
20472203
}
20482204
}
2049-
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2050-
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2051-
compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2052-
compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2205+
compute(&acc_0, 0, 0, comparray, vs, fin_res);
2206+
compute(&acc_1, 4, 4, comparray, vs, fin_res);
2207+
compute(&acc_2, 0, 8, comparray, vs, fin_res);
2208+
compute(&acc_3, 4, 12, comparray, vs, fin_res);
20532209
}
20542210
save_res(ii, jj, 0, fin_res);
20552211
save_res(ii+4, jj, 4, fin_res);
20562212
save_res(ii, jj+4, 8, fin_res);
20572213
save_res(ii+4, jj+4, 12, fin_res);
20582214
}
20592215

2216+
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
2217+
acc_t acc[4];
2218+
for (int i = 0; i < mc ; i += 8) {
2219+
for (int j = 0; j < nc; j += 8) {
2220+
vector float fin_res[16] = {0};
2221+
vector float vs[16] = {0};
2222+
for (int64_t kk = 0; kk < kc; kk++) {
2223+
for (int x = 0; x < 4; x++) {
2224+
__builtin_mma_xxsetaccz(&acc[x]);
2225+
}
2226+
int A_block_idx = (i/8)*(16*kc) + kk*16;
2227+
int B_block_idx = (j/8)*(16*kc)+ kk*16;
2228+
vec_t *A_block = &vec_A[A_block_idx];
2229+
vec_t *B_block = &vec_B[B_block_idx];
2230+
for (int x = 0; x < 8; x++) {
2231+
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
2232+
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
2233+
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
2234+
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
2235+
}
2236+
compute_scale(ii+i, jj+j, l+kk, vs);
2237+
int c_index = (i/8)*(8*kc)+ kk*8;
2238+
int* c_block = &comparray[c_index];
2239+
compute(&acc[0], 0, 0, c_block, vs, fin_res);
2240+
compute(&acc[1], 4, 4, c_block, vs, fin_res);
2241+
compute(&acc[2], 0, 8, c_block, vs, fin_res);
2242+
compute(&acc[3], 4, 12, c_block, vs, fin_res);
2243+
}
2244+
if (l == 0) {
2245+
save_res(ii+i, jj+j, 0, fin_res);
2246+
save_res(ii+i+4, jj+j, 4, fin_res);
2247+
save_res(ii+i, jj+j+4, 8, fin_res);
2248+
save_res(ii+i+4, jj+j+4, 12, fin_res);
2249+
} else {
2250+
add_save_res(ii+i, jj+j, 0, fin_res);
2251+
add_save_res(ii+i+4, jj+j, 4, fin_res);
2252+
add_save_res(ii+i, jj+j+4, 8, fin_res);
2253+
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
2254+
}
2255+
}
2256+
}
2257+
}
2258+
2259+
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2260+
if (!key_created) {
2261+
if (pthread_key_create(&t_data_key, thread_cleanup) == 0) {
2262+
key_created = true;
2263+
} else {
2264+
return;
2265+
}
2266+
}
2267+
thread_scratchpad_t* t_data = (thread_scratchpad_t*)pthread_getspecific(t_data_key);
2268+
if (t_data == nullptr) {
2269+
t_data = new thread_scratchpad_t;
2270+
t_data->A_pack = new vec_t[mc * kc * 2];
2271+
t_data->B_pack = new vec_t[nc * kc * 2];
2272+
t_data->comparray = new int[mc * kc];
2273+
pthread_setspecific(t_data_key, t_data);
2274+
}
2275+
vec_t* A_pack = t_data->A_pack;
2276+
vec_t* B_pack = t_data->B_pack;
2277+
int* comparray = t_data->comparray;
2278+
2279+
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
2280+
int64_t ytiles = m / mc;
2281+
int64_t xtiles = n / nc;
2282+
int64_t tiles = xtiles * ytiles;
2283+
int64_t duty = (tiles + nth - 1) / nth;
2284+
int64_t start = duty * ith;
2285+
int64_t end = start + duty;
2286+
if (end > tiles) {
2287+
end = tiles;
2288+
}
2289+
for (int64_t job = start; job < end; ++job) {
2290+
int64_t ii = (job / xtiles) * mc;
2291+
int64_t jj = (job % xtiles) * nc;
2292+
for (int64_t kk = 0; kk < k; kk += kc) {
2293+
if constexpr(is_Ablock_q4) {
2294+
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
2295+
} else {
2296+
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
2297+
}
2298+
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
2299+
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
2300+
}
2301+
}
2302+
}
2303+
2304+
20602305
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
20612306
int64_t ytiles = (m - m0) / RM;
20622307
int64_t xtiles = (n - n0) / RN;
@@ -2159,6 +2404,7 @@ class tinyBLAS_Q0_PPC {
21592404
const block_q8_0 *const B;
21602405
float *C;
21612406
const int64_t k;
2407+
int64_t kc;
21622408
const int64_t lda;
21632409
const int64_t ldb;
21642410
const int64_t ldc;

0 commit comments

Comments
 (0)