88#include < algorithm>
99#include < cassert>
1010#include < cmath>
11+ #include < cstring>
1112#include < limits>
1213#include < map>
1314#include < stdexcept>
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
3738
3839 const uint32_t n_layer_kv = hparams.n_layer_kv ();
3940
41+ // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
42+ struct ggml_backend_buft_comparator {
43+ bool operator ()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
44+ return strcmp (ggml_backend_buft_name (lhs), ggml_backend_buft_name (rhs)) < 0 ;
45+ }
46+ };
47+ std::map<ggml_backend_buffer_type_t , ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
48+
4049 // create a context for each buffer type
41- std::map<ggml_backend_buffer_type_t , ggml_context *> ctx_map;
4250 auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
4351 auto it = ctx_map.find (buft);
4452 if (it == ctx_map.end ()) {
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
5361 return nullptr ;
5462 }
5563
56- ctx_map[buft] = ctx;
57- ctxs.emplace_back (ctx);
64+ ctx_map.emplace (buft, ctx);
5865
5966 return ctx;
6067 }
6168
62- return it->second ;
69+ return it->second . get () ;
6370 };
6471
6572 GGML_ASSERT (n_stream == 1 || n_stream == n_seq_max);
@@ -167,19 +174,16 @@ llama_kv_cache::llama_kv_cache(
167174 }
168175
169176 // allocate tensors and initialize the buffers to avoid NaNs in the padding
170- for (auto it : ctx_map) {
171- auto * buft = it.first ;
172- auto * ctx = it.second ;
173-
174- ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
177+ for (auto & [buft, ctx] : ctx_map) {
178+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx.get (), buft);
175179 if (!buf) {
176180 throw std::runtime_error (" failed to allocate buffer for kv cache" );
177181 }
178182
179183 LLAMA_LOG_INFO (" %s: %10s KV buffer size = %8.2f MiB\n " , __func__, ggml_backend_buffer_name (buf), ggml_backend_buffer_get_size (buf)/1024.0 /1024.0 );
180184
181185 ggml_backend_buffer_clear (buf, 0 );
182- bufs .emplace_back (buf);
186+ ctxs_bufs .emplace_back (std::move (ctx), buf);
183187 }
184188
185189 {
@@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
203207 }
204208
205209 if (data) {
206- for (auto & buf : bufs ) {
210+ for (auto & [_, buf] : ctxs_bufs ) {
207211 ggml_backend_buffer_clear (buf.get (), 0 );
208212 }
209213 }
@@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
472476
473477std::map<ggml_backend_buffer_type_t , size_t > llama_kv_cache::memory_breakdown () const {
474478 std::map<ggml_backend_buffer_type_t , size_t > ret;
475- for (const ggml_backend_buffer_ptr & buf_ptr : bufs ) {
476- ret[ggml_backend_buffer_get_type (buf_ptr .get ())] += ggml_backend_buffer_get_size (buf_ptr .get ());
479+ for (const auto & [_, buf] : ctxs_bufs ) {
480+ ret[ggml_backend_buffer_get_type (buf .get ())] += ggml_backend_buffer_get_size (buf .get ());
477481 }
478482 return ret;
479483}
@@ -1298,7 +1302,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
12981302size_t llama_kv_cache::total_size () const {
12991303 size_t size = 0 ;
13001304
1301- for (const auto & buf : bufs ) {
1305+ for (const auto & [_, buf] : ctxs_bufs ) {
13021306 size += ggml_backend_buffer_get_size (buf.get ());
13031307 }
13041308
0 commit comments