-
Notifications
You must be signed in to change notification settings - Fork 600
[Bugfix] Fix model run _npu_flash_attention hang issue #4410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request effectively addresses a reported hang issue in _npu_flash_attention by refining the handling of attention mask dtypes. The key change involves refactoring the logic for retrieving the chunked_prefill_attn_mask into a dedicated method, get_chunked_prefill_attn_mask. This new method explicitly ensures the mask is converted to torch.bool, which is crucial for the correct operation of the attention mechanism. This refactoring not only fixes the bug but also enhances code clarity and maintainability by centralizing the dtype conversion for this specific mask.
| def get_chunked_prefill_attn_mask(self): | ||
| return self.chunked_prefill_attn_mask.to(torch.bool) | ||
|
|
||
| def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The introduction of get_chunked_prefill_attn_mask and the removal of the conditional logic from get_attn_mask is a significant improvement. This refactoring clearly separates the responsibility of providing the chunked prefill attention mask and explicitly ensures its torch.bool dtype. This directly addresses the reported "wrong attention mask dtype" issue, which was causing hangs in _npu_flash_attention, by enforcing the correct data type for this specific mask. It also makes the get_attn_mask method more focused on its general purpose.
| elif attn_state == AscendAttentionState.PrefillCacheHit: | ||
| return self.attn_mask_builder.get_attn_mask( | ||
| 2048, self.dtype, self.device) | ||
| return self.attn_mask_builder.get_chunked_prefill_attn_mask() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating the call to use the new get_chunked_prefill_attn_mask() method is a correct and consistent application of the refactored logic. This change ensures that the attention mask used in the PrefillCacheHit state consistently has the torch.bool dtype, which is essential for preventing the _npu_flash_attention hang issue as described in the PR.
f5f2eac to
36d1b5e
Compare
Fix model run _npu_flash_attention in _forward_prefill_no_cache hang issue, it was caused by wrong attention mask dtype. Signed-off-by: Ting FU <[email protected]>
36d1b5e to
5b7ef34
Compare
Fix model run _npu_flash_attention in _forward_prefill_no_cache hang issue, it was caused by wrong attention mask dtype.
What this PR does / why we need it?
Does this PR introduce any user-facing change?
No
How was this patch tested?
Yes, tesed on Qwen2.5-VL and Qwen2.5-Omni