[AMD]: Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5…#21213
[AMD]: Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5…#21213ZiguanWang wants to merge 1 commit intosgl-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and compatibility of the AITER MLA attention mechanism on AMD GPUs. By introducing a head-repeat strategy, it allows the system to efficiently handle configurations with fewer attention heads per rank, such as those encountered with specific model architectures like Kimi K2.5 when using tensor parallelism. This change enables broader utilization of optimized kernels without requiring new kernel development, improving performance and resource utilization for a wider range of deployment scenarios. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for MLA with fewer than 16 heads by using head repetition, which is a great enhancement. The implementation is mostly correct, but I've identified a few areas for improvement. There is significant code duplication in the forward_extend method that should be refactored to improve maintainability. I've also suggested a small refactoring in forward_decode for better code clarity. Additionally, a comment in a test file is now outdated due to the changes and needs to be updated.
| q_in = q | ||
| o_out = o | ||
| if self.head_repeat_factor > 1: | ||
| q_in = q.repeat_interleave(self.head_repeat_factor, dim=1) | ||
| o_out = o.new_empty( | ||
| (o.shape[0], self.num_head_padded, layer.v_head_dim), | ||
| dtype=self.input_dtype, | ||
| device=o.device, | ||
| ) |
There was a problem hiding this comment.
This block of logic to prepare q_in and o_out for mla_decode_fwd is part of a larger pattern that is duplicated three times in forward_extend (here, and for is_draft_extend and is_extend_and_draft_extend modes). The full duplicated pattern includes preparing inputs, calling mla_decode_fwd, and contracting the output.
This duplication reduces maintainability, as any future changes to this logic will need to be applied in three different places. Please consider refactoring this logic into a single private helper function to avoid code repetition.
| q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) | ||
| o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) | ||
| if self.head_repeat_factor > 1: | ||
| q_view = q_view.repeat_interleave(self.head_repeat_factor, dim=1) | ||
| o_padded = o.new_empty( | ||
| (o_view.shape[0], self.num_head_padded, layer.v_head_dim), | ||
| dtype=self.input_dtype, | ||
| device=o_view.device, | ||
| ) | ||
| else: | ||
| o_padded = o_view |
There was a problem hiding this comment.
The logic for preparing inputs for mla_decode_fwd can be made clearer. Reassigning q_view within the if block and using an if/else to define o_padded makes the data flow harder to follow.
To improve readability, consider using a separate variable for the (potentially repeated) query tensor. The following suggestion introduces q_in. You would then need to update the call to mla_decode_fwd to use q_in instead of q_view.
| q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) | |
| o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) | |
| if self.head_repeat_factor > 1: | |
| q_view = q_view.repeat_interleave(self.head_repeat_factor, dim=1) | |
| o_padded = o.new_empty( | |
| (o_view.shape[0], self.num_head_padded, layer.v_head_dim), | |
| dtype=self.input_dtype, | |
| device=o_view.device, | |
| ) | |
| else: | |
| o_padded = o_view | |
| q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) | |
| o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) | |
| q_in = q_view | |
| o_padded = o_view | |
| if self.head_repeat_factor > 1: | |
| q_in = q_view.repeat_interleave(self.head_repeat_factor, dim=1) | |
| o_padded = o.new_empty( | |
| (o_view.shape[0], self.num_head_padded, layer.v_head_dim), | |
| dtype=self.input_dtype, | |
| device=o_view.device, | |
| ) |
| other_args = [ | ||
| "--tp", | ||
| "4", | ||
| "8", |
There was a problem hiding this comment.
Motivation
Modifications
Accuracy Tests
TP4:
TP8
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci