@@ -119,6 +119,30 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
119119#if defined(__MMA__)
120120typedef vector unsigned char vec_t ;
121121typedef __vector_quad acc_t ;
122+
123+ #include < pthread.h>
124+
125+ typedef vector unsigned char vec_t ;
126+ typedef __vector_quad acc_t ;
127+
128+ static pthread_key_t t_data_key;
129+ typedef struct {
130+ vec_t * A_pack;
131+ vec_t * B_pack;
132+ int * comparray;
133+ } thread_scratchpad_t ;
134+ void thread_cleanup (void * arg) {
135+ thread_scratchpad_t * data = (thread_scratchpad_t *)arg;
136+ if (data) {
137+ delete[] data->A_pack ;
138+ delete[] data->B_pack ;
139+ delete[] data->comparray ;
140+
141+ delete data;
142+ }
143+ }
144+ static bool key_created = false ;
145+
122146#endif
123147// //////////////////////////////////////////////////////////////////////////////////////////////////
124148// VECTORIZED FUSED MULTIPLY ADD
@@ -621,6 +645,7 @@ class tinyBLAS_Q0_ARM {
621645 const block_q8_0 *const B;
622646 float *const C;
623647 const int64_t k;
648+ int64_t kc;
624649 const int64_t lda;
625650 const int64_t ldb;
626651 const int64_t ldc;
@@ -1582,34 +1607,73 @@ class tinyBLAS_Q0_PPC {
15821607 float *C, int64_t ldc,
15831608 int ith, int nth)
15841609 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1610+ kc = 64 ;
15851611 }
15861612
15871613 void matmul (int64_t m, int64_t n) {
1588- mnpack (0 , m, 0 , n);
1614+ int mc = 64 ; int nc = 64 ;
1615+ const bool is_aligned = ((m & (mc - 1 )) == 0 ) & ((n & (nc - 1 )) == 0 ) & ((k & (kc - 1 )) == 0 );
1616+ if (is_aligned) {
1617+ matmul_tiled (m, n, mc, nc, kc);
1618+ } else {
1619+ mnpack (0 , m, 0 , n);
1620+ }
15891621 }
15901622
15911623 private:
15921624
15931625 inline void save_res (int ii, int jj, int idx, vector float * fin_res, int RM=4 , int RN=4 ) {
1594- for (int I = 0 ; I < RM; I++) {
1595- for (int J = 0 ; J < RN; J++) {
1596- *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&fin_res[idx+I]+J);
1597- }
1598- }
1626+ for (int I = 0 ; I < RM; I++) {
1627+ for (int J = 0 ; J < RN; J++) {
1628+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&fin_res[idx+I]+J);
1629+ }
1630+ }
1631+ }
1632+
1633+ inline void add_save_res (int ii, int jj, int idx, vector float * fin_res, int RM=4 , int RN=4 ) {
1634+ for (int I = 0 ; I < RM; I++) {
1635+ for (int J = 0 ; J < RN; J++) {
1636+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
1637+ *c_ptr += *((float *)&fin_res[idx+I]+J);
1638+ }
1639+ }
15991640 }
16001641
16011642 template <int size>
16021643 inline void compute (acc_t * ACC, int c_idx, int s_idx, std::array<int , size>& comparray, vector float * vs, vector float * fin_res) {
1603- vector signed int vec_C[4 ];
1604- vector float CA[4 ] = {0 };
1605- vector float res[4 ] = {0 };
1606- __builtin_mma_disassemble_acc (vec_C, ACC);
1607- for (int i = 0 ; i < 4 ; i++) {
1608- CA[i] = vec_splats ((float )(((double )comparray[c_idx+i]) * -128.0 ));
1609- res[i] = vec_add (vec_ctf (vec_C[i], 0 ), CA[i]);
1610- fin_res[s_idx+i] = vec_madd (res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611- }
1644+ vector signed int vec_C[4 ];
1645+ vector float CA[4 ] = {0 };
1646+ vector float res[4 ] = {0 };
1647+ __builtin_mma_disassemble_acc (vec_C, ACC);
1648+ for (int i = 0 ; i < 4 ; i++) {
1649+ CA[i] = vec_splats ((float )(((double )comparray[c_idx+i]) * -128.0 ));
1650+ res[i] = vec_add (vec_ctf (vec_C[i], 0 ), CA[i]);
1651+ fin_res[s_idx+i] = vec_madd (res[i], vs[s_idx+i], fin_res[s_idx+i]);
1652+ }
1653+ }
1654+
1655+ inline void compute_array (acc_t * ACC, int c_idx, int s_idx, int * comparray, vector float * vs, vector float * fin_res) {
1656+ vector signed int vec_C[4 ];
1657+ vector float CA[4 ] = {0 };
1658+ vector float res[4 ] = {0 };
1659+ __builtin_mma_disassemble_acc (vec_C, ACC);
1660+ for (int i = 0 ; i < 4 ; i++) {
1661+ CA[i] = vec_splats ((float )(((double )comparray[c_idx+i]) * -128.0 ));
1662+ res[i] = vec_add (vec_ctf (vec_C[i], 0 ), CA[i]);
1663+ fin_res[s_idx+i] = vec_madd (res[i], vs[s_idx+i], fin_res[s_idx+i]);
1664+ }
1665+ }
1666+
1667+ void compute_scale (int64_t ii, int64_t jj, int blk, vector float * vs){
1668+ for (int I = 0 ; I<8 ; I++) {
1669+ float a_scale = unhalf ((A+((ii+I)*lda)+blk)->d );
1670+ for (int J = 0 ; J<4 ; J++) {
1671+ *((float *)&vs[I]+J) = (a_scale * unhalf ((B+((jj+J)*ldb)+blk)->d ));
1672+ *((float *)&vs[I+8 ]+J) = (a_scale * unhalf ((B+((jj+J+4 )*ldb)+blk)->d ));
1673+ }
1674+ }
16121675 }
1676+
16131677 /* This function processes quantized data from block_q4_0 elements.
16141678 * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
16151679 * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
@@ -1630,6 +1694,18 @@ class tinyBLAS_Q0_PPC {
16301694 *(ca) = vsum[0 ] + vsum[1 ] + vsum[2 ] + vsum[3 ];
16311695 }
16321696
1697+ inline void process_q8_elements (const int8_t *qs, int *ca) {
1698+ vector signed char c1 = vec_xl (0 , qs);
1699+ vector signed char c2 = vec_xl (16 , qs);
1700+ vector signed int vsum1 = {0 };
1701+ vector signed int vsum2 = {0 };
1702+ vsum1 = vec_sum4s (c1, vsum1);
1703+ vsum2 = vec_sum4s (c2, vsum2);
1704+ vector signed int vsum = vec_add (vsum1, vsum2);
1705+ *ca = vsum[0 ] + vsum[1 ] + vsum[2 ] + vsum[3 ];
1706+ }
1707+
1708+
16331709 template <typename V1, typename V2>
16341710 inline void vector_permute_store (V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
16351711 vector unsigned char swiz1 = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 };
@@ -1877,6 +1953,98 @@ class tinyBLAS_Q0_PPC {
18771953 }
18781954 }
18791955 }
1956+ template <typename VA, typename VB>
1957+ void packNormal_large (const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int * comparray=nullptr ) {
1958+ int64_t i, j;
1959+ block_q8_0 *aoffset = NULL ;
1960+ VA *vecOffset = NULL ;
1961+ block_q8_0* aoffsets[8 ];
1962+ __vector_pair arr[8 ];
1963+ VB c[8 ][2 ] = {0 };
1964+ VB c1[8 ] = {0 }; VB c2[8 ] = {0 };
1965+ aoffset = const_cast <block_q8_0*>(a);
1966+ vecOffset = vec;
1967+ j = (rows >> 3 );
1968+ int index = 0 ;
1969+ if (j > 0 ) {
1970+ do {
1971+ for (int it = 0 ; it < 8 ; it++)
1972+ aoffsets[it] = aoffset + it*lda;
1973+ aoffset += 8 * lda;
1974+ for (int blk = 0 ; blk < kc; blk++) {
1975+ for (int it = 0 ; it < 8 ; it++) {
1976+ arr[it] = __builtin_vsx_lxvp (0 , (__vector_pair*)(aoffsets[it]+blk)->qs );
1977+ __builtin_vsx_disassemble_pair (c[it], &arr[it]);
1978+ c1[it] = c[it][0 ];
1979+ c2[it] = c[it][1 ];
1980+ if (comparray){
1981+ process_q8_elements ((aoffsets[it]+ blk)->qs , &comparray[index + 8 *blk + it]);
1982+ }
1983+ }
1984+ vector_permute_store<VA, VB>(c1[0 ], c1[1 ], c1[2 ], c1[3 ], vecOffset, flip);
1985+ vector_permute_store<VA, VB>(c2[0 ], c2[1 ], c2[2 ], c2[3 ], vecOffset+64 , flip);
1986+ vector_permute_store<VA, VB>(c1[4 ], c1[5 ], c1[6 ], c1[7 ], vecOffset+128 , flip);
1987+ vector_permute_store<VA, VB>(c2[4 ], c2[5 ], c2[6 ], c2[7 ], vecOffset+192 , flip);
1988+ vecOffset += 256 ;
1989+ }
1990+ j--;
1991+ index += 8 *kc;
1992+ } while (j > 0 );
1993+ }
1994+ }
1995+
1996+ void packNormalInt4_large (const TA* a, int64_t lda, int rows, int cols, int8_t * vec, int *comparray) {
1997+ int64_t i, j;
1998+ TA *aoffset = NULL ;
1999+ int8_t *vecOffset = NULL ;
2000+ TA *aoffset1 = NULL , *aoffset2 = NULL , *aoffset3 = NULL , *aoffset4 = NULL ;
2001+ TA *aoffset5 = NULL , *aoffset6 = NULL , *aoffset7 = NULL , *aoffset8 = NULL ;
2002+ vector signed char c1[2 ] = {0 }, c2[2 ] = {0 }, c3[2 ] = {0 }, c4[2 ] = {0 };
2003+ vector signed char c5[2 ] = {0 }, c6[2 ] = {0 }, c7[2 ] = {0 }, c8[2 ] = {0 };
2004+ aoffset = const_cast <TA*>(a);
2005+ vecOffset = vec;
2006+ int index = 0 ;
2007+ j = (rows >> 3 );
2008+ if (j > 0 ) {
2009+ do {
2010+ aoffset1 = aoffset;
2011+ aoffset2 = aoffset1 + lda;
2012+ aoffset3 = aoffset2 + lda;
2013+ aoffset4 = aoffset3 + lda;
2014+ aoffset5 = aoffset4 + lda;
2015+ aoffset6 = aoffset5 + lda;
2016+ aoffset7 = aoffset6 + lda;
2017+ aoffset8 = aoffset7 + lda;
2018+ aoffset += 8 * lda;
2019+ for (int blk = 0 ; blk < kc; blk++) {
2020+ c1[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset1+blk)->qs ));
2021+ c2[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset2+blk)->qs ));
2022+ c3[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset3+blk)->qs ));
2023+ c4[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset4+blk)->qs ));
2024+ c5[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset5+blk)->qs ));
2025+ c6[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset6+blk)->qs ));
2026+ c7[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset7+blk)->qs ));
2027+ c8[1 ] = reinterpret_cast <vector signed char >(vec_xl (0 , (aoffset8+blk)->qs ));
2028+
2029+ process_q4_elements (c1, &comparray[index + 8 *blk+0 ]);
2030+ process_q4_elements (c2, &comparray[index + 8 *blk+1 ]);
2031+ process_q4_elements (c3, &comparray[index + 8 *blk+2 ]);
2032+ process_q4_elements (c4, &comparray[index + 8 *blk+3 ]);
2033+ process_q4_elements (c5, &comparray[index + 8 *blk+4 ]);
2034+ process_q4_elements (c6, &comparray[index + 8 *blk+5 ]);
2035+ process_q4_elements (c7, &comparray[index + 8 *blk+6 ]);
2036+ process_q4_elements (c8, &comparray[index + 8 *blk+7 ]);
2037+ vector_permute_store<int8_t , vector signed char >(c1[0 ], c2[0 ], c3[0 ], c4[0 ], vecOffset, false );
2038+ vector_permute_store<int8_t , vector signed char >(c1[1 ], c2[1 ], c3[1 ], c4[1 ], vecOffset+64 , false );
2039+ vector_permute_store<int8_t , vector signed char >(c5[0 ], c6[0 ], c7[0 ], c8[0 ], vecOffset+128 , false );
2040+ vector_permute_store<int8_t , vector signed char >(c5[1 ], c6[1 ], c7[1 ], c8[1 ], vecOffset+192 , false );
2041+ vecOffset += 256 ;
2042+ }
2043+ j--;
2044+ index += 8 *kc;
2045+ } while (j > 0 );
2046+ }
2047+ }
18802048
18812049 void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
18822050 int m_rem = MIN (m - m0, 16 );
@@ -2057,6 +2225,95 @@ class tinyBLAS_Q0_PPC {
20572225 save_res (ii+4 , jj+4 , 12 , fin_res);
20582226 }
20592227
2228+ void KERNEL_Q0 (int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
2229+ acc_t acc[4 ];
2230+ for (int i = 0 ; i < mc ; i += 8 ) {
2231+ for (int j = 0 ; j < nc; j += 8 ) {
2232+ vector float fin_res[16 ] = {0 };
2233+ vector float vs[16 ] = {0 };
2234+ for (int64_t kk = 0 ; kk < kc; kk++) {
2235+ for (int x = 0 ; x < 4 ; x++) {
2236+ __builtin_mma_xxsetaccz (&acc[x]);
2237+ }
2238+ int A_block_idx = (i/8 )*(16 *kc) + kk*16 ;
2239+ int B_block_idx = (j/8 )*(16 *kc)+ kk*16 ;
2240+ vec_t *A_block = &vec_A[A_block_idx];
2241+ vec_t *B_block = &vec_B[B_block_idx];
2242+ for (int x = 0 ; x < 8 ; x++) {
2243+ __builtin_mma_xvi8ger4pp (&acc[0 ], A_block[x], B_block[x]);
2244+ __builtin_mma_xvi8ger4pp (&acc[1 ], A_block[x + 8 ], B_block[x]);
2245+ __builtin_mma_xvi8ger4pp (&acc[2 ], A_block[x], B_block[x+8 ]);
2246+ __builtin_mma_xvi8ger4pp (&acc[3 ], A_block[x+8 ], B_block[x+8 ]);
2247+ }
2248+ compute_scale (ii+i, jj+j, l+kk, vs);
2249+ int c_index = (i/8 )*(8 *kc)+ kk*8 ;
2250+ int * c_block = &comparray[c_index];
2251+ compute_array (&acc[0 ], 0 , 0 , c_block, vs, fin_res);
2252+ compute_array (&acc[1 ], 4 , 4 , c_block, vs, fin_res);
2253+ compute_array (&acc[2 ], 0 , 8 , c_block, vs, fin_res);
2254+ compute_array (&acc[3 ], 4 , 12 , c_block, vs, fin_res);
2255+ }
2256+ if (l == 0 ) {
2257+ save_res (ii+i, jj+j, 0 , fin_res);
2258+ save_res (ii+i+4 , jj+j, 4 , fin_res);
2259+ save_res (ii+i, jj+j+4 , 8 , fin_res);
2260+ save_res (ii+i+4 , jj+j+4 , 12 , fin_res);
2261+ } else {
2262+ add_save_res (ii+i, jj+j, 0 , fin_res);
2263+ add_save_res (ii+i+4 , jj+j, 4 , fin_res);
2264+ add_save_res (ii+i, jj+j+4 , 8 , fin_res);
2265+ add_save_res (ii+i+4 , jj+j+4 , 12 , fin_res);
2266+ }
2267+ }
2268+ }
2269+ }
2270+
2271+ void matmul_tiled (int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2272+ if (!key_created) {
2273+ if (pthread_key_create (&t_data_key, thread_cleanup) == 0 ) {
2274+ key_created = true ;
2275+ } else {
2276+ return ;
2277+ }
2278+ }
2279+ thread_scratchpad_t * t_data = (thread_scratchpad_t *)pthread_getspecific (t_data_key);
2280+ if (t_data == nullptr ) {
2281+ t_data = new thread_scratchpad_t ;
2282+ t_data->A_pack = new vec_t [mc * kc * 2 ];
2283+ t_data->B_pack = new vec_t [nc * kc * 2 ];
2284+ t_data->comparray = new int [mc * kc];
2285+ pthread_setspecific (t_data_key, t_data);
2286+ }
2287+ vec_t * A_pack = t_data->A_pack ;
2288+ vec_t * B_pack = t_data->B_pack ;
2289+ int * comparray = t_data->comparray ;
2290+
2291+ constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
2292+ int64_t ytiles = m / mc;
2293+ int64_t xtiles = n / nc;
2294+ int64_t tiles = xtiles * ytiles;
2295+ int64_t duty = (tiles + nth - 1 ) / nth;
2296+ int64_t start = duty * ith;
2297+ int64_t end = start + duty;
2298+ if (end > tiles) {
2299+ end = tiles;
2300+ }
2301+ for (int64_t job = start; job < end; ++job) {
2302+ int64_t ii = (job / xtiles) * mc;
2303+ int64_t jj = (job % xtiles) * nc;
2304+ for (int64_t kk = 0 ; kk < k; kk += kc) {
2305+ if constexpr (is_Ablock_q4) {
2306+ packNormalInt4_large (A + ii*lda + kk, lda, mc, 4 , (int8_t *)A_pack, comparray);
2307+ } else {
2308+ packNormal_large<int8_t , vector signed char >(A + ii*lda + kk, lda, mc, 8 , (int8_t *)A_pack, false , comparray);
2309+ }
2310+ packNormal_large<uint8_t , vector unsigned char >(B + jj*ldb + kk, ldb, nc, 8 , (uint8_t *)B_pack, true );
2311+ KERNEL_Q0 (ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
2312+ }
2313+ }
2314+ }
2315+
2316+
20602317 void gemm_small (int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
20612318 int64_t ytiles = (m - m0) / RM;
20622319 int64_t xtiles = (n - n0) / RN;
@@ -2159,6 +2416,7 @@ class tinyBLAS_Q0_PPC {
21592416 const block_q8_0 *const B;
21602417 float *C;
21612418 const int64_t k;
2419+ int64_t kc;
21622420 const int64_t lda;
21632421 const int64_t ldb;
21642422 const int64_t ldc;
0 commit comments