|
1 | 1 | #include "arg.h" |
2 | 2 |
|
| 3 | +#include "common.h" |
3 | 4 | #include "log.h" |
4 | 5 | #include "sampling.h" |
5 | 6 |
|
@@ -321,6 +322,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context |
321 | 322 | params.kv_overrides.back().key[0] = 0; |
322 | 323 | } |
323 | 324 |
|
| 325 | + if (!params.tensor_buft_overrides.empty()) { |
| 326 | + params.tensor_buft_overrides.push_back({nullptr, nullptr}); |
| 327 | + } |
| 328 | + |
324 | 329 | if (params.reranking && params.embedding) { |
325 | 330 | throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); |
326 | 331 | } |
@@ -1477,6 +1482,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex |
1477 | 1482 | exit(0); |
1478 | 1483 | } |
1479 | 1484 | )); |
| 1485 | + add_opt(common_arg( |
| 1486 | + {"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...", |
| 1487 | + "override tensor buffer type", [](common_params & params, const std::string & value) { |
| 1488 | + static std::map<std::string, ggml_backend_buffer_type_t> buft_list; |
| 1489 | + if (buft_list.empty()) { |
| 1490 | + // enumerate all the devices and add their buffer types to the list |
| 1491 | + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
| 1492 | + auto * dev = ggml_backend_dev_get(i); |
| 1493 | + auto * buft = ggml_backend_dev_buffer_type(dev); |
| 1494 | + buft_list[ggml_backend_buft_name(buft)] = buft; |
| 1495 | + } |
| 1496 | + } |
| 1497 | + |
| 1498 | + for (const auto & override : string_split<std::string>(value, ',')) { |
| 1499 | + std::string::size_type pos = override.find('='); |
| 1500 | + if (pos == std::string::npos) { |
| 1501 | + throw std::invalid_argument("invalid value"); |
| 1502 | + } |
| 1503 | + std::string tensor_name = override.substr(0, pos); |
| 1504 | + std::string buffer_type = override.substr(pos + 1); |
| 1505 | + |
| 1506 | + if (buft_list.find(buffer_type) == buft_list.end()) { |
| 1507 | + printf("Available buffer types:\n"); |
| 1508 | + for (const auto & it : buft_list) { |
| 1509 | + printf(" %s\n", ggml_backend_buft_name(it.second)); |
| 1510 | + } |
| 1511 | + throw std::invalid_argument("unknown buffer type"); |
| 1512 | + } |
| 1513 | + // FIXME: this leaks memory |
| 1514 | + params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); |
| 1515 | + } |
| 1516 | + } |
| 1517 | + )); |
1480 | 1518 | add_opt(common_arg( |
1481 | 1519 | {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", |
1482 | 1520 | "number of layers to store in VRAM", |
|
0 commit comments