@@ -2805,6 +2805,11 @@ static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
28052805 cache.do_defrag = true;
28062806}
28072807
2808+ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
2809+ // the FA kernels require padding to avoid extra runtime boundary checks
2810+ return cparams.flash_attn ? 256u : 32u;
2811+ }
2812+
28082813//
28092814// model loading and saving
28102815//
@@ -11510,7 +11515,8 @@ static int llama_decode_internal(
1151011515 // a heuristic, to avoid attending the full cache if it is not yet utilized
1151111516 // after enough generations, the benefit from this heuristic disappears
1151211517 // if we start defragmenting the cache, the benefit from this will be more important
11513- kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
11518+ const uint32_t pad = llama_kv_cache_get_padding(cparams);
11519+ kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
1151411520 //kv_self.n = llama_kv_cache_cell_max(kv_self);
1151511521 }
1151611522 }
@@ -15511,6 +15517,11 @@ struct llama_context * llama_new_context_with_model(
1551115517 return nullptr;
1551215518 }
1551315519
15520+ if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
15521+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15522+ params.flash_attn = false;
15523+ }
15524+
1551415525 llama_context * ctx = new llama_context(*model);
1551515526
1551615527 const auto & hparams = model->hparams;
@@ -15534,7 +15545,7 @@ struct llama_context * llama_new_context_with_model(
1553415545 cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
1553515546
1553615547 // this is necessary due to kv_self.n being padded later during inference
15537- cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256 );
15548+ cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams) );
1553815549
1553915550 // with causal attention, the batch size is limited by the context size
1554015551 cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -15579,11 +15590,6 @@ struct llama_context * llama_new_context_with_model(
1557915590 }
1558015591 }
1558115592
15582- if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
15583- LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
15584- cparams.flash_attn = false;
15585- }
15586-
1558715593 if (params.seed == LLAMA_DEFAULT_SEED) {
1558815594 params.seed = time(NULL);
1558915595 }
0 commit comments