Skip to content

Commit 3e15a14

Browse files
wqerrewetwFMayran
andauthored
Qwen vl causal fix (#8)
* simple fix proposal for Qwen2.5 VL's cache causal masking issues. This is just a quick and dirty demonstration. * replace the map with a vector for performance * fix compiler warning in llama_ubatch struct construction * adapting the previous fix to the syntax used by other fields of the ubatch --------- Co-authored-by: FMayran <[email protected]>
1 parent 6fe7ad4 commit 3e15a14

File tree

5 files changed

+108
-89
lines changed

5 files changed

+108
-89
lines changed

src/llama-batch.cpp

Lines changed: 84 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ bool llama_batch_allocr::init(
224224
/*.seq_idx =*/ this->seq_idx.data(),
225225
/*.output =*/ batch.logits,
226226
/*.data =*/ {},
227+
/*.kv_position_of_token=*/ {},
227228
};
228229

229230
ubatch_print(ubatch, debug);
@@ -256,36 +257,38 @@ bool llama_batch_allocr::init(
256257
continue;
257258
}
258259

259-
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
260-
261-
if (p0 >= 0) {
262-
bool ok = true;
263-
264-
if (batch.token) {
265-
if (seq_pos_min(s) != p0 + 1) {
266-
ok = false;
267-
}
268-
} else {
269-
assert(batch.embd);
270-
271-
// for embeddings (typically used as vision input), we allow them to have repeating positions
272-
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
273-
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
274-
ok = false;
275-
}
276-
}
277-
278-
if (!ok) {
279-
LLAMA_LOG_ERROR(
280-
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
281-
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
282-
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
283-
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
284-
__func__, s, s, p0, s, seq_pos_min(s));
285-
286-
return false;
287-
}
288-
}
260+
//@fmayran: these checks don't make sense with models using position encoding such as Qwen VL, because the position stored in the KV cache can jump around (it is not even always increasing).
261+
//it is not enough to let them be repeating. Within an image embedding, arbitrary jumps are expected.
262+
//const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
263+
//
264+
//if (p0 >= 0) {
265+
// bool ok = true;
266+
//
267+
// if (batch.token) {
268+
// if (seq_pos_min(s) != p0 + 1) {
269+
// ok = false;
270+
// }
271+
// } else {
272+
// assert(batch.embd);
273+
//
274+
// // for embeddings (typically used as vision input), we allow them to have repeating positions
275+
// // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
276+
// if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
277+
// ok = false;
278+
// }
279+
// }
280+
//
281+
// if (!ok) {
282+
// LLAMA_LOG_ERROR(
283+
// "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
284+
// " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
285+
// " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
286+
// " it is required that the sequence positions remain consecutive: Y = X + 1\n",
287+
// __func__, s, s, p0, s, seq_pos_min(s));
288+
//
289+
// return false;
290+
// }
291+
//}
289292

290293
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
291294
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
@@ -369,36 +372,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
369372

370373
auto udata = std::make_shared<llama_ubatch::data_t>();
371374

372-
udata->token .resize(n_tokens);
373-
udata->embd .clear();
374-
udata->pos .resize(n_tokens);
375-
udata->n_seq_id .resize(n_tokens);
376-
udata->seq_id .resize(n_tokens);
377-
udata->seq_id_unq.resize(0);
378-
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
379-
udata->output .resize(n_tokens);
375+
udata->token .resize(n_tokens);
376+
udata->embd .clear();
377+
udata->pos .resize(n_tokens);
378+
udata->n_seq_id .resize(n_tokens);
379+
udata->seq_id .resize(n_tokens);
380+
udata->seq_id_unq .resize(0);
381+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
382+
udata->output .resize(n_tokens);
383+
udata->kv_position_of_token.resize(n_tokens, -1);
380384

381385
for (uint32_t s = 0; s < n_seqs; ++s) {
382386
udata->seq_idx[s] = s;
383387
udata->seq_id_unq.push_back(s);
384388
}
385389

386390
llama_ubatch res {
387-
/*.b_equal_seqs =*/ true,
388-
/*.n_tokens =*/ n_tokens,
389-
/*.n_seq_tokens =*/ n_seq_tokens,
390-
/*.n_seqs =*/ n_seqs,
391-
/*.n_seqs_unq =*/ n_seqs,
392-
393-
/*.token =*/ udata->token.data(),
394-
/*.embd =*/ nullptr,
395-
/*.pos =*/ udata->pos.data(),
396-
/*.n_seq_id =*/ udata->n_seq_id.data(),
397-
/*.seq_id =*/ udata->seq_id.data(),
398-
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
399-
/*.seq_idx =*/ udata->seq_idx.data(),
400-
/*.output =*/ udata->output.data(),
401-
/*.data =*/ std::move(udata),
391+
/*.b_equal_seqs =*/ true,
392+
/*.n_tokens =*/ n_tokens,
393+
/*.n_seq_tokens =*/ n_seq_tokens,
394+
/*.n_seqs =*/ n_seqs,
395+
/*.n_seqs_unq =*/ n_seqs,
396+
397+
/*.token =*/ udata->token.data(),
398+
/*.embd =*/ nullptr,
399+
/*.pos =*/ udata->pos.data(),
400+
/*.n_seq_id =*/ udata->n_seq_id.data(),
401+
/*.seq_id =*/ udata->seq_id.data(),
402+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
403+
/*.seq_idx =*/ udata->seq_idx.data(),
404+
/*.output =*/ udata->output.data(),
405+
/*.kv_position_of_token=*/ udata->kv_position_of_token.data(),
406+
/*.data =*/ std::move(udata),
402407
};
403408

404409
return res;
@@ -660,14 +665,15 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
660665
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
661666
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
662667

663-
udata->token .resize(n_tokens);
664-
udata->embd .resize(n_embd_all);
665-
udata->pos .resize(n_pos_all);
666-
udata->n_seq_id .resize(n_tokens);
667-
udata->seq_id .resize(n_tokens);
668-
udata->seq_id_unq.resize(0);
669-
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
670-
udata->output .resize(n_tokens);
668+
udata->token .resize(n_tokens);
669+
udata->embd .resize(n_embd_all);
670+
udata->pos .resize(n_pos_all);
671+
udata->n_seq_id .resize(n_tokens);
672+
udata->seq_id .resize(n_tokens);
673+
udata->seq_id_unq .resize(0);
674+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
675+
udata->output .resize(n_tokens);
676+
udata->kv_position_of_token.resize(n_tokens, -1);
671677

672678
seq_set_t seq_set_unq;
673679

@@ -705,21 +711,23 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
705711
}
706712

707713
llama_ubatch res {
708-
/*.b_equal_seqs =*/ equal_seqs,
709-
/*.n_tokens =*/ n_tokens,
710-
/*.n_seq_tokens =*/ n_tokens/n_seqs,
711-
/*.n_seqs =*/ n_seqs,
712-
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
713-
714-
/*.token =*/ batch.token ? udata->token.data() : nullptr,
715-
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
716-
/*.pos =*/ udata->pos.data(),
717-
/*.n_seq_id =*/ udata->n_seq_id.data(),
718-
/*.seq_id =*/ udata->seq_id.data(),
719-
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
720-
/*.seq_idx =*/ udata->seq_idx.data(),
721-
/*.output =*/ udata->output.data(),
722-
/*.data =*/ std::move(udata),
714+
/*.b_equal_seqs =*/ equal_seqs,
715+
/*.n_tokens =*/ n_tokens,
716+
/*.n_seq_tokens =*/ n_tokens/n_seqs,
717+
/*.n_seqs =*/ n_seqs,
718+
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
719+
720+
/*.token =*/ batch.token ? udata->token.data() : nullptr,
721+
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
722+
/*.pos =*/ udata->pos.data(),
723+
/*.n_seq_id =*/ udata->n_seq_id.data(),
724+
/*.seq_id =*/ udata->seq_id.data(),
725+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
726+
/*.seq_idx =*/ udata->seq_idx.data(),
727+
/*.output =*/ udata->output.data(),
728+
/*.kv_position_of_token=*/ udata->kv_position_of_token.data(),
729+
/*.data =*/ std::move(udata),
730+
723731
};
724732

725733
if (debug > 0) {

src/llama-batch.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@ struct llama_ubatch {
3030
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
3131
// used for extracting sequence pooled embeddings
3232

33-
// // size | idx | val
34-
llama_token * token; // [n_tokens] | i | id, token
35-
float * embd; // [n_embd, n_tokens] | i | embd
36-
llama_pos * pos; // [n_tokens] | i | pos
37-
int32_t * n_seq_id; // [n_tokens] | i | -
38-
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
39-
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
40-
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
41-
int8_t * output; // [n_tokens] | i | -
33+
// // size | idx | val
34+
llama_token * token; // [n_tokens] | i | id, token
35+
float * embd; // [n_embd, n_tokens] | i | embd
36+
llama_pos * pos; // [n_tokens] | i | pos
37+
int32_t * n_seq_id; // [n_tokens] | i | -
38+
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
39+
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
40+
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
41+
int8_t * output; // [n_tokens] | i | -
42+
int32_t * kv_position_of_token; // [n_tokens] | i | kv position whre the token was inserted
4243

4344
struct data_t {
4445
std::vector<llama_token> token;
@@ -49,6 +50,7 @@ struct llama_ubatch {
4950
std::vector<llama_seq_id> seq_id_unq;
5051
std::vector<int32_t> seq_idx;
5152
std::vector<int8_t> output;
53+
std::vector<int32_t> kv_position_of_token;//when pushed to the kv cache, where is the token pushed (used for causal masking)
5254
};
5355

5456
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data

src/llama-kv-cache.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,7 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
895895
}
896896

897897
cells.pos_set(idx, ubatch.pos[i]);
898+
ubatch.kv_position_of_token[i] = (int32_t)idx;//set the position in the kv cache as a property for this token (needed for proper causal masking)
898899

899900
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
900901
cells.seq_add(idx, ubatch.seq_id[i][s]);
@@ -1215,6 +1216,12 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12151216

12161217
std::fill(data, data + ggml_nelements(dst), -INFINITY);
12171218

1219+
std::vector<int32_t> map_kv_to_batch(n_kv, -1);//for each token in the cache, either (-1) or the position in the current ubatch
1220+
for (uint32_t i = 0; i < n_tokens; ++i)//invert the batch -> kv position map into a kv -> batch position map
1221+
{
1222+
if (ubatch->kv_position_of_token[i] != -1)
1223+
map_kv_to_batch[ubatch->kv_position_of_token[i]] = i;
1224+
}
12181225
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
12191226
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
12201227
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
@@ -1254,8 +1261,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12541261
const llama_pos p0 = cells.pos_get(j);
12551262

12561263
// mask future tokens
1257-
if (causal_attn && p0 > p1) {
1258-
continue;
1264+
if (causal_attn)
1265+
{
1266+
if (map_kv_to_batch[j] != -1 && map_kv_to_batch[j] > (int32_t)i)//if the kv cache token is in the current batch AND its position in the batch is higher than i
1267+
continue;
12591268
}
12601269

12611270
// apply SWA if any

tools/mtmd/mtmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
10261026

10271027
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
10281028
if (image_tokens->use_mrope_pos) {
1029-
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
1029+
return (::std::max)(image_tokens->nx, image_tokens->ny);//assuming image, not video // for M-RoPE, the whole image is 1 in temporal dimension
10301030
}
10311031
return image_tokens->n_tokens();
10321032
}

tools/mtmd/mtmd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
153153
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
154154
// returns nullptr for ID on text chunk
155155
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
156-
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
156+
// number of temporal positions (always max(ntok_x, ntok_y, ntok_t) for M-RoPE, n_tokens otherwise)
157157
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
158158

159159
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)

0 commit comments

Comments
 (0)