|
1 | 1 | #include "llama-mmap.h"
|
2 | 2 |
|
3 | 3 | #include "llama-impl.h"
|
| 4 | +#include "uint8-buff-stream.h" |
4 | 5 |
|
5 | 6 | #include "ggml.h"
|
6 | 7 |
|
|
9 | 10 | #include <stdexcept>
|
10 | 11 | #include <cerrno>
|
11 | 12 | #include <algorithm>
|
| 13 | +#include <map> |
12 | 14 |
|
13 | 15 | #ifdef __has_include
|
14 | 16 | #if __has_include(<unistd.h>)
|
@@ -265,6 +267,77 @@ uint32_t llama_file_disk::read_u32() const { return pimpl->read_u32(); }
|
265 | 267 | void llama_file_disk::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
|
266 | 268 | void llama_file_disk::write_u32(uint32_t val) const { pimpl->write_u32(val); }
|
267 | 269 |
|
| 270 | +template <bool Writable> |
| 271 | +llama_file_buffer<Writable>::llama_file_buffer(std::unique_ptr<std::basic_streambuf<uint8_t>> && streambuf) : |
| 272 | + streambuf(std::move(streambuf)) {} |
| 273 | + |
| 274 | +template <bool Writable> llama_file_buffer<Writable>::~llama_file_buffer() = default; |
| 275 | + |
| 276 | +template <bool Writable> size_t llama_file_buffer<Writable>::tell() const { |
| 277 | + return streambuf->pubseekoff(0, std::ios_base::cur); |
| 278 | +} |
| 279 | + |
| 280 | +template <bool Writable> size_t llama_file_buffer<Writable>::size() const { |
| 281 | + auto current_pos = streambuf->pubseekoff(0, std::ios_base::cur); |
| 282 | + auto end_pos = streambuf->pubseekoff(0, std::ios_base::end); |
| 283 | + streambuf->pubseekpos(current_pos); |
| 284 | + return end_pos; |
| 285 | +} |
| 286 | + |
| 287 | +template <bool Writable> int llama_file_buffer<Writable>::file_id() const { |
| 288 | + return -1; |
| 289 | +} |
| 290 | + |
| 291 | +template <bool Writable> void llama_file_buffer<Writable>::seek(size_t offset, int whence) const { |
| 292 | + static std::map<int, std::ios_base::seekdir> whence_to_dir = { |
| 293 | + { SEEK_SET, std::ios_base::beg }, |
| 294 | + { SEEK_CUR, std::ios_base::cur }, |
| 295 | + { SEEK_END, std::ios_base::end } |
| 296 | + }; |
| 297 | + auto result = streambuf->pubseekoff(offset, whence_to_dir.at(whence)); |
| 298 | + if (result == std::streampos(-1)) { |
| 299 | + throw std::runtime_error("seek failed"); |
| 300 | + } |
| 301 | +} |
| 302 | + |
| 303 | +template <bool Writable> void llama_file_buffer<Writable>::read_raw(void * ptr, size_t len) const { |
| 304 | + auto bytes_read = streambuf->sgetn(static_cast<uint8_t *>(ptr), len); |
| 305 | + if (bytes_read != static_cast<std::streamsize>(len)) { |
| 306 | + throw std::runtime_error("read beyond end of buffer"); |
| 307 | + } |
| 308 | +} |
| 309 | + |
| 310 | +template <bool Writable> uint32_t llama_file_buffer<Writable>::read_u32() const { |
| 311 | + uint32_t val; |
| 312 | + read_raw(&val, sizeof(val)); |
| 313 | + return val; |
| 314 | +} |
| 315 | + |
| 316 | +template <> |
| 317 | +[[noreturn]] void llama_file_buffer<false>::write_raw([[maybe_unused]] const void * ptr, |
| 318 | + [[maybe_unused]] size_t _len) const { |
| 319 | + throw std::runtime_error("buffer is not writable"); |
| 320 | +} |
| 321 | + |
| 322 | +template <> [[noreturn]] void llama_file_buffer<false>::write_u32([[maybe_unused]] uint32_t val) const { |
| 323 | + throw std::runtime_error("buffer is not writable"); |
| 324 | +} |
| 325 | + |
| 326 | +template <> void llama_file_buffer<true>::write_raw(const void * ptr, size_t len) const { |
| 327 | + auto bytes_written = streambuf->sputn(static_cast<const uint8_t *>(ptr), len); |
| 328 | + if (bytes_written != static_cast<std::streamsize>(len)) { |
| 329 | + throw std::runtime_error("write beyond end of buffer"); |
| 330 | + } |
| 331 | +} |
| 332 | + |
| 333 | +template <> void llama_file_buffer<true>::write_u32(uint32_t val) const { |
| 334 | + write_raw(&val, sizeof(val)); |
| 335 | +} |
| 336 | + |
| 337 | +// Explicit instantiations |
| 338 | +template struct llama_file_buffer<false>; |
| 339 | +template struct llama_file_buffer<true>; |
| 340 | + |
268 | 341 | // llama_mmap
|
269 | 342 |
|
270 | 343 | struct llama_mmap::impl {
|
|
0 commit comments