-
Notifications
You must be signed in to change notification settings - Fork 389
refact mla forward_prefill&forward_decode #2342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: lwq <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the _forward_prefill
and _forward_decode
methods in AscendMLAImpl
to simplify the logic by removing several conditional paths based on attention states. While the intent to simplify is good, the refactoring has introduced critical issues in the function signatures of both methods. Incorrect type hints and mismatched arguments will lead to runtime TypeError
s. These need to be fixed to ensure the code runs correctly.
self, | ||
q_nope: torch.Tensor, | ||
ql_nope: torch.Tensor, | ||
q_pe: torch.Tensor, | ||
k_nope: torch.Tensor, | ||
k_pe: torch.Tensor, | ||
kv_c_and_k_pe_cache: Tuple[torch.Tensor], | ||
block_size: int, | ||
attn_metadata: AscendMLAMetadata, | ||
enable_multistream_mla: bool = False, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of _forward_decode
was changed to take block_size: int
instead of kv_c_and_k_pe_cache
. However, the call sites in the forward
method are not updated and still pass kv_cache
(a tuple of tensors), which will cause a TypeError
. To fix this, you should revert the signature to accept kv_c_and_k_pe_cache
and derive block_size
inside this function, as it was done previously (block_size = kv_c_and_k_pe_cache[0].shape[1]
).
self, | |
q_nope: torch.Tensor, | |
ql_nope: torch.Tensor, | |
q_pe: torch.Tensor, | |
k_nope: torch.Tensor, | |
k_pe: torch.Tensor, | |
kv_c_and_k_pe_cache: Tuple[torch.Tensor], | |
block_size: int, | |
attn_metadata: AscendMLAMetadata, | |
enable_multistream_mla: bool = False, | |
) -> torch.Tensor: | |
self, | |
ql_nope: torch.Tensor, | |
q_pe: torch.Tensor, | |
k_nope: torch.Tensor, | |
k_pe: torch.Tensor, | |
kv_c_and_k_pe_cache: Tuple[torch.Tensor, ...], | |
attn_metadata: AscendMLAMetadata, | |
enable_multistream_mla: bool = False, | |
) -> torch.Tensor: |
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?