Skip to content

[AMD]: Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5…#21213

Open
ZiguanWang wants to merge 1 commit intosgl-project:mainfrom
AMD-AIM:kimi-mxfp4-tp8
Open

[AMD]: Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5…#21213
ZiguanWang wants to merge 1 commit intosgl-project:mainfrom
AMD-AIM:kimi-mxfp4-tp8

Conversation

@ZiguanWang
Copy link
Contributor

Motivation

Modifications

  1. Support AITER MLA for num_heads < 16 (e.g., TP=8 with Kimi K2.5 giving 8 heads/rank). Uses head-repeat to expand to 16 heads before calling the AITER MLA decode kernel, then contracts the output back. This reuses the existing optimized gqa_ratio=16 ASM kernel without requiring new kernel variants.
  2. Relax head-count assertion to accept num_heads of 4, 8, or any multiple of 16 in [16, 128]. Previously only {16, 128} were accepted.

Accuracy Tests

TP4:

Accuracy: 0.929
Invalid: 0.000
Latency: 108.001 s
Output throughput: 1245.121 token/s

TP8

Accuracy: 0.928
Invalid: 0.000
Latency: 71.932 s
Output throughput: 1858.801 token/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and compatibility of the AITER MLA attention mechanism on AMD GPUs. By introducing a head-repeat strategy, it allows the system to efficiently handle configurations with fewer attention heads per rank, such as those encountered with specific model architectures like Kimi K2.5 when using tensor parallelism. This change enables broader utilization of optimized kernels without requiring new kernel development, improving performance and resource utilization for a wider range of deployment scenarios.

Highlights

  • Expanded AITER MLA Support: Implemented support for AITER MLA (Multi-Head Attention) with num_heads less than 16 (specifically 4 or 8 heads per rank), by using a head-repeat mechanism to expand to 16 heads before kernel execution and then contracting the output. This reuses existing optimized kernels.
  • Relaxed Head Count Assertions: Modified the head-count assertion for AITER MLA to accept num_heads of 4, 8, or any multiple of 16 within the range [16, 128], broadening compatibility beyond the previous {16, 128} restriction.
  • Test Configuration Update: Updated the test_kimi_k25_mxfp4.py test to use a tensor parallel size (TP) of 8, aligning with the new head configuration support.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for MLA with fewer than 16 heads by using head repetition, which is a great enhancement. The implementation is mostly correct, but I've identified a few areas for improvement. There is significant code duplication in the forward_extend method that should be refactored to improve maintainability. I've also suggested a small refactoring in forward_decode for better code clarity. Additionally, a comment in a test file is now outdated due to the changes and needs to be updated.

Comment on lines +2203 to +2211
q_in = q
o_out = o
if self.head_repeat_factor > 1:
q_in = q.repeat_interleave(self.head_repeat_factor, dim=1)
o_out = o.new_empty(
(o.shape[0], self.num_head_padded, layer.v_head_dim),
dtype=self.input_dtype,
device=o.device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of logic to prepare q_in and o_out for mla_decode_fwd is part of a larger pattern that is duplicated three times in forward_extend (here, and for is_draft_extend and is_extend_and_draft_extend modes). The full duplicated pattern includes preparing inputs, calling mla_decode_fwd, and contracting the output.

This duplication reduces maintainability, as any future changes to this logic will need to be applied in three different places. Please consider refactoring this logic into a single private helper function to avoid code repetition.

Comment on lines +2504 to +2514
q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
if self.head_repeat_factor > 1:
q_view = q_view.repeat_interleave(self.head_repeat_factor, dim=1)
o_padded = o.new_empty(
(o_view.shape[0], self.num_head_padded, layer.v_head_dim),
dtype=self.input_dtype,
device=o_view.device,
)
else:
o_padded = o_view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for preparing inputs for mla_decode_fwd can be made clearer. Reassigning q_view within the if block and using an if/else to define o_padded makes the data flow harder to follow.

To improve readability, consider using a separate variable for the (potentially repeated) query tensor. The following suggestion introduces q_in. You would then need to update the call to mla_decode_fwd to use q_in instead of q_view.

Suggested change
q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
if self.head_repeat_factor > 1:
q_view = q_view.repeat_interleave(self.head_repeat_factor, dim=1)
o_padded = o.new_empty(
(o_view.shape[0], self.num_head_padded, layer.v_head_dim),
dtype=self.input_dtype,
device=o_view.device,
)
else:
o_padded = o_view
q_view = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_view = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_in = q_view
o_padded = o_view
if self.head_repeat_factor > 1:
q_in = q_view.repeat_interleave(self.head_repeat_factor, dim=1)
o_padded = o.new_empty(
(o_view.shape[0], self.num_head_padded, layer.v_head_dim),
dtype=self.input_dtype,
device=o_view.device,
)

other_args = [
"--tp",
"4",
"8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With this change to tp=8, the comment on line 44 (# TP=4 required...) is now outdated and misleading. Please update the comment to reflect that num_heads < 16 is now supported and this test case is specifically for that scenario (e.g., 64 heads / tp=8 = 8 heads per GPU).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant