@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
3434}
3535
3636void  ggml_metal_pipelines_free (ggml_metal_pipelines_t  ppls) {
37+     if  (!ppls) {
38+         return ;
39+     }
40+ 
3741    for  (auto  it = ppls->data .begin (); it != ppls->data .end (); ++it) {
3842        ggml_metal_pipeline_free (it->second );
3943    }
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
467471    //  use custom matrix x vector kernel
468472    switch  (tsrc0) {
469473        case  GGML_TYPE_F32:
474+         case  GGML_TYPE_F16:
475+         case  GGML_TYPE_BF16:
470476            {
471-                 GGML_ASSERT (op->src [1 ]->type  == GGML_TYPE_F32);
472- 
473-                 nsg = 1 ;
474-                 nr0 = 1 ;
475-                 nr1 = 4 ;
476477                if  (ne00 == 4 ) {
478+                     nsg = 1 ;
477479                    nr0 = 32 ;
480+                     nr1 = 4 ;
478481                    suffix = " _c4"  ;
479-                 }
480-             } break ;
481-         case  GGML_TYPE_F16:
482-         case  GGML_TYPE_BF16:
483-             {
484-                 nsg = 1 ;
485-                 nr0 = 1 ;
486-                 if  (op->src [1 ]->type  == GGML_TYPE_F32) {
487-                     if  (ne00 == 4 ) {
488-                         nr0 = 32 ;
489-                         nr1 = 4 ;
490-                         suffix = " _c4"  ;
491-                     } else  if  (ne11 * ne12 < 4 ) {
492-                         suffix = " _1row"  ;
493-                     } else  if  (ne00 >= 128  && ne01 >= 8  && ne00%4  == 0 ) {
494-                         suffix = " _l4"  ;
495-                         nr1 = ne11;
496-                     } else  {
497-                         nr1 = 4 ;
498-                     }
482+                 } else  if  (ne00 % 4  == 0 ) {
483+                     nsg = N_SG_F;
484+                     nr0 = N_R0_F;
485+                     nr1 = 1 ;
486+                     smem = 32 *sizeof (float )*N_R0_F;
487+                     suffix = " _4"  ;
499488                } else  {
500-                     nr1 = 4 ;
489+                     nsg = N_SG_F;
490+                     nr0 = N_R0_F;
491+                     nr1 = 1 ;
492+                     smem = 32 *sizeof (float )*N_R0_F;
501493                }
502494            } break ;
503495        case  GGML_TYPE_Q4_0:
@@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
623615        return  res;
624616    }
625617
626-     res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
618+     ggml_metal_cv_t  cv = ggml_metal_cv_init ();
619+ 
620+     ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
621+ 
622+     res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
623+ 
624+     ggml_metal_cv_free (cv);
627625
628626    ggml_metal_pipeline_set_nr0  (res, nr0);
629627    ggml_metal_pipeline_set_nr1  (res, nr1);
@@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
689687    const  ggml_type tsrc0 = op->src [0 ]->type ;
690688    const  ggml_type tsrc1 = op->src [1 ]->type ;
691689
690+     const  char  * suffix = " "  ;
691+ 
692692        //  use custom matrix x vector kernel
693693    switch  (tsrc0) {
694694        case  GGML_TYPE_F32:
695-             {
696-                 GGML_ASSERT (op->src [1 ]->type  == GGML_TYPE_F32);
697-                 nsg = 1 ;
698-                 nr0 = 1 ;
699-             } break ;
700695        case  GGML_TYPE_F16:
701-             {
702-                 GGML_ASSERT (op->src [1 ]->type  == GGML_TYPE_F32);
703-                 nsg = 1 ;
704-                 nr0 = 1 ;
705-             } break ;
706696        case  GGML_TYPE_BF16:
707697            {
708-                 GGML_ASSERT (op->src [1 ]->type  == GGML_TYPE_F32);
709-                 nsg = 1 ;
710-                 nr0 = 1 ;
698+                 if  (ne00 % 4  == 0 ) {
699+                     nsg = N_SG_F;
700+                     nr0 = N_R0_F;
701+                     nr1 = 1 ;
702+                     smem = 32 *sizeof (float )*N_R0_F;
703+                     suffix = " _4"  ;
704+                 } else  {
705+                     nsg = N_SG_F;
706+                     nr0 = N_R0_F;
707+                     nr1 = 1 ;
708+                     smem = 32 *sizeof (float )*N_R0_F;
709+                 }
711710            } break ;
712711        case  GGML_TYPE_Q4_0:
713712            {
@@ -824,15 +823,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824823            }
825824    };
826825
827-     snprintf (base, 256 , " kernel_mul_mv_id_%s_%s"  , ggml_type_name (tsrc0), ggml_type_name (tsrc1));
826+     snprintf (base, 256 , " kernel_mul_mv_id_%s_%s%s "  , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix );
828827    snprintf (name, 256 , " %s"  , base);
829828
830829    ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
831830    if  (res) {
832831        return  res;
833832    }
834833
835-     res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
834+     ggml_metal_cv_t  cv = ggml_metal_cv_init ();
835+ 
836+     ggml_metal_cv_set_int16 (cv, nsg, FC_MUL_MV + 0 );
837+ 
838+     res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
839+ 
840+     ggml_metal_cv_free (cv);
836841
837842    ggml_metal_pipeline_set_nr0  (res, nr0);
838843    ggml_metal_pipeline_set_nr1  (res, nr1);
0 commit comments