@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
124124 const int tg = n_tg[i_tg];
125125 const int pl = n_pl[i_pl];
126126
127- const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
127+ const int n_ctx_req = is_pp_shared ? (params. kv_unified ? pp : pl*pp) + pl*tg : pl*(pp + tg);
128128
129129 if (n_ctx_req > n_kv_max) {
130130 continue ;
@@ -147,13 +147,24 @@ int main(int argc, char ** argv) {
147147 return 1 ;
148148 }
149149
150+ const auto t_pp_end = ggml_time_us ();
151+
150152 if (is_pp_shared) {
151153 for (int32_t i = 1 ; i < pl; ++i) {
152154 llama_memory_seq_cp (mem, 0 , i, -1 , -1 );
153155 }
154- }
155156
156- const auto t_pp_end = ggml_time_us ();
157+ if (!params.kv_unified ) {
158+ // run one dummy token to apply the memory copy
159+ common_batch_clear (batch);
160+ common_batch_add (batch, get_token_rand (), pp + 0 , { 0 }, true );
161+ if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
162+ LOG_ERR (" %s: llama_decode() failed\n " , __func__);
163+ return 1 ;
164+ }
165+ llama_memory_seq_rm (mem, 0 , pp, -1 );
166+ }
167+ }
157168
158169 const auto t_tg_start = ggml_time_us ();
159170
0 commit comments