Skip to content

Commit f07c2ec

Browse files
committed
llama : add option to override tensor buffers
1 parent 9fbadae commit f07c2ec

File tree

9 files changed

+87
-8
lines changed

9 files changed

+87
-8
lines changed

common/arg.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "arg.h"
22

3+
#include "common.h"
34
#include "log.h"
45
#include "sampling.h"
56

@@ -321,6 +322,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
321322
params.kv_overrides.back().key[0] = 0;
322323
}
323324

325+
if (!params.tensor_buft_overrides.empty()) {
326+
params.tensor_buft_overrides.push_back({nullptr, nullptr});
327+
}
328+
324329
if (params.reranking && params.embedding) {
325330
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
326331
}
@@ -1477,6 +1482,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14771482
exit(0);
14781483
}
14791484
));
1485+
add_opt(common_arg(
1486+
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
1487+
"override tensor buffer type", [](common_params & params, const std::string & value) {
1488+
static std::map<std::string, ggml_backend_buffer_type_t> buft_list;
1489+
if (buft_list.empty()) {
1490+
// enumerate all the devices and add their buffer types to the list
1491+
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1492+
auto * dev = ggml_backend_dev_get(i);
1493+
auto * buft = ggml_backend_dev_buffer_type(dev);
1494+
buft_list[ggml_backend_buft_name(buft)] = buft;
1495+
}
1496+
}
1497+
1498+
for (const auto & override : string_split<std::string>(value, ',')) {
1499+
std::string::size_type pos = override.find('=');
1500+
if (pos == std::string::npos) {
1501+
throw std::invalid_argument("invalid value");
1502+
}
1503+
std::string tensor_name = override.substr(0, pos);
1504+
std::string buffer_type = override.substr(pos + 1);
1505+
1506+
if (buft_list.find(buffer_type) == buft_list.end()) {
1507+
printf("Available buffer types:\n");
1508+
for (const auto & it : buft_list) {
1509+
printf(" %s\n", ggml_backend_buft_name(it.second));
1510+
}
1511+
throw std::invalid_argument("unknown buffer type");
1512+
}
1513+
// FIXME: this leaks memory
1514+
params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
1515+
}
1516+
}
1517+
));
14801518
add_opt(common_arg(
14811519
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
14821520
"number of layers to store in VRAM",

common/common.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,22 +1083,32 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
10831083
if (!params.devices.empty()) {
10841084
mparams.devices = params.devices.data();
10851085
}
1086+
10861087
if (params.n_gpu_layers != -1) {
10871088
mparams.n_gpu_layers = params.n_gpu_layers;
10881089
}
1090+
10891091
mparams.main_gpu = params.main_gpu;
10901092
mparams.split_mode = params.split_mode;
10911093
mparams.tensor_split = params.tensor_split;
10921094
mparams.use_mmap = params.use_mmap;
10931095
mparams.use_mlock = params.use_mlock;
10941096
mparams.check_tensors = params.check_tensors;
1097+
10951098
if (params.kv_overrides.empty()) {
10961099
mparams.kv_overrides = NULL;
10971100
} else {
10981101
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
10991102
mparams.kv_overrides = params.kv_overrides.data();
11001103
}
11011104

1105+
if (params.tensor_buft_overrides.empty()) {
1106+
mparams.tensor_buft_overrides = NULL;
1107+
} else {
1108+
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
1109+
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
1110+
}
1111+
11021112
return mparams;
11031113
}
11041114

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ struct common_params {
256256
std::vector<std::string> in_files; // all input files
257257
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
258258
std::vector<llama_model_kv_override> kv_overrides;
259+
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
259260

260261
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
261262
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale

include/llama.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,18 @@ extern "C" {
275275
};
276276
};
277277

278+
struct llama_model_tensor_buft_override {
279+
const char * pattern;
280+
ggml_backend_buffer_type_t buft;
281+
};
282+
278283
struct llama_model_params {
279284
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
280285
ggml_backend_dev_t * devices;
281286

287+
// NULL-terminated list of buffer types to use for tensors that match a pattern
288+
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
289+
282290
int32_t n_gpu_layers; // number of layers to store in VRAM
283291
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
284292

src/llama-model-loader.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
445445
std::vector<std::string> & splits,
446446
bool use_mmap,
447447
bool check_tensors,
448-
const struct llama_model_kv_override * param_overrides_p) {
448+
const llama_model_kv_override * param_overrides_p,
449+
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
449450
int trace = 0;
450451
if (getenv("LLAMA_TRACE")) {
451452
trace = atoi(getenv("LLAMA_TRACE"));
@@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
457458
}
458459
}
459460

461+
tensor_buft_overrides = param_tensor_buft_overrides_p;
462+
460463
// Load the main GGUF
461464
struct ggml_context * ctx = NULL;
462465
struct gguf_init_params params = {

src/llama-model-loader.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ struct llama_model_loader {
7777

7878
llama_mmaps mappings;
7979

80-
std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
81-
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
80+
std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
81+
std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
82+
const llama_model_tensor_buft_override * tensor_buft_overrides;
8283

8384
gguf_context_ptr meta;
8485
std::vector<ggml_context_ptr> contexts;
@@ -95,7 +96,8 @@ struct llama_model_loader {
9596
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
9697
bool use_mmap,
9798
bool check_tensors,
98-
const struct llama_model_kv_override * param_overrides_p);
99+
const llama_model_kv_override * param_overrides_p,
100+
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
99101

100102
template<typename T>
101103
typename std::enable_if<std::is_integral<T>::value, bool>::type

src/llama-model.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,9 +1444,25 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
14441444
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
14451445
}
14461446

1447-
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list);
1447+
ggml_backend_buffer_type_t buft = nullptr;
1448+
1449+
// check overrides
1450+
if (ml.tensor_buft_overrides) {
1451+
std::string tensor_name = tn.str();
1452+
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
1453+
if (tensor_name.find(overrides->pattern) != std::string::npos) {
1454+
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
1455+
buft = overrides->buft;
1456+
break;
1457+
}
1458+
}
1459+
}
1460+
14481461
if (!buft) {
1449-
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
1462+
buft = select_weight_buft(hparams, t_meta, op, *buft_list);
1463+
if (!buft) {
1464+
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
1465+
}
14501466
}
14511467

14521468
// avoid using a host buffer when using mmap
@@ -3757,6 +3773,7 @@ const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
37573773
struct llama_model_params llama_model_default_params() {
37583774
struct llama_model_params result = {
37593775
/*.devices =*/ nullptr,
3776+
/*.tensor_buft_overrides =*/ nullptr,
37603777
/*.n_gpu_layers =*/ 0,
37613778
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
37623779
/*.main_gpu =*/ 0,

src/llama-quant.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
527527
}
528528

529529
std::vector<std::string> splits = {};
530-
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
530+
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
531531
ml.init_mappings(false); // no prefetching
532532

533533
llama_model model(llama_model_default_params());

src/llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
4040
model.t_start_us = tm.t_start_us;
4141

4242
try {
43-
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
43+
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
4444

4545
ml.print_info();
4646

0 commit comments

Comments
 (0)