Skip to content

Commit 3beb65b

Browse files
committed
Implement a 16x8 kernel for Q4 GEMM
This does not give any perf benefit nor cause any regression Signed-off-by: Shalini Salomi Bodapati <[email protected]>
1 parent ee3a9fc commit 3beb65b

File tree

1 file changed

+103
-16
lines changed

1 file changed

+103
-16
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 103 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,8 +1883,12 @@ class tinyBLAS_Q0_PPC {
18831883
int n_rem = MIN(n - n0, 16);
18841884

18851885
int mc = 0, nc = 0;
1886+
if (m_rem >=16 && n_rem >= 8) {
1887+
mc = 16;
1888+
nc= 8;
1889+
gemm<16,8>(m0, m,n0,n);
18861890

1887-
if (m_rem >= 8 && n_rem >= 8) {
1891+
} else if (m_rem >= 8 && n_rem >= 8) {
18881892
mc = 8;
18891893
nc = 8;
18901894
gemm<8, 8>(m0, m, n0, n);
@@ -2010,16 +2014,32 @@ class tinyBLAS_Q0_PPC {
20102014
std::array<int, 8> comparray {};
20112015
vector float fin_res[16] = {0};
20122016
vector float vs[16] = {0};
2013-
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2017+
bool constexpr isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2018+
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch one loop ahead
2019+
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch one loop ahead
20142020
for (int l = 0; l < k; l++) {
20152021
__builtin_mma_xxsetaccz(&acc_0);
20162022
__builtin_mma_xxsetaccz(&acc_1);
20172023
__builtin_mma_xxsetaccz(&acc_2);
20182024
__builtin_mma_xxsetaccz(&acc_3);
2019-
if (std::is_same_v<TA, block_q4_0>) {
2025+
2026+
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2027+
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2028+
if constexpr(isAblock_q4) {
20202029
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
20212030
} else {
20222031
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2032+
auto aoffset = A+(ii*lda)+l;
2033+
for (int i = 0; i < 8; i++) {
2034+
comparray[i] = 0;
2035+
int ca = 0;
2036+
auto *at = aoffset->qs;
2037+
for (int j = 0; j < 32; j++)
2038+
ca += (int)*at++;
2039+
comparray[i] = ca;
2040+
aoffset += lda;
2041+
}
2042+
20232043
}
20242044
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
20252045
for(int x = 0; x < 8; x++) {
@@ -2034,18 +2054,6 @@ class tinyBLAS_Q0_PPC {
20342054
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
20352055
}
20362056
}
2037-
if (!isAblock_q4) {
2038-
auto aoffset = A+(ii*lda)+l;
2039-
for (int i = 0; i < 8; i++) {
2040-
comparray[i] = 0;
2041-
int ca = 0;
2042-
auto *at = aoffset->qs;
2043-
for (int j = 0; j < 32; j++)
2044-
ca += (int)*at++;
2045-
comparray[i] = ca;
2046-
aoffset += lda;
2047-
}
2048-
}
20492057
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
20502058
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
20512059
compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
@@ -2056,6 +2064,83 @@ class tinyBLAS_Q0_PPC {
20562064
save_res(ii, jj+4, 8, fin_res);
20572065
save_res(ii+4, jj+4, 12, fin_res);
20582066
}
2067+
void KERNEL_16x8(int64_t ii, int64_t jj) {
2068+
vec_t vec_A[32], vec_B[16] = {0}; // 16 rows × 2 blocks for A
2069+
acc_t acc[8]; // 8 accumulators
2070+
std::array<int, 16> comparray {}; // 16 rows
2071+
vector float fin_res[32] = {0}; // final results
2072+
vector float vs[32] = {0}; // scale * B
2073+
2074+
constexpr bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2075+
2076+
for (int l = 0; l < k; l++) {
2077+
// Zero all 8 accumulators
2078+
for (int x = 0; x < 8; x++)
2079+
__builtin_mma_xxsetaccz(&acc[x]);
2080+
2081+
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2082+
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2083+
// Pack A
2084+
if constexpr (isAblock_q4) {
2085+
packNormalInt4<16>((A + (ii*lda) + l), lda, 16, 4, (int8_t*)vec_A, comparray);
2086+
} else {
2087+
packNormal<int8_t, vector signed char>((const block_q8_0*)(A + (ii*lda) + l), lda, 16, 8, (int8_t*)vec_A, false);
2088+
auto aoffset = A + (ii*lda) + l;
2089+
for (int i = 0; i < 16; i++) {
2090+
comparray[i] = 0;
2091+
int ca = 0;
2092+
auto *at = aoffset->qs;
2093+
for (int j = 0; j < 32; j++) ca += (int)*at++;
2094+
comparray[i] = ca;
2095+
aoffset += lda;
2096+
}
2097+
}
2098+
2099+
// Pack B
2100+
packNormal<uint8_t, vector unsigned char>((B + (jj*ldb) + l), ldb, 8, 8, (uint8_t*)vec_B, true);
2101+
2102+
// MMA multiply: 16 rows × 8 cols → 4×4 quadrants for each 8×4
2103+
for (int x = 0; x < 8; x++) {
2104+
// top-left / top-right quadrants
2105+
__builtin_mma_xvi8ger4pp(&acc[0], vec_A[x], vec_B[x]); // 0–7 rows, 0–3 cols
2106+
__builtin_mma_xvi8ger4pp(&acc[2], vec_A[x], vec_B[x+8]); // 0–7 rows, 4–7 cols
2107+
// bottom-left / bottom-right quadrants
2108+
__builtin_mma_xvi8ger4pp(&acc[1], vec_A[x+8], vec_B[x]); // 8–15 rows, 0–3 cols
2109+
__builtin_mma_xvi8ger4pp(&acc[3], vec_A[x+8], vec_B[x+8]); // 8–15 rows, 4–7 cols
2110+
// extra quadrants (double-buffer or prefetch)
2111+
__builtin_mma_xvi8ger4pp(&acc[4], vec_A[x], vec_B[x]); // same as top-left, can be reused
2112+
__builtin_mma_xvi8ger4pp(&acc[6], vec_A[x], vec_B[x+8]); // same as top-right
2113+
__builtin_mma_xvi8ger4pp(&acc[5], vec_A[x+8], vec_B[x]); // same as bottom-left
2114+
__builtin_mma_xvi8ger4pp(&acc[7], vec_A[x+8], vec_B[x+8]); // same as bottom-right
2115+
}
2116+
2117+
// Compute vs: scale * B
2118+
for (int I = 0; I < 16; I++) {
2119+
float a_scale = unhalf((A + ((ii+I)*lda) + l)->d);
2120+
for (int J = 0; J < 4; J++) {
2121+
*((float*)&vs[I] + J) = a_scale * unhalf((B + ((jj+J)*ldb) + l)->d);
2122+
*((float*)&vs[I+16] + J) = a_scale * unhalf((B + ((jj+J+4)*ldb) + l)->d);
2123+
}
2124+
}
2125+
2126+
// Compute accumulators
2127+
compute<16>(&acc[0], 0, 0, comparray, vs, fin_res);
2128+
compute<16>(&acc[2], 0, 4, comparray, vs, fin_res);
2129+
compute<16>(&acc[1], 8, 0, comparray, vs, fin_res);
2130+
compute<16>(&acc[3], 8, 4, comparray, vs, fin_res);
2131+
compute<16>(&acc[4], 0, 0, comparray, vs, fin_res);
2132+
compute<16>(&acc[6], 0, 4, comparray, vs, fin_res);
2133+
compute<16>(&acc[5], 8, 0, comparray, vs, fin_res);
2134+
compute<16>(&acc[7], 8, 4, comparray, vs, fin_res);
2135+
}
2136+
2137+
// Save results
2138+
for (int r = 0; r < 16; r += 8) {
2139+
for (int c = 0; c < 8; c += 4) {
2140+
save_res(ii + r, jj + c, /*offset*/ r + c, fin_res);
2141+
}
2142+
}
2143+
}
20592144

20602145
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
20612146
int64_t ytiles = (m - m0) / RM;
@@ -2133,7 +2218,9 @@ class tinyBLAS_Q0_PPC {
21332218
KERNEL_8x4(ii,jj);
21342219
} else if constexpr(RM == 8 && RN == 8) {
21352220
KERNEL_8x8(ii,jj);
2136-
} else {
2221+
} else if constexpr(RM == 16 && RN == 8) {
2222+
KERNEL_16x8(ii,jj);
2223+
}else {
21372224
assert(false && "RN/RM values not supported");
21382225
}
21392226
}

0 commit comments

Comments
 (0)