1- //  -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2- //  vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
3- // 
41//  Copyright 2024 Mozilla Foundation
52// 
63//  Permission is hereby granted, free of charge, to any person obtaining
@@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM {
585582};
586583#endif  //  __ARM_FEATURE_DOTPROD
587584
588- #if  defined(__AVX2__) || defined(__AVX512F__)
585+ #if  defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) 
589586template  <typename  TA, typename  TB, typename  TC>
590- class  tinyBLAS_Q0_AVX2  {
587+ class  tinyBLAS_Q0_AVX  {
591588  public: 
592-     tinyBLAS_Q0_AVX2 (int64_t  k,
593-                       const  TA *A, int64_t  lda,
594-                       const  TB *B, int64_t  ldb,
595-                       TC *C, int64_t  ldc,
596-                       int  ith, int  nth)
589+     tinyBLAS_Q0_AVX (int64_t  k,
590+                     const  TA *A, int64_t  lda,
591+                     const  TB *B, int64_t  ldb,
592+                     TC *C, int64_t  ldc,
593+                     int  ith, int  nth)
597594        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
598595    }
599596
@@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 {
728725            __m256 Cv[RN][RM] = {};
729726            for  (int64_t  l = 0 ; l < k; ++l)
730727                for  (int64_t  j = 0 ; j < RN; ++j)
731-                     for  (int64_t  i = 0 ; i < RM; ++i)
728+                     for  (int64_t  i = 0 ; i < RM; ++i) {
729+ #if  defined(__AVX2__)
730+                         __m256 udTmp = updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
731+                                                               load (A + lda * (ii + i) + l)),
732+                                              _mm256_sign_epi8 (load (B + ldb * (jj + j) + l),
733+                                                               load (A + lda * (ii + i) + l)));
734+ #else 
735+                         __m128i ali0 = load0 (A + lda * (ii + i) + l);
736+                         __m128i ali1 = load1 (A + lda * (ii + i) + l);
737+                         __m128i blj0 = load0 (B + ldb * (jj + j) + l);
738+                         __m128i blj1 = load1 (B + ldb * (jj + j) + l);
739+ 
740+                         __m128i sepAA0 = _mm_sign_epi8 (ali0, ali0);
741+                         __m128i sepAA1 = _mm_sign_epi8 (ali1, ali1);
742+                         __m128i sepBA0 = _mm_sign_epi8 (blj0, ali0);
743+                         __m128i sepBA1 = _mm_sign_epi8 (blj1, ali1);
744+ 
745+                         //  updot
746+                         const  __m128i oneFill = _mm_set1_epi16 (1 );
747+                         __m128i mad0 = _mm_maddubs_epi16 (sepAA0, sepBA0);
748+                         __m128i mad1 = _mm_maddubs_epi16 (sepAA1, sepBA1);
749+                         __m256 udTmp = _mm256_cvtepi32_ps (MM256_SET_M128I (_mm_madd_epi16 (oneFill, mad1), _mm_madd_epi16 (oneFill, mad0)));
750+ #endif 
732751                        Cv[j][i] = madd (_mm256_set1_ps (unhalf (A[lda * (ii + i) + l].d ) *
733752                                                       unhalf (B[ldb * (jj + j) + l].d )),
734-                                         updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
735-                                                                load (A + lda * (ii + i) + l)),
736-                                               _mm256_sign_epi8 (load (B + ldb * (jj + j) + l),
737-                                                                load (A + lda * (ii + i) + l))),
738-                                         Cv[j][i]);
753+                                                        udTmp,
754+                                                        Cv[j][i]);
755+                     }
739756            for  (int64_t  j = 0 ; j < RN; ++j)
740757                for  (int64_t  i = 0 ; i < RM; ++i)
741758                    C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
@@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
746763        return  _mm256_loadu_si256 ((const  __m256i *)b->qs );
747764    }
748765
766+     inline  __m128i load0 (const  block_q8_0 *b) {
767+         return  _mm_loadu_si128 ((const  __m128i *)b->qs );
768+     }
769+ 
770+     inline  __m128i load1 (const  block_q8_0 *b) {
771+         return  _mm_loadu_si128 (((const  __m128i *)b->qs ) + 1 );
772+     }
773+ 
749774    inline  __m256i load (const  block_q4_0 *b) {
750775        return  _mm256_sub_epi8 (denibble (b->qs ), _mm256_set1_epi8 (8 ));
751776    }
752777
778+     inline  __m128i load0 (const  block_q4_0 *b) {
779+         const  __m128i x = _mm_loadu_si128 ((const  __m128i *)(b->qs ));
780+         return  _mm_sub_epi8 (_mm_and_si128 (_mm_set1_epi8 (15 ), x), _mm_set1_epi8 (8 ));
781+     }
782+ 
783+     inline  __m128i load1 (const  block_q4_0 *b) {
784+         const  __m128i x = _mm_loadu_si128 ((const  __m128i *)(b->qs ));
785+         return  _mm_sub_epi8 (_mm_and_si128 (_mm_set1_epi8 (15 ), _mm_srli_epi16 (x, 4 )), _mm_set1_epi8 (8 ));
786+     }
787+ 
753788    inline  __m256 updot (__m256i u, __m256i s) {
754789        __m256i res;
755790#if  defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 {
777812    const  int  ith;
778813    const  int  nth;
779814};
780- #endif  //  __AVX2__ 
815+ #endif  //  __AVX__ 
781816
782817} //  namespace
783818
@@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
928963    case  GGML_TYPE_Q8_0: {
929964        if  (Btype != GGML_TYPE_Q8_0)
930965           return  false ;
931- #if  defined(__AVX2__) || defined(__AVX512F__)
932-         tinyBLAS_Q0_AVX2 <block_q8_0, block_q8_0, float > tb{
966+ #if  defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) 
967+         tinyBLAS_Q0_AVX <block_q8_0, block_q8_0, float > tb{
933968            k, (const  block_q8_0 *)A, lda,
934969            (const  block_q8_0 *)B, ldb,
935970            (float  *)C, ldc,
@@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
952987    case  GGML_TYPE_Q4_0: {
953988        if  (Btype != GGML_TYPE_Q8_0)
954989            return  false ;
955- #if  defined(__AVX2__) || defined(__AVX512F__)
956-         tinyBLAS_Q0_AVX2 <block_q4_0, block_q8_0, float > tb{
990+ #if  defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) 
991+         tinyBLAS_Q0_AVX <block_q4_0, block_q8_0, float > tb{
957992            k, (const  block_q4_0 *)A, lda,
958993            (const  block_q8_0 *)B, ldb,
959994            (float  *)C, ldc,
0 commit comments