Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def split_decodes_and_prefills(
return num_reqs, 0, num_tokens, 0

first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] >= decode_threshold)
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
Expand Down
12 changes: 12 additions & 0 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,14 @@ def _propose(
input_ids = draft_token_ids_list[-1].int()
positions += 1

if not self.torchair_graph_enabled:
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist()
attn_metadata_i.decode.cos = builder.cos_cache[
positions].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[
positions].unsqueeze(1).unsqueeze(2)

# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
Expand Down Expand Up @@ -560,6 +568,8 @@ def _propose(

if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
)
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.input_positions = self.positions[:
num_input_tokens]
Expand All @@ -569,6 +579,8 @@ def _propose(
self.runner.model_config.max_model_len)
if attn_metadata_i.decode is not None:
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
)
attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1
Expand Down
Loading