Skip to content

Commit 2beac32

Browse files
committed
bugfix for mtp
Signed-off-by: zouyida2052 <[email protected]>
1 parent 3a27b15 commit 2beac32

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

vllm_ascend/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def split_decodes_and_prefills(
9797
return num_reqs, 0, num_tokens, 0
9898

9999
first_prefill = is_prefill.int().argmax(dim=-1).item()
100-
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
100+
assert torch.all(query_lens[first_prefill:] > decode_threshold)
101101
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
102102
num_decodes = first_prefill
103103
num_prefills = num_reqs - num_decodes

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,11 @@ def _propose(
527527
input_ids = draft_token_ids_list[-1].int()
528528
positions += 1
529529

530+
if not is_running_torchair:
531+
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[1:batch_size + 1].tolist()
532+
attn_metadata_i.decode.cos = builder.cos_cache[positions].unsqueeze(1).unsqueeze(2)
533+
attn_metadata_i.decode.sin = builder.sin_cache[positions].unsqueeze(1).unsqueeze(2)
534+
530535
# NOTE(woosuk): We should handle the case where the draft model
531536
# generates tokens beyond the max model length. Since it is complex
532537
# to remove such requests from the batch, we keep them in the batch
@@ -560,6 +565,7 @@ def _propose(
560565

561566
if attn_metadata_i.prefill is not None:
562567
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
568+
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist()
563569
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
564570
attn_metadata_i.prefill.input_positions = self.positions[:
565571
num_input_tokens]
@@ -569,6 +575,7 @@ def _propose(
569575
self.runner.model_config.max_model_len)
570576
if attn_metadata_i.decode is not None:
571577
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
578+
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist()
572579
attn_metadata_i.decode.input_positions = self.positions[:
573580
num_input_tokens]
574581
attn_metadata_i.decode.max_seq_lens += 1

0 commit comments

Comments
 (0)