[KDA] Support CuTeDSL KDA decode kernel#21203
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces support for the CuTeDSL KDA decode kernel, enhancing the performance of Kimi-Linear/Kimi-2.5 and other KDA-architecture models. The new kernel provides a measurable speedup for single-batch inferences and is integrated into the system as an optional backend. While currently focused on decode, future work is planned to extend CuTeDSL support to prefill operations and address layout compatibility with existing Triton kernels. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for a CuTeDSL KDA decode kernel, including the kernel implementation, integration into the backend, and a comprehensive benchmark suite for correctness and performance testing. The implementation is well-structured, but there are critical thread-safety issues with the use of global caches for compiled kernels and other resources. Additionally, the benchmark script's fallback timing mechanism could be improved for better accuracy. My review provides suggestions to address these points.
|
/tag-and-rerun-ci |
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| raise SystemExit(main()) |
There was a problem hiding this comment.
Can we match the benchmark script's style with other bench scripts?
BBuf
left a comment
There was a problem hiding this comment.
Can we add a correct test?
In benchmark test, there's a correct check firstly. |
5793727 to
7f6e1ba
Compare
|
thanks for the pr. i mainly concerned about this plan item:
i think we have agreed upon standardizing the vk layout everywhere in the long-term direction to avoid complexity. So, instead of adding back the KV to the KDA triton, I think we should modify the CuTe DSL KDA kernel in VK layout, consistent with the direction Triton is already moving, rather than adding yet another KV kernel that will need migration later. |
@BBuf @kaixih After deeply investigating, I found it was not just a trivial modification to VK, but rather we need to change the logic of varlen and non-varlen APIs, normalize, and make q/k/v/a/b/A_log/dt_bias continuous, which will also introduce some overhead and impact the performance. It's comparable to implementing the GDN decode logic in FlashInfer. I think we should specifically implement a KV layout prefill logic for CuTeDSL using KDA for now, so this PR can move forward as an independent kernel first, and we can complete the e2e CuTeDSL later. Now that the prefill has been changed to VK layout, the GDN in CuTeDSL is also not passing e2e. We have the same issue. I’ll fix it together in the following PR. Is this OK for you? Thanks. |
If the long-term direction is to standardize on the VK layout, then merging this now may not add much value; we can probably pause this PR for now instead of merging code that would effectively become dead. |
Benchmark Results: CuTeDSL KDA Decode across A100, H100, H200Command: Correctness: all passed on all 3 GPUs and observed improvement in most metrics |
@BBuf This code will not be dead, instead the CuTeDSL VK layout support is in progress and is expected to be in the next PR. Can we move this baseline PR forward? It does not impact the ongoing function. |
@edwingao28 Thank you very much for your report. |
|
/rerun-failed-ci |
|
@BBuf I wrote a test script to verify that kda prefill triton kernel is non-deterministic, which proves two things:
Shall we move forward for this PR? The result is as following: |
45e143e to
6450a36
Compare
6450a36 to
2344dfa
Compare
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Motivation
This PR is to support CuTeDSL KDA decode kernel for Kimi-Linear/Kimi-2.5 and other models using KDA architecture.
Benchmark is conducted on H800. The accuracy is expected. The batch_size=1 gets about 1.05x performance uplift.
Currently we haven't integrated this kernel to e2e backend the reason is that there's only CuTeDSL decode KDA kernel for now, when integrating with the flags: --linear-attn-prefill-backend triton --linear-attn-decode-backend cutedsl, due to the prefill triton KDA kernel (shared with GDN in underneath kernel) has been modified to VK layout (introduced in #20283), it doesn't match the cutedsl decode (in KV layout).
We plan to do the following tasks:
Either approach can make the CuTeDSL KDA story model-wise completed.
Modifications
GDN/KDA are similar besides KDA uses K dimension-wise gating, while GDN uses head-wise gating. So KDA can learn more complicated forget pattern, with the cost of higher parameters(more dt_bias dimension) and higher compute complexity (need to handle K dimension gate). To be more specific:
The KDA CuTeDSL decode kernel can hereby refer to GDN CuTeDSL decode kernel with some changes in a/dt_bias/g.
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci