@@ -2009,17 +2009,37 @@ class tinyBLAS_Q0_PPC {
20092009 acc_t acc_0, acc_1, acc_2, acc_3;
20102010 std::array<int , 8 > comparray {};
20112011 vector float fin_res[16 ] = {0 };
2012- vector float vs[16 ] = {0 };
2012+ vector float vs[16 * k] = {0 };
2013+ // scale factor computation
2014+ for (int l = 0 ; l < k; l++) {
2015+ for (int I = 0 ; I<8 ; I++) {
2016+ float a_scale = unhalf ((A+((ii+I)*lda)+l)->d );;
2017+ for (int J = 0 ; J<4 ; J++) {
2018+ *((float *)&vs[(16 *l)+ I]+J) = (a_scale * unhalf ((B+((jj+J)*ldb)+l)->d ));
2019+ *((float *)&vs[(16 *l) + I+8 ]+J) = (a_scale * unhalf ((B+((jj+J+4 )*ldb)+l)->d ));
2020+ }
2021+ }
2022+ }
20132023 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
20142024 for (int l = 0 ; l < k; l++) {
20152025 __builtin_mma_xxsetaccz (&acc_0);
20162026 __builtin_mma_xxsetaccz (&acc_1);
20172027 __builtin_mma_xxsetaccz (&acc_2);
20182028 __builtin_mma_xxsetaccz (&acc_3);
2019- if (std::is_same_v<TA, block_q4_0> ) {
2029+ if (isAblock_q4 ) {
20202030 packNormalInt4<8 >((A+(ii*lda)+l), lda, 8 , 4 , (int8_t *)vec_A, comparray);
20212031 } else {
20222032 packNormal<int8_t , vector signed char >((const block_q8_0*)(A+(ii*lda)+l), lda, 8 , 8 , (int8_t *)vec_A, false );
2033+ auto aoffset = A+(ii*lda)+l;
2034+ for (int i = 0 ; i < 8 ; i++) {
2035+ comparray[i] = 0 ;
2036+ int ca = 0 ;
2037+ auto *at = aoffset->qs ;
2038+ for (int j = 0 ; j < 32 ; j++)
2039+ ca += (int )*at++;
2040+ comparray[i] = ca;
2041+ aoffset += lda;
2042+ }
20232043 }
20242044 packNormal<uint8_t , vector unsigned char >((B+(jj*ldb)+l), ldb, 8 , 8 , (uint8_t *)vec_B, true );
20252045 for (int x = 0 ; x < 8 ; x++) {
@@ -2028,28 +2048,17 @@ class tinyBLAS_Q0_PPC {
20282048 __builtin_mma_xvi8ger4pp (&acc_2, vec_A[x], vec_B[x+8 ]);
20292049 __builtin_mma_xvi8ger4pp (&acc_3, vec_A[x+8 ], vec_B[x+8 ]);
20302050 }
2031- for (int I = 0 ; I<8 ; I++) {
2051+ /* for (int I = 0; I<8; I++) {
2052+ //float a_scale = unhalf((A+((ii+I)*lda)+l)->d);// * unhalf((B+((jj+J)*ldb)+l)->d));
20322053 for (int J = 0; J<4; J++) {
20332054 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
20342055 *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
20352056 }
2036- }
2037- if (!isAblock_q4) {
2038- auto aoffset = A+(ii*lda)+l;
2039- for (int i = 0 ; i < 8 ; i++) {
2040- comparray[i] = 0 ;
2041- int ca = 0 ;
2042- auto *at = aoffset->qs ;
2043- for (int j = 0 ; j < 32 ; j++)
2044- ca += (int )*at++;
2045- comparray[i] = ca;
2046- aoffset += lda;
2047- }
2048- }
2049- compute<8 >(&acc_0, 0 , 0 , comparray, vs, fin_res);
2050- compute<8 >(&acc_1, 4 , 4 , comparray, vs, fin_res);
2051- compute<8 >(&acc_2, 0 , 8 , comparray, vs, fin_res);
2052- compute<8 >(&acc_3, 4 , 12 , comparray, vs, fin_res);
2057+ }*/
2058+ compute<8 >(&acc_0, 0 , 0 , comparray, vs + 16 *l, fin_res);
2059+ compute<8 >(&acc_1, 4 , 4 , comparray, vs + 16 *l, fin_res);
2060+ compute<8 >(&acc_2, 0 , 8 , comparray, vs + 16 *l, fin_res);
2061+ compute<8 >(&acc_3, 4 , 12 , comparray, vs+ 16 *l, fin_res);
20532062 }
20542063 save_res (ii, jj, 0 , fin_res);
20552064 save_res (ii+4 , jj, 4 , fin_res);
0 commit comments