Skip to content

Commit 93f1fa5

Browse files
committed
bugfix for mtp
Signed-off-by: zouyida2052 <[email protected]>
1 parent 3a27b15 commit 93f1fa5

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

vllm_ascend/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def split_decodes_and_prefills(
9797
return num_reqs, 0, num_tokens, 0
9898

9999
first_prefill = is_prefill.int().argmax(dim=-1).item()
100-
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
100+
assert torch.all(query_lens[first_prefill:] > decode_threshold)
101101
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
102102
num_decodes = first_prefill
103103
num_prefills = num_reqs - num_decodes

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,14 @@ def _propose(
527527
input_ids = draft_token_ids_list[-1].int()
528528
positions += 1
529529

530+
if not self.torchair_graph_enabled:
531+
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
532+
1:batch_size + 1].tolist()
533+
attn_metadata_i.decode.cos = builder.cos_cache[
534+
positions].unsqueeze(1).unsqueeze(2)
535+
attn_metadata_i.decode.sin = builder.sin_cache[
536+
positions].unsqueeze(1).unsqueeze(2)
537+
530538
# NOTE(woosuk): We should handle the case where the draft model
531539
# generates tokens beyond the max model length. Since it is complex
532540
# to remove such requests from the batch, we keep them in the batch
@@ -560,6 +568,8 @@ def _propose(
560568

561569
if attn_metadata_i.prefill is not None:
562570
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
571+
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
572+
)
563573
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
564574
attn_metadata_i.prefill.input_positions = self.positions[:
565575
num_input_tokens]
@@ -569,6 +579,8 @@ def _propose(
569579
self.runner.model_config.max_model_len)
570580
if attn_metadata_i.decode is not None:
571581
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
582+
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
583+
)
572584
attn_metadata_i.decode.input_positions = self.positions[:
573585
num_input_tokens]
574586
attn_metadata_i.decode.max_seq_lens += 1

0 commit comments

Comments
 (0)