diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 519cde0c5a..efc11030d5 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -97,7 +97,7 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] >= decode_threshold) + assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index d0a0d507fa..8b1fe0ba34 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -527,6 +527,14 @@ def _propose( input_ids = draft_token_ids_list[-1].int() positions += 1 + if not self.torchair_graph_enabled: + attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + 1:batch_size + 1].tolist() + attn_metadata_i.decode.cos = builder.cos_cache[ + positions].unsqueeze(1).unsqueeze(2) + attn_metadata_i.decode.sin = builder.sin_cache[ + positions].unsqueeze(1).unsqueeze(2) + # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch @@ -560,6 +568,8 @@ def _propose( if attn_metadata_i.prefill is not None: attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( + ) attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens attn_metadata_i.prefill.input_positions = self.positions[: num_input_tokens] @@ -569,6 +579,8 @@ def _propose( self.runner.model_config.max_model_len) if attn_metadata_i.decode is not None: attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( + ) attn_metadata_i.decode.input_positions = self.positions[: num_input_tokens] attn_metadata_i.decode.max_seq_lens += 1