@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
414414    return  res;
415415}
416416
417- ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t  lib, ggml_type tsrc0, ggml_type tsrc1, int  r1ptg) {
417+ ggml_metal_pipeline_t  ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t  lib, ggml_type tsrc0, ggml_type tsrc1, int  nsg,  int  nxpsg,  int   r1ptg) {
418418    char  base[256 ];
419419    char  name[256 ];
420420
421421    snprintf (base, 256 , " kernel_mul_mv_ext_%s_%s_r1_%d"  , ggml_type_name (tsrc0), ggml_type_name (tsrc1), r1ptg);
422-     snprintf (name, 256 , " %s "  , base);
422+     snprintf (name, 256 , " %s_nsg=%d_nxpsg=%d "  , base, nsg, nxpsg );
423423
424424    ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
425425    if  (res) {
426426        return  res;
427427    }
428428
429-     res = ggml_metal_library_compile_pipeline (lib, base, name, nullptr );
429+     ggml_metal_cv_t  cv = ggml_metal_cv_init ();
430+ 
431+     ggml_metal_cv_set_int16 (cv, nsg,   FC_MUL_MV + 0 );
432+     ggml_metal_cv_set_int16 (cv, nxpsg, FC_MUL_MV + 1 );
433+ 
434+     res = ggml_metal_library_compile_pipeline (lib, base, name, cv);
435+ 
436+     ggml_metal_cv_free (cv);
430437
431438    return  res;
432439}
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
608615    };
609616
610617    snprintf (base, 256 , " kernel_mul_mv_%s_%s%s"  , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix);
611-     snprintf (name, 256 , " %s "  , base);
618+     snprintf (name, 256 , " %s_nsg=%d "  , base, nsg );
612619
613620    ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
614621    if  (res) {
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
824831    };
825832
826833    snprintf (base, 256 , " kernel_mul_mv_id_%s_%s%s"  , ggml_type_name (tsrc0), ggml_type_name (tsrc1), suffix);
827-     snprintf (name, 256 , " %s "  , base);
834+     snprintf (name, 256 , " %s_nsg=%d "  , base, nsg );
828835
829836    ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
830837    if  (res) {
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
923930            dk,
924931            dv);
925932
926-     snprintf (name, 256 , " kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d"  ,
927-             " flash_attn_ext"  ,
928-             ggml_type_name (op->src [1 ]->type ),
929-             dk,
930-             dv,
933+     snprintf (name, 256 , " %s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d"  ,
934+             base,
931935            has_mask,
932936            has_sinks,
933937            has_bias,
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
985989            dk,
986990            dv);
987991
988-     snprintf (name, 256 , " kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d"  ,
989-             " flash_attn_ext_vec"  ,
990-             ggml_type_name (op->src [1 ]->type ),
991-             dk,
992-             dv,
992+     snprintf (name, 256 , " %s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d"  ,
993+             base,
993994            has_mask,
994995            has_sinks,
995996            has_bias,
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
10331034    char  name[256 ];
10341035
10351036    snprintf (base, 256 , " kernel_flash_attn_ext_vec_reduce"  );
1036-     snprintf (name, 256 , " kernel_flash_attn_ext_vec_reduce_dv =%d_nwg=%d"  , dv, nwg);
1037+     snprintf (name, 256 , " %s_dv =%d_nwg=%d" , base , dv, nwg);
10371038
10381039    ggml_metal_pipeline_t  res = ggml_metal_library_get_pipeline (lib, name);
10391040    if  (res) {
0 commit comments