@@ -749,6 +749,39 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
749
749
// utils
750
750
//
751
751
752
+ // Helper function to parse tensor buffer override strings
753
+ static void parse_tensor_buffer_overrides (const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
754
+ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
755
+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
756
+ auto * dev = ggml_backend_dev_get (i);
757
+ auto * buft = ggml_backend_dev_buffer_type (dev);
758
+ if (buft) {
759
+ buft_list[ggml_backend_buft_name (buft)] = buft;
760
+ }
761
+ }
762
+
763
+ for (const auto & override : string_split<std::string>(value, ' ,' )) {
764
+ std::string::size_type pos = override .find (' =' );
765
+ if (pos == std::string::npos) {
766
+ throw std::invalid_argument (" invalid value" );
767
+ }
768
+ std::string tensor_name = override .substr (0 , pos);
769
+ std::string buffer_type = override .substr (pos + 1 );
770
+
771
+ if (buft_list.find (buffer_type) == buft_list.end ()) {
772
+ printf (" Available buffer types:\n " );
773
+ for (const auto & it : buft_list) {
774
+ printf (" %s\n " , ggml_backend_buft_name (it.second ));
775
+ }
776
+ throw std::invalid_argument (" unknown buffer type" );
777
+ }
778
+ // keep strings alive and avoid leaking memory by storing them in a static vector
779
+ static std::list<std::string> buft_overrides;
780
+ buft_overrides.push_back (tensor_name);
781
+ overrides.push_back ({buft_overrides.back ().c_str (), buft_list.at (buffer_type)});
782
+ }
783
+ }
784
+
752
785
struct handle_model_result {
753
786
bool found_mmproj = false ;
754
787
common_params_model mmproj;
@@ -993,6 +1026,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
993
1026
params.tensor_buft_overrides .push_back ({nullptr , nullptr });
994
1027
}
995
1028
1029
+ if (!params.speculative .tensor_buft_overrides .empty ()) {
1030
+ params.speculative .tensor_buft_overrides .push_back ({nullptr , nullptr });
1031
+ }
1032
+
996
1033
if (!params.chat_template .empty () && !common_chat_verify_template (params.chat_template , params.use_jinja )) {
997
1034
throw std::runtime_error (string_format (
998
1035
" error: the supplied chat template is not supported: %s%s\n " ,
@@ -2349,40 +2386,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2349
2386
add_opt (common_arg (
2350
2387
{" --override-tensor" , " -ot" }, " <tensor name pattern>=<buffer type>,..." ,
2351
2388
" override tensor buffer type" , [](common_params & params, const std::string & value) {
2352
- /* static */ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
2353
- if (buft_list.empty ()) {
2354
- // enumerate all the devices and add their buffer types to the list
2355
- for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
2356
- auto * dev = ggml_backend_dev_get (i);
2357
- auto * buft = ggml_backend_dev_buffer_type (dev);
2358
- if (buft) {
2359
- buft_list[ggml_backend_buft_name (buft)] = buft;
2360
- }
2361
- }
2362
- }
2363
-
2364
- for (const auto & override : string_split<std::string>(value, ' ,' )) {
2365
- std::string::size_type pos = override .find (' =' );
2366
- if (pos == std::string::npos) {
2367
- throw std::invalid_argument (" invalid value" );
2368
- }
2369
- std::string tensor_name = override .substr (0 , pos);
2370
- std::string buffer_type = override .substr (pos + 1 );
2371
-
2372
- if (buft_list.find (buffer_type) == buft_list.end ()) {
2373
- printf (" Available buffer types:\n " );
2374
- for (const auto & it : buft_list) {
2375
- printf (" %s\n " , ggml_backend_buft_name (it.second ));
2376
- }
2377
- throw std::invalid_argument (" unknown buffer type" );
2378
- }
2379
- // keep strings alive and avoid leaking memory by storing them in a static vector
2380
- static std::list<std::string> buft_overrides;
2381
- buft_overrides.push_back (tensor_name);
2382
- params.tensor_buft_overrides .push_back ({buft_overrides.back ().c_str (), buft_list.at (buffer_type)});
2383
- }
2389
+ parse_tensor_buffer_overrides (value, params.tensor_buft_overrides );
2384
2390
}
2385
2391
));
2392
+ add_opt (common_arg (
2393
+ {" --override-tensor-draft" , " -otd" }, " <tensor name pattern>=<buffer type>,..." ,
2394
+ " override tensor buffer type for draft model" , [](common_params & params, const std::string & value) {
2395
+ parse_tensor_buffer_overrides (value, params.speculative .tensor_buft_overrides );
2396
+ }
2397
+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
2386
2398
add_opt (common_arg (
2387
2399
{" --cpu-moe" , " -cmoe" },
2388
2400
" keep all Mixture of Experts (MoE) weights in the CPU" ,
@@ -2405,6 +2417,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
2405
2417
}
2406
2418
}
2407
2419
).set_env (" LLAMA_ARG_N_CPU_MOE" ));
2420
+ add_opt (common_arg (
2421
+ {" --cpu-moe-draft" , " -cmoed" },
2422
+ " keep all Mixture of Experts (MoE) weights in the CPU for the draft model" ,
2423
+ [](common_params & params) {
2424
+ params.speculative .tensor_buft_overrides .push_back ({" \\ .ffn_(up|down|gate)_exps" , ggml_backend_cpu_buffer_type ()});
2425
+ }
2426
+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env (" LLAMA_ARG_CPU_MOE_DRAFT" ));
2427
+ add_opt (common_arg (
2428
+ {" --n-cpu-moe-draft" , " -ncmoed" }, " N" ,
2429
+ " keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model" ,
2430
+ [](common_params & params, int value) {
2431
+ if (value < 0 ) {
2432
+ throw std::invalid_argument (" invalid value" );
2433
+ }
2434
+ for (int i = 0 ; i < value; ++i) {
2435
+ static std::list<std::string> buft_overrides_draft;
2436
+ buft_overrides_draft.push_back (string_format (" blk\\ .%d\\ .ffn_(up|down|gate)_exps" , i));
2437
+ params.speculative .tensor_buft_overrides .push_back ({buft_overrides_draft.back ().c_str (), ggml_backend_cpu_buffer_type ()});
2438
+ }
2439
+ }
2440
+ ).set_examples ({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env (" LLAMA_ARG_N_CPU_MOE_DRAFT" ));
2408
2441
add_opt (common_arg (
2409
2442
{" -ngl" , " --gpu-layers" , " --n-gpu-layers" }, " N" ,
2410
2443
" number of layers to store in VRAM" ,
0 commit comments