-
Notifications
You must be signed in to change notification settings - Fork 468
[Draft]Append padding logic for Attention #3256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Angazenn <[email protected]>
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces padding logic for attention metadata, which is necessary for ACL graph capture with dynamic batch sizes. The changes in vllm_ascend/worker/model_runner_v1.py
correctly control when this padding is applied. However, I've found a couple of critical issues in the padding implementation in vllm_ascend/attention/attention_v1.py
. Specifically, there's an off-by-one error when padding query_start_loc_cpu
, and query_lens
is not being padded at all, which will lead to tensor shape mismatches and incorrect attention computation. I've provided a code suggestion to fix these issues.
padded_num_tokens = common_attn_metadata.graph_pad_size - num_actual_tokens | ||
seq_lens = torch.cat([seq_lens, torch.ones(padded_num_tokens, dtype=seq_lens.dtype, device=seq_lens.device)]) | ||
block_table_padding = torch.zeros( | ||
(padded_num_tokens, ) + block_table.shape[1:], | ||
dtype=block_table.dtype, | ||
device=block_table.device) | ||
block_table = torch.cat([block_table, block_table_padding], | ||
dim=0) | ||
query_start_loc_cpu = torch.cat([query_start_loc_cpu, torch.arange(query_start_loc_cpu[-1] + 1, query_start_loc_cpu[-1] + padded_num_tokens, dtype=query_start_loc_cpu.dtype, device=query_start_loc_cpu.device)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two issues in this padding logic that will cause problems during execution:
-
query_start_loc_cpu
padding is incorrect: The call totorch.arange
has an off-by-one error. Theend
parameter is exclusive, sotorch.arange(start, end)
generatesend - start
elements. Your code generatespadded_num_tokens - 1
elements, but you need to padpadded_num_tokens
elements to match the paddedseq_lens
. This will cause a tensor size mismatch. Theend
argument should bequery_start_loc_cpu[-1] + padded_num_tokens + 1
. -
query_lens
is not padded: You are paddingseq_lens
,block_table
, andquery_start_loc_cpu
, butquery_lens
(calculated before this block) is not updated. This will lead to a mismatch in the number of sequences betweenquery_lens
and other padded tensors when creatingAscendMetadata
, causing errors downstream wherequery_lens
is used.query_lens
should also be padded with ones for the new sequences.
I've provided a suggestion to fix both issues.
padded_num_tokens = common_attn_metadata.graph_pad_size - num_actual_tokens
seq_lens = torch.cat([seq_lens, torch.ones(padded_num_tokens, dtype=seq_lens.dtype, device=seq_lens.device)])
query_lens = torch.cat([query_lens, torch.ones(padded_num_tokens, dtype=query_lens.dtype, device=query_lens.device)])
block_table_padding = torch.zeros(
(padded_num_tokens, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
new_query_locs = torch.arange(
query_start_loc_cpu[-1] + 1,
query_start_loc_cpu[-1] + padded_num_tokens + 1,
dtype=query_start_loc_cpu.dtype,
device=query_start_loc_cpu.device)
query_start_loc_cpu = torch.cat([query_start_loc_cpu, new_query_locs])
Signed-off-by: Angazenn <[email protected]>
What this PR does / why we need it?
This PR aims to add padding logic to seq_lens、block_tables when running in full decode scenario. Before this PR, the number of input tokens with padding might exceeds corresponding seq_lens. For example, when running in full decode scenario:
Here,
input_ids
is padded by 2 tokens whileseq_lens
/query_start_loc
are not. The mismatch betweeninput_ids
andseq_lens
/query_start_loc
might cause some potential bugs. This PR would change it into :Does this PR introduce any user-facing change?
No.
How was this patch tested?