Skip to content

Commit da6e20c

Browse files
committed
[refactor] Load all data
- The function now takes size_data instead of the member attribute. - Sanity checks of file pointer handles These two changes will be useful when calling `load_all_data` multiple times during incremental shard load.
1 parent 67a5476 commit da6e20c

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

src/llama-model-loader.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -904,12 +904,9 @@ void llama_model_loader::load_data_for(struct ggml_tensor * cur) const {
904904
}
905905
}
906906

907-
bool llama_model_loader::load_all_data(
908-
struct ggml_context * ctx,
909-
llama_buf_map & bufs,
910-
llama_mlocks * lmlocks,
911-
llama_progress_callback progress_callback,
912-
void * progress_callback_user_data) {
907+
bool llama_model_loader::load_all_data(size_t size_data, struct ggml_context * ctx, llama_buf_map & bufs,
908+
llama_mlocks * lmlocks, llama_progress_callback progress_callback,
909+
void * progress_callback_user_data) {
913910
GGML_ASSERT(size_data != 0 && "call init_mappings() first");
914911

915912
std::vector<no_init<uint8_t>> read_buf;
@@ -1049,6 +1046,12 @@ bool llama_model_loader::load_all_data(
10491046
}
10501047
} else {
10511048
const auto & file = files.at(weight->idx);
1049+
if (file == nullptr) {
1050+
throw std::runtime_error(
1051+
format("file not found for tensor '%s' at split-index %d", ggml_get_name(cur), weight->idx));
1052+
}
1053+
LLAMA_LOG_CMAKE_DEBUG("%s: uploading tensor %s from file at split-index %d\n", __func__, ggml_get_name(cur),
1054+
weight->idx);
10521055
if (ggml_backend_buffer_is_host(cur->buffer)) {
10531056
file->seek(weight->offs, SEEK_SET);
10541057
file->read_raw(cur->data, n_size);

src/llama-model-loader.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,8 @@ struct llama_model_loader {
158158
void load_data_for(struct ggml_tensor * cur) const;
159159

160160
// Returns false if cancelled by progress_callback
161-
bool load_all_data(
162-
struct ggml_context * ctx,
163-
llama_buf_map & bufs,
164-
llama_mlocks * lmlocks,
165-
llama_progress_callback progress_callback,
166-
void * progress_callback_user_data);
161+
bool load_all_data(size_t size_data, struct ggml_context * ctx, llama_buf_map & bufs, llama_mlocks * lmlocks,
162+
llama_progress_callback progress_callback, void * progress_callback_user_data);
167163

168164
std::string ftype_name() const;
169165

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4394,7 +4394,7 @@ bool llama_model::create_backend_buffers(std::size_t
43944394
for (auto & it : ctx_bufs) {
43954395
ggml_context * ctx = it.first;
43964396
auto & bufs = it.second;
4397-
if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
4397+
if (!ml.load_all_data(size_data, ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
43984398
return false;
43994399
}
44004400
}

0 commit comments

Comments
 (0)