Skip to content

Commit f8942e7

Browse files
committed
[common] Pure interface for files
Convert llama_file to a pure virtual class that can be overriden by multiple implementations (disk, single memory buffer, ...)
1 parent 73e53dc commit f8942e7

File tree

5 files changed

+43
-30
lines changed

5 files changed

+43
-30
lines changed

src/llama-adapter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
347347

348348
// set tensor data
349349
{
350-
llama_file gguf_file(path_lora, "rb");
350+
llama_file_disk gguf_file(path_lora, "rb");
351351
std::vector<uint8_t> read_buf;
352352
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
353353
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));

src/llama-context.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,7 +1614,7 @@ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * sr
16141614
}
16151615

16161616
bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1617-
llama_file file(filepath, "rb");
1617+
llama_file_disk file(filepath, "rb");
16181618

16191619
// sanity checks
16201620
{
@@ -1657,7 +1657,7 @@ bool llama_context::state_load_file(const char * filepath, llama_token * tokens_
16571657
}
16581658

16591659
bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
1660-
llama_file file(filepath, "wb");
1660+
llama_file_disk file(filepath, "wb");
16611661

16621662
file.write_u32(LLAMA_SESSION_MAGIC);
16631663
file.write_u32(LLAMA_SESSION_VERSION);
@@ -1674,7 +1674,7 @@ bool llama_context::state_save_file(const char * filepath, const llama_token * t
16741674
}
16751675

16761676
size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
1677-
llama_file file(filepath, "rb");
1677+
llama_file_disk file(filepath, "rb");
16781678

16791679
// version checks
16801680
{
@@ -1717,7 +1717,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
17171717
}
17181718

17191719
size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
1720-
llama_file file(filepath, "wb");
1720+
llama_file_disk file(filepath, "wb");
17211721

17221722
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
17231723
file.write_u32(LLAMA_STATE_SEQ_VERSION);

src/llama-mmap.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ static std::string llama_format_win_err(DWORD err) {
5454
}
5555
#endif
5656

57-
// llama_file
58-
59-
struct llama_file::impl {
57+
struct llama_file_disk::impl {
6058
#if defined(_WIN32)
6159
HANDLE fp_win32;
6260
std::string GetErrorMessageWin32(DWORD error_code) const {
@@ -241,13 +239,13 @@ struct llama_file::impl {
241239
size_t size;
242240
};
243241

244-
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
245-
llama_file::~llama_file() = default;
242+
llama_file_disk::llama_file_disk(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
243+
llama_file_disk::~llama_file_disk() = default;
246244

247-
size_t llama_file::tell() const { return pimpl->tell(); }
248-
size_t llama_file::size() const { return pimpl->size; }
245+
size_t llama_file_disk::tell() const { return pimpl->tell(); }
246+
size_t llama_file_disk::size() const { return pimpl->size; }
249247

250-
int llama_file::file_id() const {
248+
int llama_file_disk::file_id() const {
251249
#ifdef _WIN32
252250
return _fileno(pimpl->fp);
253251
#else
@@ -259,13 +257,13 @@ int llama_file::file_id() const {
259257
#endif
260258
}
261259

262-
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
263-
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
260+
void llama_file_disk::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
261+
void llama_file_disk::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
264262

265-
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
263+
uint32_t llama_file_disk::read_u32() const { return pimpl->read_u32(); }
266264

267-
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
268-
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
265+
void llama_file_disk::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
266+
void llama_file_disk::write_u32(uint32_t val) const { pimpl->write_u32(val); }
269267

270268
// llama_mmap
271269

src/llama-mmap.h

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,36 @@ using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
1313
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
1414

1515
struct llama_file {
16-
llama_file(const char * fname, const char * mode);
17-
~llama_file();
16+
virtual ~llama_file() = default;
1817

19-
size_t tell() const;
20-
size_t size() const;
18+
virtual size_t tell() const = 0;
19+
virtual size_t size() const = 0;
20+
virtual int file_id() const = 0;
21+
22+
virtual void seek(size_t offset, int whence) const = 0;
23+
24+
virtual void read_raw(void * ptr, size_t len) const = 0;
25+
virtual uint32_t read_u32() const = 0;
26+
27+
virtual void write_raw(const void * ptr, size_t len) const = 0;
28+
virtual void write_u32(uint32_t val) const = 0;
29+
};
30+
31+
struct llama_file_disk : public llama_file {
32+
llama_file_disk(const char * fname, const char * mode);
33+
~llama_file_disk() override;
2134

22-
int file_id() const; // fileno overload
35+
size_t tell() const override;
36+
size_t size() const override;
37+
int file_id() const override;
2338

24-
void seek(size_t offset, int whence) const;
39+
void seek(size_t offset, int whence) const override;
2540

26-
void read_raw(void * ptr, size_t len) const;
27-
uint32_t read_u32() const;
41+
void read_raw(void * ptr, size_t len) const override;
42+
uint32_t read_u32() const override;
2843

29-
void write_raw(const void * ptr, size_t len) const;
30-
void write_u32(uint32_t val) const;
44+
void write_raw(const void * ptr, size_t len) const override;
45+
void write_u32(uint32_t val) const override;
3146

3247
private:
3348
struct impl;

src/llama-model-loader.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ llama_model_loader::llama_model_loader(
500500
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
501501
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
502502

503-
files.emplace_back(new llama_file(fname.c_str(), "rb"));
503+
files.emplace_back(new llama_file_disk(fname.c_str(), "rb"));
504504
contexts.emplace_back(ctx);
505505

506506
// Save tensors data offset of the main file.
@@ -568,7 +568,7 @@ llama_model_loader::llama_model_loader(
568568
}
569569
}
570570

571-
files.emplace_back(new llama_file(fname_split, "rb"));
571+
files.emplace_back(new llama_file_disk(fname_split, "rb"));
572572
contexts.emplace_back(ctx);
573573

574574
// Save tensors data offset info of the shard.

0 commit comments

Comments
 (0)