@@ -119,6 +119,30 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
119119#if defined(__MMA__)
120120typedef vector unsigned char vec_t ;
121121typedef __vector_quad acc_t ;
122+
123+ #include < pthread.h>
124+
125+ typedef vector unsigned char vec_t ;
126+ typedef __vector_quad acc_t ;
127+
128+ static pthread_key_t t_data_key;
129+ typedef struct {
130+ vec_t * A_pack;
131+ vec_t * B_pack;
132+ int * comparray;
133+ } thread_scratchpad_t ;
134+ void thread_cleanup (void * arg) {
135+ thread_scratchpad_t * data = (thread_scratchpad_t *)arg;
136+ if (data) {
137+ delete[] data->A_pack ;
138+ delete[] data->B_pack ;
139+ delete[] data->comparray ;
140+
141+ delete data;
142+ }
143+ }
144+ static bool key_created = false ;
145+
122146#endif
123147// //////////////////////////////////////////////////////////////////////////////////////////////////
124148// VECTORIZED FUSED MULTIPLY ADD
@@ -621,6 +645,7 @@ class tinyBLAS_Q0_ARM {
621645 const block_q8_0 *const B;
622646 float *const C;
623647 const int64_t k;
648+ int64_t kc;
624649 const int64_t lda;
625650 const int64_t ldb;
626651 const int64_t ldc;
@@ -1582,34 +1607,61 @@ class tinyBLAS_Q0_PPC {
15821607 float *C, int64_t ldc,
15831608 int ith, int nth)
15841609 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1610+ kc = 64 ;
15851611 }
15861612
15871613 void matmul (int64_t m, int64_t n) {
1588- mnpack (0 , m, 0 , n);
1614+ int mc = 64 ; int nc = 64 ;
1615+ const bool is_aligned = ((m & (mc - 1 )) == 0 ) & ((n & (nc - 1 )) == 0 ) & ((k & (kc - 1 )) == 0 );
1616+ if (is_aligned) {
1617+ matmul_tiled (m, n, mc, nc, kc);
1618+ } else {
1619+ mnpack (0 , m, 0 , n);
1620+ }
15891621 }
15901622
15911623 private:
15921624
15931625 inline void save_res (int ii, int jj, int idx, vector float * fin_res, int RM=4 , int RN=4 ) {
1594- for (int I = 0 ; I < RM; I++) {
1595- for (int J = 0 ; J < RN; J++) {
1596- *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&fin_res[idx+I]+J);
1597- }
1598- }
1626+ for (int I = 0 ; I < RM; I++) {
1627+ for (int J = 0 ; J < RN; J++) {
1628+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&fin_res[idx+I]+J);
1629+ }
1630+ }
15991631 }
16001632
1601- template <int size>
1602- inline void compute (acc_t * ACC, int c_idx, int s_idx, std::array<int , size>& comparray, vector float * vs, vector float * fin_res) {
1603- vector signed int vec_C[4 ];
1604- vector float CA[4 ] = {0 };
1605- vector float res[4 ] = {0 };
1606- __builtin_mma_disassemble_acc (vec_C, ACC);
1607- for (int i = 0 ; i < 4 ; i++) {
1608- CA[i] = vec_splats ((float )(((double )comparray[c_idx+i]) * -128.0 ));
1609- res[i] = vec_add (vec_ctf (vec_C[i], 0 ), CA[i]);
1610- fin_res[s_idx+i] = vec_madd (res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611- }
1633+ inline void add_save_res (int ii, int jj, int idx, vector float * fin_res, int RM=4 , int RN=4 ) {
1634+ for (int I = 0 ; I < RM; I++) {
1635+ for (int J = 0 ; J < RN; J++) {
1636+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
1637+ *c_ptr += *((float *)&fin_res[idx+I]+J);
1638+ }
1639+ }
16121640 }
1641+
1642+ template <typename ArrayType>
1643+ inline void compute (acc_t * ACC, int c_idx, int s_idx, ArrayType& comparray, vector float * vs, vector float * fin_res) {
1644+ vector signed int vec_C[4 ];
1645+ vector float CA[4 ] = {0 };
1646+ vector float res[4 ] = {0 };
1647+ __builtin_mma_disassemble_acc (vec_C, ACC);
1648+ for (int i = 0 ; i < 4 ; i++) {
1649+ CA[i] = vec_splats ((float )(((double )comparray[c_idx+i]) * -128.0 ));
1650+ res[i] = vec_add (vec_ctf (vec_C[i], 0 ), CA[i]);
1651+ fin_res[s_idx+i] = vec_madd (res[i], vs[s_idx+i], fin_res[s_idx+i]);
1652+ }
1653+ }
1654+
1655+ void compute_scale (int64_t ii, int64_t jj, int blk, vector float * vs){
1656+ for (int I = 0 ; I<8 ; I++) {
1657+ float a_scale = unhalf ((A+((ii+I)*lda)+blk)->d );
1658+ for (int J = 0 ; J<4 ; J++) {
1659+ *((float *)&vs[I]+J) = (a_scale * unhalf ((B+((jj+J)*ldb)+blk)->d ));
1660+ *((float *)&vs[I+8 ]+J) = (a_scale * unhalf ((B+((jj+J+4 )*ldb)+blk)->d ));
1661+ }
1662+ }
1663+ }
1664+
16131665 /* This function processes quantized data from block_q4_0 elements.
16141666 * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
16151667 * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
@@ -1630,6 +1682,18 @@ class tinyBLAS_Q0_PPC {
16301682 *(ca) = vsum[0 ] + vsum[1 ] + vsum[2 ] + vsum[3 ];
16311683 }
16321684
1685+ inline void process_q8_elements (const int8_t *qs, int *ca) {
1686+ vector signed char c1 = vec_xl (0 , qs);
1687+ vector signed char c2 = vec_xl (16 , qs);
1688+ vector signed int vsum1 = {0 };
1689+ vector signed int vsum2 = {0 };
1690+ vsum1 = vec_sum4s (c1, vsum1);
1691+ vsum2 = vec_sum4s (c2, vsum2);
1692+ vector signed int vsum = vec_add (vsum1, vsum2);
1693+ *ca = vsum[0 ] + vsum[1 ] + vsum[2 ] + vsum[3 ];
1694+ }
1695+
1696+
16331697 template <typename V1, typename V2>
16341698 inline void vector_permute_store (V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
16351699 vector unsigned char swiz1 = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 };
@@ -1877,6 +1941,98 @@ class tinyBLAS_Q0_PPC {
18771941 }
18781942 }
18791943 }
1944+ template <typename VA, typename VB>
1945+ void packNormal_large (const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int * comparray=nullptr ) {
1946+ int64_t i, j;
1947+ block_q8_0 *aoffset = NULL ;
1948+ VA *vecOffset = NULL ;
1949+ block_q8_0* aoffsets[8 ];
1950+ __vector_pair arr[8 ];
1951+ VB c[8 ][2 ] = {0 };
1952+ VB c1[8 ] = {0 }; VB c2[8 ] = {0 };
1953+ aoffset = const_cast <block_q8_0*>(a);
1954+ vecOffset = vec;
1955+ j = (rows >> 3 );
1956+ int index = 0 ;
1957+ if (j > 0 ) {
1958+ do {
1959+ for (int it = 0 ; it < 8 ; it++)
1960+ aoffsets[it] = aoffset + it*lda;
1961+ aoffset += 8 * lda;
1962+ for (int blk = 0 ; blk < kc; blk++) {
1963+ for (int it = 0 ; it < 8 ; it++) {
1964+ arr[it] = __builtin_vsx_lxvp (0 , (__vector_pair*)(aoffsets[it]+blk)->qs );
1965+ __builtin_vsx_disassemble_pair (c[it], &arr[it]);
1966+ c1[it] = c[it][0 ];
1967+ c2[it] = c[it][1 ];
1968+ if (comparray){
1969+ process_q8_elements ((aoffsets[it]+ blk)->qs , &comparray[index + 8 *blk + it]);
1970+ }
1971+ }
1972+ vector_permute_store<VA, VB>(c1[0 ], c1[1 ], c1[2 ], c1[3 ], vecOffset, flip);
1973+ vector_permute_store<VA, VB>(c2[0 ], c2[1 ], c2[2 ], c2[3 ], vecOffset+64 , flip);
1974+ vector_permute_store<VA, VB>(c1[4 ], c1[5 ], c1[6 ], c1[7 ], vecOffset+128 , flip);
1975+ vector_permute_store<VA, VB>(c2[4 ], c2[5 ], c2[6 ], c2[7 ], vecOffset+192 , flip);
1976+ vecOffset += 256 ;
1977+ }
1978+ j--;
1979+ index += 8 *kc;
1980+ } while (j > 0 );
1981+ }
1982+ }
1983+
1984+ void packNormalInt4_large (const TA* a, int64_t lda, int rows, int cols, int8_t * vec, int *comparray) {
1985+ int64_t i, j;
1986+ TA *aoffset = NULL ;
1987+ int8_t *vecOffset = NULL ;
1988+ TA *aoffset1 = NULL , *aoffset2 = NULL , *aoffset3 = NULL , *aoffset4 = NULL ;
1989+ TA *aoffset5 = NULL , *aoffset6 = NULL , *aoffset7 = NULL , *aoffset8 = NULL ;
1990+ vector signed char c1[2 ] = {0 }, c2[2 ] = {0 }, c3[2 ] = {0 }, c4[2 ] = {0 };
1991+ vector signed char c5[2 ] = {0 }, c6[2 ] = {0 }, c7[2 ] = {0 }, c8[2 ] = {0 };
1992+ aoffset = const_cast <TA*>(a);
1993+ vecOffset = vec;
1994+ int index = 0 ;
1995+ j = (rows >> 3 );
1996+ if (j > 0 ) {
1997+ do {
1998+ aoffset1 = aoffset;
1999+ aoffset2 = aoffset1 + lda;
2000+ aoffset3 = aoffset2 + lda;
2001+ aoffset4 = aoffset3 + lda;
2002+ aoffset5 = aoffset4 + lda;
2003+ aoffset6 = aoffset5 + lda;
2004+ aoffset7 = aoffset6 + lda;
2005+ aoffset8 = aoffset7 + lda;
2006+ aoffset += 8 * lda;
2007+ for (int blk = 0 ; blk < kc; blk++) {
2008+ c1[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset1+blk)->qs ));
2009+ c2[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset2+blk)->qs ));
2010+ c3[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset3+blk)->qs ));
2011+ c4[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset4+blk)->qs ));
2012+ c5[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset5+blk)->qs ));
2013+ c6[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset6+blk)->qs ));
2014+ c7[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset7+blk)->qs ));
2015+ c8[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset8+blk)->qs ));
2016+
2017+ process_q4_elements (c1, &comparray[index + 8 *blk+0 ]);
2018+ process_q4_elements (c2, &comparray[index + 8 *blk+1 ]);
2019+ process_q4_elements (c3, &comparray[index + 8 *blk+2 ]);
2020+ process_q4_elements (c4, &comparray[index + 8 *blk+3 ]);
2021+ process_q4_elements (c5, &comparray[index + 8 *blk+4 ]);
2022+ process_q4_elements (c6, &comparray[index + 8 *blk+5 ]);
2023+ process_q4_elements (c7, &comparray[index + 8 *blk+6 ]);
2024+ process_q4_elements (c8, &comparray[index + 8 *blk+7 ]);
2025+ vector_permute_store<int8_t , vector signed char >(c1[0 ], c2[0 ], c3[0 ], c4[0 ], vecOffset, false );
2026+ vector_permute_store<int8_t , vector signed char >(c1[1 ], c2[1 ], c3[1 ], c4[1 ], vecOffset+64 , false );
2027+ vector_permute_store<int8_t , vector signed char >(c5[0 ], c6[0 ], c7[0 ], c8[0 ], vecOffset+128 , false );
2028+ vector_permute_store<int8_t , vector signed char >(c5[1 ], c6[1 ], c7[1 ], c8[1 ], vecOffset+192 , false );
2029+ vecOffset += 256 ;
2030+ }
2031+ j--;
2032+ index += 8 *kc;
2033+ } while (j > 0 );
2034+ }
2035+ }
18802036
18812037 void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
18822038 int m_rem = MIN (m - m0, 16 );
@@ -1953,8 +2109,8 @@ class tinyBLAS_Q0_PPC {
19532109 aoffset += lda;
19542110 }
19552111 }
1956- compute< 4 > (&acc_0, 0 , 0 , comparray, vs, fin_res);
1957- compute< 4 > (&acc_1, 0 , 4 , comparray, vs, fin_res);
2112+ compute (&acc_0, 0 , 0 , comparray, vs, fin_res);
2113+ compute (&acc_1, 0 , 4 , comparray, vs, fin_res);
19582114 }
19592115 save_res (ii, jj, 0 , fin_res);
19602116 save_res (ii, jj+4 , 4 , fin_res);
@@ -1997,8 +2153,8 @@ class tinyBLAS_Q0_PPC {
19972153 aoffset += lda;
19982154 }
19992155 }
2000- compute< 8 > (&acc_0, 0 , 0 , comparray, vs, fin_res);
2001- compute< 8 > (&acc_1, 4 , 4 , comparray, vs, fin_res);
2156+ compute (&acc_0, 0 , 0 , comparray, vs, fin_res);
2157+ compute (&acc_1, 4 , 4 , comparray, vs, fin_res);
20022158 }
20032159 save_res (ii, jj, 0 , fin_res);
20042160 save_res (ii+4 , jj, 4 , fin_res);
@@ -2046,17 +2202,106 @@ class tinyBLAS_Q0_PPC {
20462202 aoffset += lda;
20472203 }
20482204 }
2049- compute< 8 > (&acc_0, 0 , 0 , comparray, vs, fin_res);
2050- compute< 8 > (&acc_1, 4 , 4 , comparray, vs, fin_res);
2051- compute< 8 > (&acc_2, 0 , 8 , comparray, vs, fin_res);
2052- compute< 8 > (&acc_3, 4 , 12 , comparray, vs, fin_res);
2205+ compute (&acc_0, 0 , 0 , comparray, vs, fin_res);
2206+ compute (&acc_1, 4 , 4 , comparray, vs, fin_res);
2207+ compute (&acc_2, 0 , 8 , comparray, vs, fin_res);
2208+ compute (&acc_3, 4 , 12 , comparray, vs, fin_res);
20532209 }
20542210 save_res (ii, jj, 0 , fin_res);
20552211 save_res (ii+4 , jj, 4 , fin_res);
20562212 save_res (ii, jj+4 , 8 , fin_res);
20572213 save_res (ii+4 , jj+4 , 12 , fin_res);
20582214 }
20592215
2216+ void KERNEL_Q0 (int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
2217+ acc_t acc[4 ];
2218+ for (int i = 0 ; i < mc ; i += 8 ) {
2219+ for (int j = 0 ; j < nc; j += 8 ) {
2220+ vector float fin_res[16 ] = {0 };
2221+ vector float vs[16 ] = {0 };
2222+ for (int64_t kk = 0 ; kk < kc; kk++) {
2223+ for (int x = 0 ; x < 4 ; x++) {
2224+ __builtin_mma_xxsetaccz (&acc[x]);
2225+ }
2226+ int A_block_idx = (i/8 )*(16 *kc) + kk*16 ;
2227+ int B_block_idx = (j/8 )*(16 *kc)+ kk*16 ;
2228+ vec_t *A_block = &vec_A[A_block_idx];
2229+ vec_t *B_block = &vec_B[B_block_idx];
2230+ for (int x = 0 ; x < 8 ; x++) {
2231+ __builtin_mma_xvi8ger4pp (&acc[0 ], A_block[x], B_block[x]);
2232+ __builtin_mma_xvi8ger4pp (&acc[1 ], A_block[x + 8 ], B_block[x]);
2233+ __builtin_mma_xvi8ger4pp (&acc[2 ], A_block[x], B_block[x+8 ]);
2234+ __builtin_mma_xvi8ger4pp (&acc[3 ], A_block[x+8 ], B_block[x+8 ]);
2235+ }
2236+ compute_scale (ii+i, jj+j, l+kk, vs);
2237+ int c_index = (i/8 )*(8 *kc)+ kk*8 ;
2238+ int * c_block = &comparray[c_index];
2239+ compute (&acc[0 ], 0 , 0 , c_block, vs, fin_res);
2240+ compute (&acc[1 ], 4 , 4 , c_block, vs, fin_res);
2241+ compute (&acc[2 ], 0 , 8 , c_block, vs, fin_res);
2242+ compute (&acc[3 ], 4 , 12 , c_block, vs, fin_res);
2243+ }
2244+ if (l == 0 ) {
2245+ save_res (ii+i, jj+j, 0 , fin_res);
2246+ save_res (ii+i+4 , jj+j, 4 , fin_res);
2247+ save_res (ii+i, jj+j+4 , 8 , fin_res);
2248+ save_res (ii+i+4 , jj+j+4 , 12 , fin_res);
2249+ } else {
2250+ add_save_res (ii+i, jj+j, 0 , fin_res);
2251+ add_save_res (ii+i+4 , jj+j, 4 , fin_res);
2252+ add_save_res (ii+i, jj+j+4 , 8 , fin_res);
2253+ add_save_res (ii+i+4 , jj+j+4 , 12 , fin_res);
2254+ }
2255+ }
2256+ }
2257+ }
2258+
2259+ void matmul_tiled (int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2260+ if (!key_created) {
2261+ if (pthread_key_create (&t_data_key, thread_cleanup) == 0 ) {
2262+ key_created = true ;
2263+ } else {
2264+ return ;
2265+ }
2266+ }
2267+ thread_scratchpad_t * t_data = (thread_scratchpad_t *)pthread_getspecific (t_data_key);
2268+ if (t_data == nullptr ) {
2269+ t_data = new thread_scratchpad_t ;
2270+ t_data->A_pack = new vec_t [mc * kc * 2 ];
2271+ t_data->B_pack = new vec_t [nc * kc * 2 ];
2272+ t_data->comparray = new int [mc * kc];
2273+ pthread_setspecific (t_data_key, t_data);
2274+ }
2275+ vec_t * A_pack = t_data->A_pack ;
2276+ vec_t * B_pack = t_data->B_pack ;
2277+ int * comparray = t_data->comparray ;
2278+
2279+ constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
2280+ int64_t ytiles = m / mc;
2281+ int64_t xtiles = n / nc;
2282+ int64_t tiles = xtiles * ytiles;
2283+ int64_t duty = (tiles + nth - 1 ) / nth;
2284+ int64_t start = duty * ith;
2285+ int64_t end = start + duty;
2286+ if (end > tiles) {
2287+ end = tiles;
2288+ }
2289+ for (int64_t job = start; job < end; ++job) {
2290+ int64_t ii = (job / xtiles) * mc;
2291+ int64_t jj = (job % xtiles) * nc;
2292+ for (int64_t kk = 0 ; kk < k; kk += kc) {
2293+ if constexpr (is_Ablock_q4) {
2294+ packNormalInt4_large (A + ii*lda + kk, lda, mc, 4 , (int8_t *)A_pack, comparray);
2295+ } else {
2296+ packNormal_large<int8_t , vector signed char >(A + ii*lda + kk, lda, mc, 8 , (int8_t *)A_pack, false , comparray);
2297+ }
2298+ packNormal_large<uint8_t , vector unsigned char >(B + jj*ldb + kk, ldb, nc, 8 , (uint8_t *)B_pack, true );
2299+ KERNEL_Q0 (ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
2300+ }
2301+ }
2302+ }
2303+
2304+
20602305 void gemm_small (int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
20612306 int64_t ytiles = (m - m0) / RM;
20622307 int64_t xtiles = (n - n0) / RN;
@@ -2159,6 +2404,7 @@ class tinyBLAS_Q0_PPC {
21592404 const block_q8_0 *const B;
21602405 float *C;
21612406 const int64_t k;
2407+ int64_t kc;
21622408 const int64_t lda;
21632409 const int64_t ldb;
21642410 const int64_t ldc;
0 commit comments