@@ -53,6 +53,13 @@ static std::vector<T> split(const std::string & str, char delim) {
5353    return  values;
5454}
5555
56+ template <typename  T, typename  F>
57+ static  std::vector<std::string> transform_to_str (const  std::vector<T> & values, F f) {
58+     std::vector<std::string> str_values;
59+     std::transform (values.begin (), values.end (), std::back_inserter (str_values), f);
60+     return  str_values;
61+ }
62+ 
5663template <typename  T>
5764static  T avg (const  std::vector<T> & v) {
5865    if  (v.empty ()) {
@@ -126,7 +133,8 @@ struct cmd_params {
126133    std::vector<int > n_prompt;
127134    std::vector<int > n_gen;
128135    std::vector<int > n_batch;
129-     std::vector<bool > f32_kv;
136+     std::vector<ggml_type> type_k;
137+     std::vector<ggml_type> type_v;
130138    std::vector<int > n_threads;
131139    std::vector<int > n_gpu_layers;
132140    std::vector<int > main_gpu;
@@ -142,7 +150,8 @@ static const cmd_params cmd_params_defaults = {
142150    /*  n_prompt      */   {512 },
143151    /*  n_gen         */   {128 },
144152    /*  n_batch       */   {512 },
145-     /*  f32_kv        */   {false },
153+     /*  type_k        */   {GGML_TYPE_F16},
154+     /*  type_v        */   {GGML_TYPE_F16},
146155    /*  n_threads     */   {get_num_physical_cores ()},
147156    /*  n_gpu_layers  */   {99 },
148157    /*  main_gpu      */   {0 },
@@ -162,7 +171,8 @@ static void print_usage(int /* argc */, char ** argv) {
162171    printf ("   -p, --n-prompt <n>                (default: %s)\n "  , join (cmd_params_defaults.n_prompt , " ,"  ).c_str ());
163172    printf ("   -n, --n-gen <n>                   (default: %s)\n "  , join (cmd_params_defaults.n_gen , " ,"  ).c_str ());
164173    printf ("   -b, --batch-size <n>              (default: %s)\n "  , join (cmd_params_defaults.n_batch , " ,"  ).c_str ());
165-     printf ("   --memory-f32 <0|1>                (default: %s)\n "  , join (cmd_params_defaults.f32_kv , " ,"  ).c_str ());
174+     printf ("   -ctk <t>, --cache-type-k <t>      (default: %s)\n "  , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ,"  ).c_str ());
175+     printf ("   -ctv <t>, --cache-type-v <t>      (default: %s)\n "  , join (transform_to_str (cmd_params_defaults.type_v , ggml_type_name), " ,"  ).c_str ());
166176    printf ("   -t, --threads <n>                 (default: %s)\n "  , join (cmd_params_defaults.n_threads , " ,"  ).c_str ());
167177    printf ("   -ngl, --n-gpu-layers <n>          (default: %s)\n "  , join (cmd_params_defaults.n_gpu_layers , " ,"  ).c_str ());
168178    printf ("   -mg, --main-gpu <i>               (default: %s)\n "  , join (cmd_params_defaults.main_gpu , " ,"  ).c_str ());
@@ -173,9 +183,32 @@ static void print_usage(int /* argc */, char ** argv) {
173183    printf ("   -v, --verbose                     (default: %s)\n "  , cmd_params_defaults.verbose  ? " 1"   : " 0"  );
174184    printf (" \n "  );
175185    printf (" Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n "  );
186+ }
176187
188+ static  ggml_type ggml_type_from_name (const  std::string & s) {
189+     if  (s == " f16"  ) {
190+         return  GGML_TYPE_F16;
191+     }
192+     if  (s == " q8_0"  ) {
193+         return  GGML_TYPE_Q8_0;
194+     }
195+     if  (s == " q4_0"  ) {
196+         return  GGML_TYPE_Q4_0;
197+     }
198+     if  (s == " q4_1"  ) {
199+         return  GGML_TYPE_Q4_1;
200+     }
201+     if  (s == " q5_0"  ) {
202+         return  GGML_TYPE_Q5_0;
203+     }
204+     if  (s == " q5_1"  ) {
205+         return  GGML_TYPE_Q5_1;
206+     }
207+ 
208+     return  GGML_TYPE_COUNT;
177209}
178210
211+ 
179212static  cmd_params parse_cmd_params (int  argc, char  ** argv) {
180213    cmd_params params;
181214    std::string arg;
@@ -224,13 +257,38 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
224257            }
225258            auto  p = split<int >(argv[i], split_delim);
226259            params.n_batch .insert (params.n_batch .end (), p.begin (), p.end ());
227-         } else  if  (arg == " --memory-f32 "  ) {
260+         } else  if  (arg == " -ctk "  || arg ==  " --cache-type-k "  ) {
228261            if  (++i >= argc) {
229262                invalid_param = true ;
230263                break ;
231264            }
232-             auto  p = split<int >(argv[i], split_delim);
233-             params.f32_kv .insert (params.f32_kv .end (), p.begin (), p.end ());
265+             auto  p = split<std::string>(argv[i], split_delim);
266+             std::vector<ggml_type> types;
267+             for  (const  auto  & t : p) {
268+                 ggml_type gt = ggml_type_from_name (t);
269+                 if  (gt == GGML_TYPE_COUNT) {
270+                     invalid_param = true ;
271+                     break ;
272+                 }
273+                 types.push_back (gt);
274+             }
275+             params.type_k .insert (params.type_k .end (), types.begin (), types.end ());
276+         } else  if  (arg == " -ctv"   || arg == " --cache-type-v"  ) {
277+             if  (++i >= argc) {
278+                 invalid_param = true ;
279+                 break ;
280+             }
281+             auto  p = split<std::string>(argv[i], split_delim);
282+             std::vector<ggml_type> types;
283+             for  (const  auto  & t : p) {
284+                 ggml_type gt = ggml_type_from_name (t);
285+                 if  (gt == GGML_TYPE_COUNT) {
286+                     invalid_param = true ;
287+                     break ;
288+                 }
289+                 types.push_back (gt);
290+             }
291+             params.type_v .insert (params.type_v .end (), types.begin (), types.end ());
234292        } else  if  (arg == " -t"   || arg == " --threads"  ) {
235293            if  (++i >= argc) {
236294                invalid_param = true ;
@@ -321,7 +379,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
321379    if  (params.n_prompt .empty ())     { params.n_prompt  = cmd_params_defaults.n_prompt ; }
322380    if  (params.n_gen .empty ())        { params.n_gen  = cmd_params_defaults.n_gen ; }
323381    if  (params.n_batch .empty ())      { params.n_batch  = cmd_params_defaults.n_batch ; }
324-     if  (params.f32_kv .empty ())       { params.f32_kv  = cmd_params_defaults.f32_kv ; }
382+     if  (params.type_k .empty ())       { params.type_k  = cmd_params_defaults.type_k ; }
383+     if  (params.type_v .empty ())       { params.type_v  = cmd_params_defaults.type_v ; }
325384    if  (params.n_gpu_layers .empty ()) { params.n_gpu_layers  = cmd_params_defaults.n_gpu_layers ; }
326385    if  (params.main_gpu .empty ())     { params.main_gpu  = cmd_params_defaults.main_gpu ; }
327386    if  (params.mul_mat_q .empty ())    { params.mul_mat_q  = cmd_params_defaults.mul_mat_q ; }
@@ -336,7 +395,8 @@ struct cmd_params_instance {
336395    int  n_prompt;
337396    int  n_gen;
338397    int  n_batch;
339-     bool  f32_kv;
398+     ggml_type type_k;
399+     ggml_type type_v;
340400    int  n_threads;
341401    int  n_gpu_layers;
342402    int  main_gpu;
@@ -365,7 +425,8 @@ struct cmd_params_instance {
365425
366426        cparams.n_ctx  = n_prompt + n_gen;
367427        cparams.n_batch  = n_batch;
368-         cparams.f16_kv  = !f32_kv;
428+         cparams.type_k  = type_k;
429+         cparams.type_v  = type_v;
369430        cparams.mul_mat_q  = mul_mat_q;
370431
371432        return  cparams;
@@ -380,15 +441,17 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
380441    for  (const  auto  & mg : params.main_gpu )
381442    for  (const  auto  & ts : params.tensor_split )
382443    for  (const  auto  & nb : params.n_batch )
383-     for  (const  auto  & fk : params.f32_kv )
444+     for  (const  auto  & tk : params.type_k )
445+     for  (const  auto  & tv : params.type_v )
384446    for  (const  auto  & mmq : params.mul_mat_q )
385447    for  (const  auto  & nt : params.n_threads ) {
386448        cmd_params_instance instance = {
387449            /*  .model        = */   m,
388450            /*  .n_prompt     = */   n_prompt,
389451            /*  .n_gen        = */   n_gen,
390452            /*  .n_batch      = */   nb,
391-             /*  .f32_kv       = */   fk,
453+             /*  .type_k       = */   tk,
454+             /*  .type_v       = */   tv,
392455            /*  .n_threads    = */   nt,
393456            /*  .n_gpu_layers = */   nl,
394457            /*  .main_gpu     = */   mg,
@@ -410,7 +473,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
410473    for  (const  auto  & mg : params.main_gpu )
411474    for  (const  auto  & ts : params.tensor_split )
412475    for  (const  auto  & nb : params.n_batch )
413-     for  (const  auto  & fk : params.f32_kv )
476+     for  (const  auto  & tk : params.type_k )
477+     for  (const  auto  & tv : params.type_v )
414478    for  (const  auto  & mmq : params.mul_mat_q )
415479    for  (const  auto  & nt : params.n_threads ) {
416480        for  (const  auto  & n_prompt : params.n_prompt ) {
@@ -422,7 +486,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
422486                /*  .n_prompt     = */   n_prompt,
423487                /*  .n_gen        = */   0 ,
424488                /*  .n_batch      = */   nb,
425-                 /*  .f32_kv       = */   fk,
489+                 /*  .type_k       = */   tk,
490+                 /*  .type_v       = */   tv,
426491                /*  .n_threads    = */   nt,
427492                /*  .n_gpu_layers = */   nl,
428493                /*  .main_gpu     = */   mg,
@@ -441,7 +506,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
441506                /*  .n_prompt     = */   0 ,
442507                /*  .n_gen        = */   n_gen,
443508                /*  .n_batch      = */   nb,
444-                 /*  .f32_kv       = */   fk,
509+                 /*  .type_k       = */   tk,
510+                 /*  .type_v       = */   tv,
445511                /*  .n_threads    = */   nt,
446512                /*  .n_gpu_layers = */   nl,
447513                /*  .main_gpu     = */   mg,
@@ -489,7 +555,8 @@ struct test {
489555    uint64_t  model_n_params;
490556    int  n_batch;
491557    int  n_threads;
492-     bool  f32_kv;
558+     ggml_type type_k;
559+     ggml_type type_v;
493560    int  n_gpu_layers;
494561    int  main_gpu;
495562    bool  mul_mat_q;
@@ -508,7 +575,8 @@ struct test {
508575        model_n_params = llama_model_n_params (lmodel);
509576        n_batch = inst.n_batch ;
510577        n_threads = inst.n_threads ;
511-         f32_kv = inst.f32_kv ;
578+         type_k = inst.type_k ;
579+         type_v = inst.type_v ;
512580        n_gpu_layers = inst.n_gpu_layers ;
513581        main_gpu = inst.main_gpu ;
514582        mul_mat_q = inst.mul_mat_q ;
@@ -571,7 +639,7 @@ struct test {
571639            " cuda"  , " opencl"  , " metal"  , " gpu_blas"  , " blas"  ,
572640            " cpu_info"  , " gpu_info"  ,
573641            " model_filename"  , " model_type"  , " model_size"  , " model_n_params"  ,
574-             " n_batch"  , " n_threads"  , " f16_kv "  ,
642+             " n_batch"  , " n_threads"  , " type_k " ,  " type_v "  ,
575643            " n_gpu_layers"  , " main_gpu"  , " mul_mat_q"  , " tensor_split"  ,
576644            " n_prompt"  , " n_gen"  , " test_time"  ,
577645            " avg_ns"  , " stddev_ns"  ,
@@ -621,7 +689,7 @@ struct test {
621689            std::to_string (cuda), std::to_string (opencl), std::to_string (metal), std::to_string (gpu_blas), std::to_string (blas),
622690            cpu_info, gpu_info,
623691            model_filename, model_type, std::to_string (model_size), std::to_string (model_n_params),
624-             std::to_string (n_batch), std::to_string (n_threads), std::to_string (!f32_kv ),
692+             std::to_string (n_batch), std::to_string (n_threads), ggml_type_name (type_k),  ggml_type_name (type_v ),
625693            std::to_string (n_gpu_layers), std::to_string (main_gpu), std::to_string (mul_mat_q), tensor_split_str,
626694            std::to_string (n_prompt), std::to_string (n_gen), test_time,
627695            std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
@@ -805,8 +873,11 @@ struct markdown_printer : public printer {
805873        if  (params.n_batch .size () > 1  || params.n_batch  != cmd_params_defaults.n_batch ) {
806874            fields.push_back (" n_batch"  );
807875        }
808-         if  (params.f32_kv .size () > 1  || params.f32_kv  != cmd_params_defaults.f32_kv ) {
809-             fields.push_back (" f16_kv"  );
876+         if  (params.type_k .size () > 1  || params.type_k  != cmd_params_defaults.type_k ) {
877+             fields.push_back (" type_k"  );
878+         }
879+         if  (params.type_v .size () > 1  || params.type_v  != cmd_params_defaults.type_v ) {
880+             fields.push_back (" type_v"  );
810881        }
811882        if  (params.main_gpu .size () > 1  || params.main_gpu  != cmd_params_defaults.main_gpu ) {
812883            fields.push_back (" main_gpu"  );
0 commit comments