@@ -1883,8 +1883,12 @@ class tinyBLAS_Q0_PPC {
18831883 int n_rem = MIN (n - n0, 16 );
18841884
18851885 int mc = 0 , nc = 0 ;
1886+ if (m_rem >=16 && n_rem >= 8 ) {
1887+ mc = 16 ;
1888+ nc= 8 ;
1889+ gemm<16 ,8 >(m0, m,n0,n);
18861890
1887- if (m_rem >= 8 && n_rem >= 8 ) {
1891+ } else if (m_rem >= 8 && n_rem >= 8 ) {
18881892 mc = 8 ;
18891893 nc = 8 ;
18901894 gemm<8 , 8 >(m0, m, n0, n);
@@ -2010,16 +2014,32 @@ class tinyBLAS_Q0_PPC {
20102014 std::array<int , 8 > comparray {};
20112015 vector float fin_res[16 ] = {0 };
20122016 vector float vs[16 ] = {0 };
2013- bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2017+ bool constexpr isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2018+ __builtin_prefetch ((A+(ii*lda)+0 )->qs , 0 , 1 ); // prefetch one loop ahead
2019+ __builtin_prefetch ((B+(jj*ldb)+0 )->qs , 0 , 1 ); // prefetch one loop ahead
20142020 for (int l = 0 ; l < k; l++) {
20152021 __builtin_mma_xxsetaccz (&acc_0);
20162022 __builtin_mma_xxsetaccz (&acc_1);
20172023 __builtin_mma_xxsetaccz (&acc_2);
20182024 __builtin_mma_xxsetaccz (&acc_3);
2019- if (std::is_same_v<TA, block_q4_0>) {
2025+
2026+ __builtin_prefetch ((A+(ii*lda)+(l+1 ))->qs , 0 , 1 ); // prefetch one loop ahead
2027+ __builtin_prefetch ((B+(jj*ldb)+(l+1 ))->qs , 0 , 1 ); // prefetch one loop ahead
2028+ if constexpr (isAblock_q4) {
20202029 packNormalInt4<8 >((A+(ii*lda)+l), lda, 8 , 4 , (int8_t *)vec_A, comparray);
20212030 } else {
20222031 packNormal<int8_t , vector signed char >((const block_q8_0*)(A+(ii*lda)+l), lda, 8 , 8 , (int8_t *)vec_A, false );
2032+ auto aoffset = A+(ii*lda)+l;
2033+ for (int i = 0 ; i < 8 ; i++) {
2034+ comparray[i] = 0 ;
2035+ int ca = 0 ;
2036+ auto *at = aoffset->qs ;
2037+ for (int j = 0 ; j < 32 ; j++)
2038+ ca += (int )*at++;
2039+ comparray[i] = ca;
2040+ aoffset += lda;
2041+ }
2042+
20232043 }
20242044 packNormal<uint8_t , vector unsigned char >((B+(jj*ldb)+l), ldb, 8 , 8 , (uint8_t *)vec_B, true );
20252045 for (int x = 0 ; x < 8 ; x++) {
@@ -2034,18 +2054,6 @@ class tinyBLAS_Q0_PPC {
20342054 *((float *)&vs[I+8 ]+J) = (unhalf ((A+((ii+I)*lda)+l)->d ) * unhalf ((B+((jj+J+4 )*ldb)+l)->d ));
20352055 }
20362056 }
2037- if (!isAblock_q4) {
2038- auto aoffset = A+(ii*lda)+l;
2039- for (int i = 0 ; i < 8 ; i++) {
2040- comparray[i] = 0 ;
2041- int ca = 0 ;
2042- auto *at = aoffset->qs ;
2043- for (int j = 0 ; j < 32 ; j++)
2044- ca += (int )*at++;
2045- comparray[i] = ca;
2046- aoffset += lda;
2047- }
2048- }
20492057 compute<8 >(&acc_0, 0 , 0 , comparray, vs, fin_res);
20502058 compute<8 >(&acc_1, 4 , 4 , comparray, vs, fin_res);
20512059 compute<8 >(&acc_2, 0 , 8 , comparray, vs, fin_res);
@@ -2056,6 +2064,83 @@ class tinyBLAS_Q0_PPC {
20562064 save_res (ii, jj+4 , 8 , fin_res);
20572065 save_res (ii+4 , jj+4 , 12 , fin_res);
20582066 }
2067+ void KERNEL_16x8 (int64_t ii, int64_t jj) {
2068+ vec_t vec_A[32 ], vec_B[16 ] = {0 }; // 16 rows × 2 blocks for A
2069+ acc_t acc[8 ]; // 8 accumulators
2070+ std::array<int , 16 > comparray {}; // 16 rows
2071+ vector float fin_res[32 ] = {0 }; // final results
2072+ vector float vs[32 ] = {0 }; // scale * B
2073+
2074+ constexpr bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2075+
2076+ for (int l = 0 ; l < k; l++) {
2077+ // Zero all 8 accumulators
2078+ for (int x = 0 ; x < 8 ; x++)
2079+ __builtin_mma_xxsetaccz (&acc[x]);
2080+
2081+ __builtin_prefetch ((A+(ii*lda)+(l+1 ))->qs , 0 , 1 ); // prefetch one loop ahead
2082+ __builtin_prefetch ((B+(jj*ldb)+(l+1 ))->qs , 0 , 1 ); // prefetch one loop ahead
2083+ // Pack A
2084+ if constexpr (isAblock_q4) {
2085+ packNormalInt4<16 >((A + (ii*lda) + l), lda, 16 , 4 , (int8_t *)vec_A, comparray);
2086+ } else {
2087+ packNormal<int8_t , vector signed char >((const block_q8_0*)(A + (ii*lda) + l), lda, 16 , 8 , (int8_t *)vec_A, false );
2088+ auto aoffset = A + (ii*lda) + l;
2089+ for (int i = 0 ; i < 16 ; i++) {
2090+ comparray[i] = 0 ;
2091+ int ca = 0 ;
2092+ auto *at = aoffset->qs ;
2093+ for (int j = 0 ; j < 32 ; j++) ca += (int )*at++;
2094+ comparray[i] = ca;
2095+ aoffset += lda;
2096+ }
2097+ }
2098+
2099+ // Pack B
2100+ packNormal<uint8_t , vector unsigned char >((B + (jj*ldb) + l), ldb, 8 , 8 , (uint8_t *)vec_B, true );
2101+
2102+ // MMA multiply: 16 rows × 8 cols → 4×4 quadrants for each 8×4
2103+ for (int x = 0 ; x < 8 ; x++) {
2104+ // top-left / top-right quadrants
2105+ __builtin_mma_xvi8ger4pp (&acc[0 ], vec_A[x], vec_B[x]); // 0–7 rows, 0–3 cols
2106+ __builtin_mma_xvi8ger4pp (&acc[2 ], vec_A[x], vec_B[x+8 ]); // 0–7 rows, 4–7 cols
2107+ // bottom-left / bottom-right quadrants
2108+ __builtin_mma_xvi8ger4pp (&acc[1 ], vec_A[x+8 ], vec_B[x]); // 8–15 rows, 0–3 cols
2109+ __builtin_mma_xvi8ger4pp (&acc[3 ], vec_A[x+8 ], vec_B[x+8 ]); // 8–15 rows, 4–7 cols
2110+ // extra quadrants (double-buffer or prefetch)
2111+ __builtin_mma_xvi8ger4pp (&acc[4 ], vec_A[x], vec_B[x]); // same as top-left, can be reused
2112+ __builtin_mma_xvi8ger4pp (&acc[6 ], vec_A[x], vec_B[x+8 ]); // same as top-right
2113+ __builtin_mma_xvi8ger4pp (&acc[5 ], vec_A[x+8 ], vec_B[x]); // same as bottom-left
2114+ __builtin_mma_xvi8ger4pp (&acc[7 ], vec_A[x+8 ], vec_B[x+8 ]); // same as bottom-right
2115+ }
2116+
2117+ // Compute vs: scale * B
2118+ for (int I = 0 ; I < 16 ; I++) {
2119+ float a_scale = unhalf ((A + ((ii+I)*lda) + l)->d );
2120+ for (int J = 0 ; J < 4 ; J++) {
2121+ *((float *)&vs[I] + J) = a_scale * unhalf ((B + ((jj+J)*ldb) + l)->d );
2122+ *((float *)&vs[I+16 ] + J) = a_scale * unhalf ((B + ((jj+J+4 )*ldb) + l)->d );
2123+ }
2124+ }
2125+
2126+ // Compute accumulators
2127+ compute<16 >(&acc[0 ], 0 , 0 , comparray, vs, fin_res);
2128+ compute<16 >(&acc[2 ], 0 , 4 , comparray, vs, fin_res);
2129+ compute<16 >(&acc[1 ], 8 , 0 , comparray, vs, fin_res);
2130+ compute<16 >(&acc[3 ], 8 , 4 , comparray, vs, fin_res);
2131+ compute<16 >(&acc[4 ], 0 , 0 , comparray, vs, fin_res);
2132+ compute<16 >(&acc[6 ], 0 , 4 , comparray, vs, fin_res);
2133+ compute<16 >(&acc[5 ], 8 , 0 , comparray, vs, fin_res);
2134+ compute<16 >(&acc[7 ], 8 , 4 , comparray, vs, fin_res);
2135+ }
2136+
2137+ // Save results
2138+ for (int r = 0 ; r < 16 ; r += 8 ) {
2139+ for (int c = 0 ; c < 8 ; c += 4 ) {
2140+ save_res (ii + r, jj + c, /* offset*/ r + c, fin_res);
2141+ }
2142+ }
2143+ }
20592144
20602145 void gemm_small (int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
20612146 int64_t ytiles = (m - m0) / RM;
@@ -2133,7 +2218,9 @@ class tinyBLAS_Q0_PPC {
21332218 KERNEL_8x4 (ii,jj);
21342219 } else if constexpr (RM == 8 && RN == 8 ) {
21352220 KERNEL_8x8 (ii,jj);
2136- } else {
2221+ } else if constexpr (RM == 16 && RN == 8 ) {
2222+ KERNEL_16x8 (ii,jj);
2223+ }else {
21372224 assert (false && " RN/RM values not supported" );
21382225 }
21392226 }
0 commit comments