@@ -1598,8 +1598,8 @@ class tinyBLAS_Q0_PPC {
15981598 }
15991599 }
16001600
1601- template <int size>
1602- inline void compute (acc_t * ACC, int c_idx, int s_idx, std::array< int , size>& comparray, vector float * vs, vector float * fin_res) {
1601+ // template<int size>
1602+ inline void compute (acc_t * ACC, int c_idx, int s_idx, int * comparray, vector float * vs, vector float * fin_res) {
16031603 vector signed int vec_C[4 ];
16041604 vector float CA[4 ] = {0 };
16051605 vector float res[4 ] = {0 };
@@ -1660,8 +1660,9 @@ class tinyBLAS_Q0_PPC {
16601660 vec_xst (t8, 0 , vecOffset+48 );
16611661 }
16621662
1663- template <int size>
1664- void packNormalInt4 (const TA* a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int , size>& comparray) {
1663+ // template<int size>
1664+ // void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1665+ void packNormalInt4 (const TA* a, int64_t lda, int rows, int cols, int8_t * vec, int * comparray) {
16651666 int64_t i, j;
16661667 TA *aoffset = NULL ;
16671668 int8_t *vecOffset = NULL ;
@@ -1916,7 +1917,7 @@ class tinyBLAS_Q0_PPC {
19161917
19171918
19181919 void KERNEL_4x8 (int64_t ii, int64_t jj) {
1919- vec_t vec_A[8 ], vec_B[16 ] = {0 };
1920+ /* vec_t vec_A[8], vec_B[16] = {0};
19201921 acc_t acc_0, acc_1;
19211922 std::array<int, 4> comparray {};
19221923 vector float fin_res[8] = {0};
@@ -1957,11 +1958,11 @@ class tinyBLAS_Q0_PPC {
19571958 compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
19581959 }
19591960 save_res(ii, jj, 0, fin_res);
1960- save_res (ii, jj+4 , 4 , fin_res);
1961+ save_res(ii, jj+4, 4, fin_res);*/
19611962 }
19621963
19631964 void KERNEL_8x4 (int64_t ii, int64_t jj) {
1964- vec_t vec_A[16 ], vec_B[8 ] = {0 };
1965+ /* vec_t vec_A[16], vec_B[8] = {0};
19651966 acc_t acc_0, acc_1;
19661967 std::array<int, 8> comparray {};
19671968 vector float fin_res[8] = {0};
@@ -2001,55 +2002,61 @@ class tinyBLAS_Q0_PPC {
20012002 compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
20022003 }
20032004 save_res(ii, jj, 0, fin_res);
2004- save_res (ii+4 , jj, 4 , fin_res);
2005+ save_res(ii+4, jj, 4, fin_res);*/
20052006 }
20062007
20072008 void KERNEL_8x8 (int64_t ii, int64_t jj) {
2008- vec_t vec_A[16 ], vec_B[16 ] = {0 };
2009+ vec_t vec_A[16 *k ], vec_B[16 *k ] = {0 };
20092010 acc_t acc_0, acc_1, acc_2, acc_3;
2010- std::array< int , 8 > comparray {} ;
2011+ int comparray [ 8 *k] ;
20112012 vector float fin_res[16 ] = {0 };
2012- vector float vs[16 ] = {0 };
2013+ vector float vs[16 *k ] = {0 };
20132014 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2015+ for (int l = 0 ; l< k; l++) {
2016+ // prepack A
2017+ if (isAblock_q4) {
2018+ packNormalInt4 ((A+(ii*lda)+l), lda, 8 , 4 , (int8_t *)(vec_A + 16 *l), comparray + 8 *l);
2019+ } else {
2020+ packNormal<int8_t , vector signed char >((const block_q8_0*)(A+(ii*lda)+l), lda, 8 , 8 , (int8_t *)(vec_A + 16 *l), false );
2021+ auto aoffset = A+(ii*lda)+l;
2022+ for (int i = 0 ; i < 8 ; i++) {
2023+ comparray[16 *l + i] = 0 ;
2024+ int ca = 0 ;
2025+ auto *at = aoffset->qs ;
2026+ for (int j = 0 ; j < 32 ; j++)
2027+ ca += (int )*at++;
2028+ comparray[16 *l + i] = ca;
2029+ aoffset += lda;
2030+ }
2031+ }
2032+ }
2033+ for (int l = 0 ; l < k; l++) {
2034+ // prepack B
2035+ packNormal<uint8_t , vector unsigned char >((B+(jj*ldb)+l), ldb, 8 , 8 , (uint8_t *)(vec_B + 16 *l), true );
2036+
2037+ }
20142038 for (int l = 0 ; l < k; l++) {
20152039 __builtin_mma_xxsetaccz (&acc_0);
20162040 __builtin_mma_xxsetaccz (&acc_1);
20172041 __builtin_mma_xxsetaccz (&acc_2);
20182042 __builtin_mma_xxsetaccz (&acc_3);
2019- if (std::is_same_v<TA, block_q4_0>) {
2020- packNormalInt4<8 >((A+(ii*lda)+l), lda, 8 , 4 , (int8_t *)vec_A, comparray);
2021- } else {
2022- packNormal<int8_t , vector signed char >((const block_q8_0*)(A+(ii*lda)+l), lda, 8 , 8 , (int8_t *)vec_A, false );
2023- }
2024- packNormal<uint8_t , vector unsigned char >((B+(jj*ldb)+l), ldb, 8 , 8 , (uint8_t *)vec_B, true );
20252043 for (int x = 0 ; x < 8 ; x++) {
2026- __builtin_mma_xvi8ger4pp (&acc_0, vec_A[x], vec_B[x]);
2027- __builtin_mma_xvi8ger4pp (&acc_1, vec_A[x+8 ], vec_B[x]);
2028- __builtin_mma_xvi8ger4pp (&acc_2, vec_A[x], vec_B[x+8 ]);
2029- __builtin_mma_xvi8ger4pp (&acc_3, vec_A[x+8 ], vec_B[x+8 ]);
2044+ __builtin_mma_xvi8ger4pp (&acc_0, vec_A[16 *l + x], vec_B[16 *l + x]);
2045+ __builtin_mma_xvi8ger4pp (&acc_1, vec_A[16 *l + x+8 ], vec_B[16 *l + x]);
2046+ __builtin_mma_xvi8ger4pp (&acc_2, vec_A[16 *l + x], vec_B[16 *l + x+8 ]);
2047+ __builtin_mma_xvi8ger4pp (&acc_3, vec_A[16 *l + x+8 ], vec_B[16 *l + x+8 ]);
20302048 }
20312049 for (int I = 0 ; I<8 ; I++) {
2050+ // float a_scale = unhalf((A+((ii+I)*lda)+l)->d);// * unhalf((B+((jj+J)*ldb)+l)->d));
20322051 for (int J = 0 ; J<4 ; J++) {
20332052 *((float *)&vs[I]+J) = (unhalf ((A+((ii+I)*lda)+l)->d ) * unhalf ((B+((jj+J)*ldb)+l)->d ));
20342053 *((float *)&vs[I+8 ]+J) = (unhalf ((A+((ii+I)*lda)+l)->d ) * unhalf ((B+((jj+J+4 )*ldb)+l)->d ));
20352054 }
20362055 }
2037- if (!isAblock_q4) {
2038- auto aoffset = A+(ii*lda)+l;
2039- for (int i = 0 ; i < 8 ; i++) {
2040- comparray[i] = 0 ;
2041- int ca = 0 ;
2042- auto *at = aoffset->qs ;
2043- for (int j = 0 ; j < 32 ; j++)
2044- ca += (int )*at++;
2045- comparray[i] = ca;
2046- aoffset += lda;
2047- }
2048- }
2049- compute<8 >(&acc_0, 0 , 0 , comparray, vs, fin_res);
2050- compute<8 >(&acc_1, 4 , 4 , comparray, vs, fin_res);
2051- compute<8 >(&acc_2, 0 , 8 , comparray, vs, fin_res);
2052- compute<8 >(&acc_3, 4 , 12 , comparray, vs, fin_res);
2056+ compute (&acc_0, 0 , 0 , comparray+ 8 *l, vs, fin_res);
2057+ compute (&acc_1, 4 , 4 , comparray+ 8 *l, vs, fin_res);
2058+ compute (&acc_2, 0 , 8 , comparray+ 8 *l, vs, fin_res);
2059+ compute (&acc_3, 4 , 12 , comparray+ 8 *l, vs, fin_res);
20532060 }
20542061 save_res (ii, jj, 0 , fin_res);
20552062 save_res (ii+4 , jj, 4 , fin_res);
@@ -2074,7 +2081,8 @@ class tinyBLAS_Q0_PPC {
20742081 for (int64_t job = start; job < end; ++job) {
20752082 int64_t ii = m0 + job / xtiles * RM;
20762083 int64_t jj = n0 + job % xtiles * RN;
2077- std::array<int , 4 > comparray{};
2084+ // std::array<int, 4> comparray{};
2085+ int comparray[4 ];// {};
20782086 vector float res[4 ] = {0 };
20792087 vector float fin_res[4 ] = {0 };
20802088 vector float vs[4 ] = {0 };
@@ -2086,7 +2094,8 @@ class tinyBLAS_Q0_PPC {
20862094 __builtin_prefetch ((B+(jj*ldb)+(l+1 ))->qs , 0 , 1 ); // prefetch one loop ahead
20872095 __builtin_mma_xxsetaccz (&acc_0);
20882096 if (isAblock_q4) {
2089- packNormalInt4<4 >((A+(ii*lda)+l), lda, RM, 4 , (int8_t *)vec_A, comparray);
2097+ // packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2098+ packNormalInt4 ((A+(ii*lda)+l), lda, RM, 4 , (int8_t *)vec_A, comparray);
20902099 } else {
20912100 packNormal<int8_t , vector signed char >((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8 , (int8_t *)vec_A, false );
20922101 }
0 commit comments