@@ -161,10 +161,17 @@ static const char * split_mode_str(llama_split_mode mode) {
161161 }
162162}
163163
164+ static std::string pair_str (const std::pair<int , int > & p) {
165+ static char buf[32 ];
166+ snprintf (buf, sizeof (buf), " %d,%d" , p.first , p.second );
167+ return buf;
168+ }
169+
164170struct cmd_params {
165171 std::vector<std::string> model;
166172 std::vector<int > n_prompt;
167173 std::vector<int > n_gen;
174+ std::vector<std::pair<int , int >> n_pg;
168175 std::vector<int > n_batch;
169176 std::vector<int > n_ubatch;
170177 std::vector<ggml_type> type_k;
@@ -188,6 +195,7 @@ static const cmd_params cmd_params_defaults = {
188195 /* model */ {" models/7B/ggml-model-q4_0.gguf" },
189196 /* n_prompt */ {512 },
190197 /* n_gen */ {128 },
198+ /* n_pg */ {{512 , 128 }},
191199 /* n_batch */ {2048 },
192200 /* n_ubatch */ {512 },
193201 /* type_k */ {GGML_TYPE_F16},
@@ -215,10 +223,11 @@ static void print_usage(int /* argc */, char ** argv) {
215223 printf (" -m, --model <filename> (default: %s)\n " , join (cmd_params_defaults.model , " ," ).c_str ());
216224 printf (" -p, --n-prompt <n> (default: %s)\n " , join (cmd_params_defaults.n_prompt , " ," ).c_str ());
217225 printf (" -n, --n-gen <n> (default: %s)\n " , join (cmd_params_defaults.n_gen , " ," ).c_str ());
226+ printf (" -pg <pp,tg> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.n_pg , pair_str), " ," ).c_str ());
218227 printf (" -b, --batch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_batch , " ," ).c_str ());
219- printf (" -ub N , --ubatch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_ubatch , " ," ).c_str ());
220- printf (" -ctk <t> , --cache-type-k <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ," ).c_str ());
221- printf (" -ctv <t> , --cache-type-v <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_v , ggml_type_name), " ," ).c_str ());
228+ printf (" -ub, --ubatch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_ubatch , " ," ).c_str ());
229+ printf (" -ctk, --cache-type-k <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ," ).c_str ());
230+ printf (" -ctv, --cache-type-v <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_v , ggml_type_name), " ," ).c_str ());
222231 printf (" -t, --threads <n> (default: %s)\n " , join (cmd_params_defaults.n_threads , " ," ).c_str ());
223232 printf (" -ngl, --n-gpu-layers <n> (default: %s)\n " , join (cmd_params_defaults.n_gpu_layers , " ," ).c_str ());
224233 printf (" -sm, --split-mode <none|layer|row> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.split_mode , split_mode_str), " ," ).c_str ());
@@ -304,6 +313,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
304313 }
305314 auto p = split<int >(argv[i], split_delim);
306315 params.n_gen .insert (params.n_gen .end (), p.begin (), p.end ());
316+ } else if (arg == " -pg" ) {
317+ if (++i >= argc) {
318+ invalid_param = true ;
319+ break ;
320+ }
321+ auto p = split<std::string>(argv[i], ' ,' );
322+ if (p.size () != 2 ) {
323+ invalid_param = true ;
324+ break ;
325+ }
326+ params.n_pg .push_back ({std::stoi (p[0 ]), std::stoi (p[1 ])});
307327 } else if (arg == " -b" || arg == " --batch-size" ) {
308328 if (++i >= argc) {
309329 invalid_param = true ;
@@ -493,6 +513,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
493513 if (params.model .empty ()) { params.model = cmd_params_defaults.model ; }
494514 if (params.n_prompt .empty ()) { params.n_prompt = cmd_params_defaults.n_prompt ; }
495515 if (params.n_gen .empty ()) { params.n_gen = cmd_params_defaults.n_gen ; }
516+ if (params.n_pg .empty ()) { params.n_pg = cmd_params_defaults.n_pg ; }
496517 if (params.n_batch .empty ()) { params.n_batch = cmd_params_defaults.n_batch ; }
497518 if (params.n_ubatch .empty ()) { params.n_ubatch = cmd_params_defaults.n_ubatch ; }
498519 if (params.type_k .empty ()) { params.type_k = cmd_params_defaults.type_k ; }
@@ -632,6 +653,31 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
632653 };
633654 instances.push_back (instance);
634655 }
656+
657+ for (const auto & n_pg : params.n_pg ) {
658+ if (n_pg.first == 0 && n_pg.second == 0 ) {
659+ continue ;
660+ }
661+ cmd_params_instance instance = {
662+ /* .model = */ m,
663+ /* .n_prompt = */ n_pg.first ,
664+ /* .n_gen = */ n_pg.second ,
665+ /* .n_batch = */ nb,
666+ /* .n_ubatch = */ nub,
667+ /* .type_k = */ tk,
668+ /* .type_v = */ tv,
669+ /* .n_threads = */ nt,
670+ /* .n_gpu_layers = */ nl,
671+ /* .split_mode = */ sm,
672+ /* .main_gpu = */ mg,
673+ /* .no_kv_offload= */ nkvo,
674+ /* .flash_attn = */ fa,
675+ /* .tensor_split = */ ts,
676+ /* .use_mmap = */ mmp,
677+ /* .embeddings = */ embd,
678+ };
679+ instances.push_back (instance);
680+ }
635681 }
636682
637683 return instances;
@@ -965,6 +1011,9 @@ struct markdown_printer : public printer {
9651011 if (field == " n_gpu_layers" ) {
9661012 return 3 ;
9671013 }
1014+ if (field == " test" ) {
1015+ return 13 ;
1016+ }
9681017
9691018 int width = std::max ((int )field.length (), 10 );
9701019
@@ -1091,12 +1140,11 @@ struct markdown_printer : public printer {
10911140 value = test::get_backend ();
10921141 } else if (field == " test" ) {
10931142 if (t.n_prompt > 0 && t.n_gen == 0 ) {
1094- snprintf (buf, sizeof (buf), " pp %d" , t.n_prompt );
1143+ snprintf (buf, sizeof (buf), " pp%d" , t.n_prompt );
10951144 } else if (t.n_gen > 0 && t.n_prompt == 0 ) {
1096- snprintf (buf, sizeof (buf), " tg %d" , t.n_gen );
1145+ snprintf (buf, sizeof (buf), " tg%d" , t.n_gen );
10971146 } else {
1098- assert (false );
1099- exit (1 );
1147+ snprintf (buf, sizeof (buf), " pp%d+tg%d" , t.n_prompt , t.n_gen );
11001148 }
11011149 value = buf;
11021150 } else if (field == " t/s" ) {
@@ -1297,6 +1345,7 @@ int main(int argc, char ** argv) {
12971345 llama_kv_cache_clear (ctx);
12981346
12991347 uint64_t t_start = get_time_ns ();
1348+
13001349 if (t.n_prompt > 0 ) {
13011350 test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t.n_threads );
13021351 }
0 commit comments