Skip to content

Commit c726070

Browse files
rootshalinib-ibm
authored andcommitted
Q4/Q8 Tiled Gemm Optimization.
This patch implemenrts tiled GEMM for large blocks where we pack blocks of 64x64 and perfrom matmul. 30 ~ 50 % improvement in llama-bench and llama-batched-bench with Meta-Llama3-8B Qunatized models( Q4_0 and Q8_0). Signed-off-by: root <[email protected]>
1 parent 061f0ef commit c726070

File tree

1 file changed

+273
-15
lines changed

1 file changed

+273
-15
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 273 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,30 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
119119
#if defined(__MMA__)
120120
typedef vector unsigned char vec_t;
121121
typedef __vector_quad acc_t;
122+
123+
#include <pthread.h>
124+
125+
typedef vector unsigned char vec_t;
126+
typedef __vector_quad acc_t;
127+
128+
static pthread_key_t t_data_key;
129+
typedef struct {
130+
vec_t* A_pack;
131+
vec_t* B_pack;
132+
int* comparray;
133+
} thread_scratchpad_t;
134+
void thread_cleanup(void* arg) {
135+
thread_scratchpad_t* data = (thread_scratchpad_t*)arg;
136+
if (data) {
137+
delete[] data->A_pack;
138+
delete[] data->B_pack;
139+
delete[] data->comparray;
140+
141+
delete data;
142+
}
143+
}
144+
static bool key_created = false;
145+
122146
#endif
123147
////////////////////////////////////////////////////////////////////////////////////////////////////
124148
// VECTORIZED FUSED MULTIPLY ADD
@@ -621,6 +645,7 @@ class tinyBLAS_Q0_ARM {
621645
const block_q8_0 *const B;
622646
float *const C;
623647
const int64_t k;
648+
int64_t kc;
624649
const int64_t lda;
625650
const int64_t ldb;
626651
const int64_t ldc;
@@ -1582,34 +1607,73 @@ class tinyBLAS_Q0_PPC {
15821607
float *C, int64_t ldc,
15831608
int ith, int nth)
15841609
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1610+
kc = 64;
15851611
}
15861612

15871613
void matmul(int64_t m, int64_t n) {
1588-
mnpack(0, m, 0, n);
1614+
int mc = 64; int nc = 64;
1615+
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
1616+
if (is_aligned) {
1617+
matmul_tiled(m, n, mc, nc, kc);
1618+
} else {
1619+
mnpack(0, m, 0, n);
1620+
}
15891621
}
15901622

15911623
private:
15921624

15931625
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1594-
for (int I = 0; I < RM; I++) {
1595-
for (int J = 0; J < RN; J++) {
1596-
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1597-
}
1598-
}
1626+
for (int I = 0; I < RM; I++) {
1627+
for (int J = 0; J < RN; J++) {
1628+
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1629+
}
1630+
}
1631+
}
1632+
1633+
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1634+
for (int I = 0; I < RM; I++) {
1635+
for (int J = 0; J < RN; J++) {
1636+
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
1637+
*c_ptr += *((float*)&fin_res[idx+I]+J);
1638+
}
1639+
}
15991640
}
16001641

16011642
template<int size>
16021643
inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1603-
vector signed int vec_C[4];
1604-
vector float CA[4] = {0};
1605-
vector float res[4] = {0};
1606-
__builtin_mma_disassemble_acc(vec_C, ACC);
1607-
for (int i = 0; i < 4; i++) {
1608-
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1609-
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1610-
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611-
}
1644+
vector signed int vec_C[4];
1645+
vector float CA[4] = {0};
1646+
vector float res[4] = {0};
1647+
__builtin_mma_disassemble_acc(vec_C, ACC);
1648+
for (int i = 0; i < 4; i++) {
1649+
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1650+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1651+
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1652+
}
1653+
}
1654+
1655+
inline void compute_array(acc_t* ACC, int c_idx, int s_idx, int* comparray, vector float* vs, vector float* fin_res) {
1656+
vector signed int vec_C[4];
1657+
vector float CA[4] = {0};
1658+
vector float res[4] = {0};
1659+
__builtin_mma_disassemble_acc(vec_C, ACC);
1660+
for (int i = 0; i < 4; i++) {
1661+
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1662+
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1663+
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1664+
}
1665+
}
1666+
1667+
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
1668+
for (int I = 0; I<8; I++) {
1669+
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
1670+
for (int J = 0; J<4; J++) {
1671+
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
1672+
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
1673+
}
1674+
}
16121675
}
1676+
16131677
/* This function processes quantized data from block_q4_0 elements.
16141678
* First the we try to extract the two int4 values stored in single int8_t into two signed int8.
16151679
* And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
@@ -1630,6 +1694,18 @@ class tinyBLAS_Q0_PPC {
16301694
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
16311695
}
16321696

1697+
inline void process_q8_elements(const int8_t *qs, int *ca) {
1698+
vector signed char c1 = vec_xl(0, qs);
1699+
vector signed char c2 = vec_xl(16, qs);
1700+
vector signed int vsum1 = {0};
1701+
vector signed int vsum2 = {0};
1702+
vsum1 = vec_sum4s(c1, vsum1);
1703+
vsum2 = vec_sum4s(c2, vsum2);
1704+
vector signed int vsum = vec_add(vsum1, vsum2);
1705+
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1706+
}
1707+
1708+
16331709
template <typename V1, typename V2>
16341710
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
16351711
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
@@ -1877,6 +1953,98 @@ class tinyBLAS_Q0_PPC {
18771953
}
18781954
}
18791955
}
1956+
template<typename VA, typename VB>
1957+
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
1958+
int64_t i, j;
1959+
block_q8_0 *aoffset = NULL;
1960+
VA *vecOffset = NULL;
1961+
block_q8_0* aoffsets[8];
1962+
__vector_pair arr[8];
1963+
VB c[8][2] = {0};
1964+
VB c1[8] = {0}; VB c2[8] = {0};
1965+
aoffset = const_cast<block_q8_0*>(a);
1966+
vecOffset = vec;
1967+
j = (rows >> 3);
1968+
int index = 0;
1969+
if (j > 0) {
1970+
do {
1971+
for (int it = 0; it < 8; it++)
1972+
aoffsets[it] = aoffset + it*lda;
1973+
aoffset += 8 * lda;
1974+
for (int blk = 0; blk < kc; blk++) {
1975+
for (int it = 0; it < 8; it++) {
1976+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
1977+
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
1978+
c1[it] = c[it][0];
1979+
c2[it] = c[it][1];
1980+
if (comparray){
1981+
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
1982+
}
1983+
}
1984+
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1985+
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1986+
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1987+
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1988+
vecOffset += 256;
1989+
}
1990+
j--;
1991+
index += 8*kc;
1992+
} while(j > 0);
1993+
}
1994+
}
1995+
1996+
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
1997+
int64_t i, j;
1998+
TA *aoffset = NULL;
1999+
int8_t *vecOffset = NULL;
2000+
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
2001+
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
2002+
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
2003+
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
2004+
aoffset = const_cast<TA*>(a);
2005+
vecOffset = vec;
2006+
int index = 0;
2007+
j = (rows >> 3);
2008+
if (j > 0) {
2009+
do {
2010+
aoffset1 = aoffset;
2011+
aoffset2 = aoffset1 + lda;
2012+
aoffset3 = aoffset2 + lda;
2013+
aoffset4 = aoffset3 + lda;
2014+
aoffset5 = aoffset4 + lda;
2015+
aoffset6 = aoffset5 + lda;
2016+
aoffset7 = aoffset6 + lda;
2017+
aoffset8 = aoffset7 + lda;
2018+
aoffset += 8 * lda;
2019+
for (int blk = 0; blk < kc; blk++) {
2020+
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
2021+
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
2022+
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
2023+
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
2024+
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
2025+
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
2026+
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
2027+
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
2028+
2029+
process_q4_elements(c1, &comparray[index + 8*blk+0]);
2030+
process_q4_elements(c2, &comparray[index + 8*blk+1]);
2031+
process_q4_elements(c3, &comparray[index + 8*blk+2]);
2032+
process_q4_elements(c4, &comparray[index + 8*blk+3]);
2033+
process_q4_elements(c5, &comparray[index + 8*blk+4]);
2034+
process_q4_elements(c6, &comparray[index + 8*blk+5]);
2035+
process_q4_elements(c7, &comparray[index + 8*blk+6]);
2036+
process_q4_elements(c8, &comparray[index + 8*blk+7]);
2037+
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
2038+
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
2039+
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
2040+
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
2041+
vecOffset += 256;
2042+
}
2043+
j--;
2044+
index += 8*kc;
2045+
} while (j > 0);
2046+
}
2047+
}
18802048

18812049
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
18822050
int m_rem = MIN(m - m0, 16);
@@ -2057,6 +2225,95 @@ class tinyBLAS_Q0_PPC {
20572225
save_res(ii+4, jj+4, 12, fin_res);
20582226
}
20592227

2228+
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
2229+
acc_t acc[4];
2230+
for (int i = 0; i < mc ; i += 8) {
2231+
for (int j = 0; j < nc; j += 8) {
2232+
vector float fin_res[16] = {0};
2233+
vector float vs[16] = {0};
2234+
for (int64_t kk = 0; kk < kc; kk++) {
2235+
for (int x = 0; x < 4; x++) {
2236+
__builtin_mma_xxsetaccz(&acc[x]);
2237+
}
2238+
int A_block_idx = (i/8)*(16*kc) + kk*16;
2239+
int B_block_idx = (j/8)*(16*kc)+ kk*16;
2240+
vec_t *A_block = &vec_A[A_block_idx];
2241+
vec_t *B_block = &vec_B[B_block_idx];
2242+
for (int x = 0; x < 8; x++) {
2243+
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
2244+
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
2245+
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
2246+
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
2247+
}
2248+
compute_scale(ii+i, jj+j, l+kk, vs);
2249+
int c_index = (i/8)*(8*kc)+ kk*8;
2250+
int* c_block = &comparray[c_index];
2251+
compute_array(&acc[0], 0, 0, c_block, vs, fin_res);
2252+
compute_array(&acc[1], 4, 4, c_block, vs, fin_res);
2253+
compute_array(&acc[2], 0, 8, c_block, vs, fin_res);
2254+
compute_array(&acc[3], 4, 12, c_block, vs, fin_res);
2255+
}
2256+
if (l == 0) {
2257+
save_res(ii+i, jj+j, 0, fin_res);
2258+
save_res(ii+i+4, jj+j, 4, fin_res);
2259+
save_res(ii+i, jj+j+4, 8, fin_res);
2260+
save_res(ii+i+4, jj+j+4, 12, fin_res);
2261+
} else {
2262+
add_save_res(ii+i, jj+j, 0, fin_res);
2263+
add_save_res(ii+i+4, jj+j, 4, fin_res);
2264+
add_save_res(ii+i, jj+j+4, 8, fin_res);
2265+
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
2266+
}
2267+
}
2268+
}
2269+
}
2270+
2271+
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2272+
if (!key_created) {
2273+
if (pthread_key_create(&t_data_key, thread_cleanup) == 0) {
2274+
key_created = true;
2275+
} else {
2276+
return;
2277+
}
2278+
}
2279+
thread_scratchpad_t* t_data = (thread_scratchpad_t*)pthread_getspecific(t_data_key);
2280+
if (t_data == nullptr) {
2281+
t_data = new thread_scratchpad_t;
2282+
t_data->A_pack = new vec_t[mc * kc * 2];
2283+
t_data->B_pack = new vec_t[nc * kc * 2];
2284+
t_data->comparray = new int[mc * kc];
2285+
pthread_setspecific(t_data_key, t_data);
2286+
}
2287+
vec_t* A_pack = t_data->A_pack;
2288+
vec_t* B_pack = t_data->B_pack;
2289+
int* comparray = t_data->comparray;
2290+
2291+
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
2292+
int64_t ytiles = m / mc;
2293+
int64_t xtiles = n / nc;
2294+
int64_t tiles = xtiles * ytiles;
2295+
int64_t duty = (tiles + nth - 1) / nth;
2296+
int64_t start = duty * ith;
2297+
int64_t end = start + duty;
2298+
if (end > tiles) {
2299+
end = tiles;
2300+
}
2301+
for (int64_t job = start; job < end; ++job) {
2302+
int64_t ii = (job / xtiles) * mc;
2303+
int64_t jj = (job % xtiles) * nc;
2304+
for (int64_t kk = 0; kk < k; kk += kc) {
2305+
if constexpr(is_Ablock_q4) {
2306+
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
2307+
} else {
2308+
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
2309+
}
2310+
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
2311+
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
2312+
}
2313+
}
2314+
}
2315+
2316+
20602317
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
20612318
int64_t ytiles = (m - m0) / RM;
20622319
int64_t xtiles = (n - n0) / RN;
@@ -2159,6 +2416,7 @@ class tinyBLAS_Q0_PPC {
21592416
const block_q8_0 *const B;
21602417
float *C;
21612418
const int64_t k;
2419+
int64_t kc;
21622420
const int64_t lda;
21632421
const int64_t ldb;
21642422
const int64_t ldc;

0 commit comments

Comments
 (0)