@@ -2695,11 +2695,13 @@ class tinyBLAS_PPC {
26952695 const TA *A, int64_t lda,
26962696 const TB *B, int64_t ldb,
26972697 TC *C, int64_t ldc,
2698- int ith, int nth)
2699- : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2698+ int ith, int nth, int64_t m_orig, bool is_transposed)
2699+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth), is_transposed(is_transposed){
2700+ m_orig = 0 ;
27002701 }
27012702
27022703 void matmul (int64_t m, int64_t n) {
2704+ m_orig = m;
27032705 mnpack (0 , m, 0 , n);
27042706 }
27052707
@@ -2957,7 +2959,13 @@ class tinyBLAS_PPC {
29572959 acc_t acc_0;
29582960 __builtin_mma_xxsetaccz (&acc_0);
29592961 for (int l = 0 ; l < k; l+=4 ) {
2960- packTranspose<vector float >(A+(ii*lda)+l, lda, 4 , 4 , (TA*)vec_A);
2962+ if (is_transposed) {
2963+ for (int x = 0 ; x< 4 ; x++) {
2964+ vec_A[x] = (vec_t )vec_xl (0 , (float *)A+ (l+x)*m_orig+ii);
2965+ }
2966+ } else {
2967+ packTranspose<vector float >(A+(ii*lda)+l, lda, 4 , 4 , (TA*)vec_A);
2968+ }
29612969 packTranspose<vector float >(B+(jj*ldb)+l, ldb, 4 , 4 , (TA*)vec_B);
29622970 __builtin_mma_xvf32gerpp (&acc_0, vec_A[0 ], vec_B[0 ]);
29632971 __builtin_mma_xvf32gerpp (&acc_0, vec_A[1 ], vec_B[1 ]);
@@ -2973,7 +2981,13 @@ class tinyBLAS_PPC {
29732981 __builtin_mma_xxsetaccz (&acc_0);
29742982 __builtin_mma_xxsetaccz (&acc_1);
29752983 for (int64_t l = 0 ; l < k; l+=4 ) {
2976- packTranspose<vector float >(A+(ii*lda)+l, lda, 4 , 4 , (TA*)vec_A);
2984+ if (is_transposed) {
2985+ for (int x =0 ; x< 4 ; x++) {
2986+ vec_A[x] = (vec_t ) vec_xl (0 , (float *)A+(l+x)*m_orig+ii);
2987+ }
2988+ }else {
2989+ packTranspose<vector float >(A+(ii*lda)+l, lda, 4 , 4 , (TA*)vec_A);
2990+ }
29772991 packTranspose<vector float >(B+(jj*ldb)+l, ldb, 8 , 4 , (TA*)vec_B);
29782992 __builtin_mma_xvf32gerpp (&acc_0, vec_A[0 ], (vec_t )vec_B[0 ]);
29792993 __builtin_mma_xvf32gerpp (&acc_1, vec_A[0 ], (vec_t )vec_B[1 ]);
@@ -2994,7 +3008,14 @@ class tinyBLAS_PPC {
29943008 __builtin_mma_xxsetaccz (&acc_0);
29953009 __builtin_mma_xxsetaccz (&acc_1);
29963010 for (int64_t l = 0 ; l < k; l+=4 ) {
3011+ if (is_transposed) {
3012+ for (int x = 0 ; x <4 ; x++) {
3013+ vec_A[2 *x] = (vec_t )vec_xl (0 , (float *)A+(l+x)*m_orig+ii);
3014+ vec_A[2 *x+1 ] = (vec_t )vec_xl (0 , (float *)A+(l+x)*m_orig+ii+4 );
3015+ }
3016+ } else {
29973017 packTranspose<vector float >(A+(ii*lda)+l, lda, 8 , 4 , (TA*)vec_A);
3018+ }
29983019 packTranspose<vector float >(B+(jj*ldb)+l, ldb, 4 , 4 , (TA*)vec_B);
29993020 __builtin_mma_xvf32gerpp (&acc_0, (vec_t )vec_A[0 ], vec_B[0 ]);
30003021 __builtin_mma_xvf32gerpp (&acc_1, (vec_t )vec_A[1 ], vec_B[0 ]);
@@ -3017,7 +3038,14 @@ class tinyBLAS_PPC {
30173038 __builtin_mma_xxsetaccz (&acc_2);
30183039 __builtin_mma_xxsetaccz (&acc_3);
30193040 for (int l = 0 ; l < k; l+=8 ) {
3041+ if (is_transposed) {
3042+ for (int x = 0 ; x <8 ; x++) {
3043+ vec_A[2 *x] = (vec_t )vec_xl (0 , (float *)A+(l+x)*m_orig+ii);
3044+ vec_A[2 *x+1 ] = (vec_t )vec_xl (0 , (float *)A+(l+x)*m_orig+ii+4 );
3045+ }
3046+ } else {
30203047 packTranspose<vector float >(A+(ii*lda)+l, lda, 8 , 8 , (TA*)vec_A);
3048+ }
30213049 packTranspose<vector float >(B+(jj*ldb)+l, ldb, 8 , 8 , (TA*)vec_B);
30223050 for (int x = 0 ; x < 16 ; x+=2 ) {
30233051 __builtin_mma_xvf32gerpp (&acc_0, (vec_t )vec_A[x], vec_B[x]);
@@ -3205,24 +3233,31 @@ class tinyBLAS_PPC {
32053233 * broadcasted, instead of using packing routine to prepack the
32063234 * matrix elements.
32073235 */
3208- if (RM == 1 ) {
3209- TA* a = const_cast <TA*>(A+(ii)*lda+l);
3236+ if (is_transposed) {
3237+ for (int x = 0 ; x< 4 ; x++) {
3238+ vec_A[x] = (vec_t )vec_xl (0 , (float *)A+(l+x)*m_orig+ii);
3239+ }
32103240 packTranspose<vector float >(B+(jj*ldb)+l, ldb, RN, 4 , (TA*)vec_B);
3211- vec_A[0 ] = (vec_t )vec_xl (0 ,a);
3212- vec_A[1 ] = (vec_t )vec_splats (*((TA*)&vec_A+1 ));
3213- vec_A[2 ] = (vec_t )vec_splats (*((TA*)&vec_A+2 ));
3214- vec_A[3 ] = (vec_t )vec_splats (*((TA*)&vec_A+3 ));
3215- } else if (RN == 1 ) {
3216- packTranspose<vector float >(A+(ii*lda)+l, lda, RM, 4 , (TA*)vec_A);
3217- TB* b = const_cast <TB*>(B+(jj)*ldb+l);
3218- vec_B[0 ] = (vec_t )vec_xl (0 ,b);
3219- vec_B[1 ] = (vec_t )vec_splats (*((TB*)&vec_B+1 ));
3220- vec_B[2 ] = (vec_t )vec_splats (*((TB*)&vec_B+2 ));
3221- vec_B[3 ] = (vec_t )vec_splats (*((TB*)&vec_B+3 ));
3222- } else {
3223- packTranspose<vector float >(A+(ii*lda)+l, lda, RM, 4 , (TA*)vec_A);
3224- packTranspose<vector float >(B+(jj*ldb)+l, ldb, RN, 4 , (TA*)vec_B);
3225- }
3241+ } else {
3242+ if (RM == 1 ) {
3243+ TA* a = const_cast <TA*>(A+(ii)*lda+l);
3244+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, RN, 4 , (TA*)vec_B);
3245+ vec_A[0 ] = (vec_t )vec_xl (0 ,a);
3246+ vec_A[1 ] = (vec_t )vec_splats (*((TA*)&vec_A+1 ));
3247+ vec_A[2 ] = (vec_t )vec_splats (*((TA*)&vec_A+2 ));
3248+ vec_A[3 ] = (vec_t )vec_splats (*((TA*)&vec_A+3 ));
3249+ } else if (RN == 1 ) {
3250+ packTranspose<vector float >(A+(ii*lda)+l, lda, RM, 4 , (TA*)vec_A);
3251+ TB* b = const_cast <TB*>(B+(jj)*ldb+l);
3252+ vec_B[0 ] = (vec_t )vec_xl (0 ,b);
3253+ vec_B[1 ] = (vec_t )vec_splats (*((TB*)&vec_B+1 ));
3254+ vec_B[2 ] = (vec_t )vec_splats (*((TB*)&vec_B+2 ));
3255+ vec_B[3 ] = (vec_t )vec_splats (*((TB*)&vec_B+3 ));
3256+ } else {
3257+ packTranspose<vector float >(A+(ii*lda)+l, lda, RM, 4 , (TA*)vec_A);
3258+ packTranspose<vector float >(B+(jj*ldb)+l, ldb, RN, 4 , (TA*)vec_B);
3259+ }
3260+ }
32263261 __builtin_mma_xvf32gerpp (&acc_0, vec_A[0 ], vec_B[0 ]);
32273262 __builtin_mma_xvf32gerpp (&acc_0, vec_A[1 ], vec_B[1 ]);
32283263 __builtin_mma_xvf32gerpp (&acc_0, vec_A[2 ], vec_B[2 ]);
@@ -3274,6 +3309,8 @@ class tinyBLAS_PPC {
32743309 const int64_t ldc;
32753310 const int ith;
32763311 const int nth;
3312+ int64_t m_orig;
3313+ bool is_transposed;
32773314};
32783315#endif
32793316} // namespace
@@ -3310,13 +3347,16 @@ class tinyBLAS_PPC {
33103347 */
33113348bool llamafile_sgemm (const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
33123349 const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
3313- int64_t ldc, int Atype, int Btype, int Ctype) {
3314-
3350+ int64_t ldc, int Atype, int Btype, int Ctype, bool is_transposed ) {
3351+ printf ( " m=%ld n=%ld k=%ld lda=%ld ldb=%ld ldc=%ld \n " , m, n, k, lda, ldb, ldc);
33153352 assert (m >= 0 );
33163353 assert (n >= 0 );
33173354 assert (k >= 0 );
3318- assert (lda >= k);
3319- assert (ldb >= k);
3355+ /* if (is_transposed)
3356+ assert(lda >= m);
3357+ else*/
3358+ // assert(lda >= k);
3359+ // assert(ldb >= k);
33203360 assert (ldc >= m);
33213361 assert (params->nth > 0 );
33223362 assert (params->ith < params->nth );
@@ -3366,12 +3406,58 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
33663406#elif defined(__MMA__)
33673407 if (k % 8 )
33683408 return false ;
3409+ // if (is_transposed)
3410+ // printf("A was transposed during GGUF; m = %d n = %d k = %d\n", m, n, k);
3411+ float * Ap = (float *)A;
3412+ float * Bp = (float *)B;
3413+ float * Cp = (float *)C;
3414+ printf (" Matrix AT in column major\n " );
3415+ for (int r = 0 ; r < k; r ++) {
3416+ printf (" | " );
3417+ for (int c = 0 ; c< m; c++) {
3418+ printf (" %.2f " , Ap[c*k + r]);
3419+ }
3420+ printf (" |\n " );
3421+ }
3422+ printf (" A memory layout n" );
3423+ for (int i = 0 ; i < (m*k); i++){
3424+ printf (" %.2f " , *(Ap++));
3425+ }
3426+ printf (" \n " );
3427+ printf (" B in column major\n " );
3428+ for (int r = 0 ; r < k; r ++) {
3429+ printf (" | " );
3430+ for (int c = 0 ; c< n; c++) {
3431+ printf (" %.2f " , Bp[c*k + r]);
3432+ }
3433+ printf (" |\n " );
3434+ }
3435+
3436+ printf (" B memory layout n" );
3437+ for (int i = 0 ; i < (n*k); i++){
3438+ printf (" %.2f " , *(Bp++));
3439+ }
3440+ printf (" \n " );
33693441 tinyBLAS_PPC<float , float , float > tb{
33703442 k, (const float *)A, lda,
33713443 (const float *)B, ldb,
33723444 (float *)C, ldc,
3373- params->ith , params->nth };
3445+ params->ith , params->nth , m, is_transposed };
33743446 tb.matmul (m, n);
3447+ printf (" C Matrix\n " );
3448+ for (int r = 0 ; r < m; r ++) {
3449+ printf (" | " );
3450+ for (int c = 0 ; c< n; c++) {
3451+ printf (" %.2f " , Cp[c*m + r]);
3452+ }
3453+ printf (" |\n " );
3454+ }
3455+
3456+ for (int i = 0 ; i < (m*n); i++){
3457+ printf (" %.2f " , *(Cp++));
3458+ }
3459+ printf (" \n " );
3460+ // printf("completd llamafile_Sgemm\n");
33753461 return true ;
33763462#else
33773463 return false ;
0 commit comments