Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm_ascend/attention/context_parallel/sfa_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def _execute_sparse_flash_attention_process(
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)

def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None:
if attn_output is None:
return None
if attn_output is None or self.pcp_size == 1:
return attn_output
# In graph/piecewise mode, output buffer uses graph bucket token size
# (forward_context.num_tokens), while PCP path may compute only valid
# tokens. Align to the larger one to avoid later write-back mismatch.
Expand Down
5 changes: 3 additions & 2 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _propose(
- 1
)
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone()
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone().to(device="cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Moving ori_seq_len to the CPU with .to(device="cpu") will likely cause a device mismatch error. This tensor is used in self.runner.pcp_manager._get_cp_local_seq_lens, which performs operations involving tensors created on the default GPU device (e.g., via torch.arange). Performing operations between a CPU tensor and a GPU tensor will lead to a runtime error. ori_seq_len should remain on the GPU.

Suggested change
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone().to(device="cpu")
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz refactor the mla attention metadata and sfa metadata instead, add a seq_lens_cpu to them to avoid the frequently h2d synchornize

mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad

# slot_mapping index base offset:
Expand Down Expand Up @@ -1223,7 +1223,8 @@ def attn_update_stack_num_spec_norm(

if self.pcp_size * self.dcp_size > 1:
if self.vllm_config.model_config.use_mla:
attn_metadata.decode.cp_seq_len = cp_seq_len
if getattr(attn_metadata, "decode", None):
attn_metadata.decode.cp_seq_len = cp_seq_len
else:
attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp

Expand Down
Loading