Skip to content

Commit 296ba81

Browse files
committed
[refactor][mbuffer] File load from variant
- Add code to be able to load a gguf file from a variant (memory or disk). - Some structs simplify how to load a file and keep track of the pointers (which are now in the same struct).
1 parent 427f1b8 commit 296ba81

File tree

6 files changed

+116
-23
lines changed

6 files changed

+116
-23
lines changed

src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ add_library(llama
2626
llama-memory-hybrid.cpp
2727
llama-memory-recurrent.cpp
2828
llama-mmap.cpp
29+
llama-model-load-input.cpp
30+
llama-model-load.cpp
2931
llama-model-loader.cpp
3032
llama-model-saver.cpp
3133
llama-model.cpp

src/llama-model-load-input.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "llama-model-load-input.h"
2+
3+
namespace load_input_variant {
4+
5+
const char * identifier(load_input_t & load_input) {
6+
if (std::holds_alternative<fname_load_input>(load_input)) {
7+
const auto & file_input = std::get<fname_load_input>(load_input);
8+
return file_input.fname.c_str();
9+
}
10+
static const char * buffer_id_str = "buffer";
11+
return buffer_id_str;
12+
}
13+
14+
fname_load_input split_name_from_variant(load_input_t & load_input) {
15+
auto file_input = std::get<fname_load_input>(load_input);
16+
return file_input;
17+
}
18+
19+
bool variant_supports_split_load(load_input_t & load_input) {
20+
return std::holds_alternative<fname_load_input>(load_input);
21+
}
22+
23+
} // namespace load_input_variant

src/llama-model-load-input.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <stdint.h>
4+
5+
#include <memory>
6+
#include <string>
7+
#include <variant>
8+
#include <vector>
9+
10+
namespace load_input_variant {
11+
12+
struct fname_load_input {
13+
const std::string & fname;
14+
std::vector<std::string> & splits; // optional, only need if the split does not follow naming scheme
15+
};
16+
17+
struct buffer_load_input {
18+
std::unique_ptr<std::basic_streambuf<uint8_t>> & streambuf;
19+
};
20+
21+
} // namespace load_input_variant
22+
23+
using load_input_t = std::variant<load_input_variant::fname_load_input, load_input_variant::buffer_load_input>;
24+
25+
namespace load_input_variant {
26+
const char * identifier(load_input_t & load_input);
27+
28+
fname_load_input split_name_from_variant(load_input_t & load_input);
29+
30+
bool variant_supports_split_load(load_input_t & load_input);
31+
} // namespace load_input_variant

src/llama-model-load.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "llama-model-load.h"
2+
3+
#include <memory>
4+
#include <stdexcept>
5+
#include <variant>
6+
7+
#include "llama-model-loader.h"
8+
9+
gguf_file_load::gguf_file_load(struct ggml_context ** ctx, load_input_t load_input) :
10+
params({
11+
/*.no_alloc = */ true,
12+
/*.ctx = */ ctx,
13+
}) {
14+
using namespace load_input_variant;
15+
if (std::holds_alternative<fname_load_input>(load_input)) {
16+
const auto & file_input = std::get<fname_load_input>(load_input);
17+
meta.reset(gguf_init_from_file(file_input.fname.c_str(), params));
18+
if (!meta) {
19+
throw std::runtime_error(format("%s: failed to load model from %s", __func__, file_input.fname.c_str()));
20+
}
21+
file = std::make_unique<llama_file_disk>(file_input.fname.c_str(), "ro");
22+
} else {
23+
const auto & buffer_input = std::get<buffer_load_input>(load_input);
24+
meta.reset(gguf_init_from_buffer(*buffer_input.streambuf, params));
25+
if (!meta) {
26+
throw std::runtime_error(format("%s: failed to load model from buffer", __func__));
27+
}
28+
file = std::make_unique<llama_file_buffer_ro>(std::move(buffer_input.streambuf));
29+
}
30+
}

src/llama-model-load.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <set>
5+
6+
#include "ggml-cpp.h"
7+
#include "llama-mmap.h"
8+
#include "llama-model-load-input.h"
9+
10+
struct llama_model_loader;
11+
12+
/// @brief Immediately loads and stores relevant data in the struct fields.
13+
struct gguf_file_load {
14+
struct gguf_init_params params;
15+
gguf_context_ptr meta;
16+
std::unique_ptr<llama_file> file = nullptr;
17+
18+
gguf_file_load(struct ggml_context ** ctx, load_input_t load_input);
19+
};

src/llama-model-loader.cpp

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "llama-model-loader.h"
22

33
#include "ggml.h"
4+
#include "llama-model-load-input.h"
5+
#include "llama-model-load.h"
46

57
#include <array>
68
#include <cinttypes>
@@ -485,22 +487,14 @@ llama_model_loader::llama_model_loader(
485487

486488
tensor_buft_overrides = param_tensor_buft_overrides_p;
487489

488-
// Load the main GGUF
489490
struct ggml_context * ctx = NULL;
490-
struct gguf_init_params params = {
491-
/*.no_alloc = */ true,
492-
/*.ctx = */ &ctx,
493-
};
494-
495-
meta.reset(gguf_init_from_file(fname.c_str(), params));
496-
if (!meta) {
497-
throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str()));
498-
}
491+
gguf_file_load main_gguf(&ctx, load_input_variant::fname_load_input{fname, splits});
492+
meta = std::move(main_gguf.meta);
499493

500494
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
501495
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
502496

503-
files.emplace_back(new llama_file_disk(fname.c_str(), "rb"));
497+
files.emplace_back(std::move(main_gguf.file));
504498
contexts.emplace_back(ctx);
505499

506500
// Save tensors data offset of the main file.
@@ -547,28 +541,22 @@ llama_model_loader::llama_model_loader(
547541
for (idx = 1; idx < n_split; idx++) {
548542
const char * fname_split = splits[idx].c_str();
549543

550-
struct gguf_init_params split_params = {
551-
/*.no_alloc = */ true,
552-
/*.ctx = */ &ctx,
553-
};
554-
gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) };
555-
if (!ctx_gguf) {
556-
throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split));
557-
}
544+
gguf_file_load split_gguf(&ctx, load_input_variant::fname_load_input{fname_split, splits});
545+
gguf_context_ptr& split_meta = split_gguf.meta;
558546

559547
// check idx
560548
{
561-
const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str());
549+
const int kid = gguf_find_key(split_meta.get(), kv_split_no.c_str());
562550
if (kid < 0) {
563551
throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split));
564552
}
565-
int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid);
553+
int idx_gguf = gguf_get_val_u16(split_meta.get(), kid);
566554
if (idx_gguf != idx) {
567555
throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx));
568556
}
569557
}
570558

571-
files.emplace_back(new llama_file_disk(fname_split, "rb"));
559+
files.emplace_back(std::move(split_gguf.file));
572560
contexts.emplace_back(ctx);
573561

574562
// Save tensors data offset info of the shard.
@@ -580,7 +568,7 @@ llama_model_loader::llama_model_loader(
580568
}
581569
n_elements += ggml_nelements(cur);
582570
n_bytes += ggml_nbytes(cur);
583-
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur));
571+
weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, split_meta.get(), cur));
584572
}
585573
}
586574

0 commit comments

Comments
 (0)