Skip to content

[KDA] Support CuTeDSL KDA decode kernel#21203

Merged
BBuf merged 3 commits intosgl-project:mainfrom
antgroup:cutedsl_kda_decode
Mar 25, 2026
Merged

[KDA] Support CuTeDSL KDA decode kernel#21203
BBuf merged 3 commits intosgl-project:mainfrom
antgroup:cutedsl_kda_decode

Conversation

@yuan-luo
Copy link
Collaborator

@yuan-luo yuan-luo commented Mar 23, 2026

Motivation

This PR is to support CuTeDSL KDA decode kernel for Kimi-Linear/Kimi-2.5 and other models using KDA architecture.
Benchmark is conducted on H800. The accuracy is expected. The batch_size=1 gets about 1.05x performance uplift.

Currently we haven't integrated this kernel to e2e backend the reason is that there's only CuTeDSL decode KDA kernel for now, when integrating with the flags: --linear-attn-prefill-backend triton --linear-attn-decode-backend cutedsl, due to the prefill triton KDA kernel (shared with GDN in underneath kernel) has been modified to VK layout (introduced in #20283), it doesn't match the cutedsl decode (in KV layout).

We plan to do the following tasks:

  1. Support CuTeDSL KDA prefill kernel.
  2. Support triton prefill KDA kernel with KV and VK layout switch off.

Either approach can make the CuTeDSL KDA story model-wise completed.

oot@e1448ef40573:/sgl-workspace/sglang_dev2# python ./benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py --cuda-graph
Device: NVIDIA L20Y  (SM 90)
==============================================================================
Correctness: Triton KDA Decode vs CuTe DSL KDA Decode
==============================================================================
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=dense  B= 1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=dense  B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=dense  B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=dense  B=64 H= 8 HV=16 K=128 V=128 pool= 128
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=varlen B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=varlen B=16 H= 8 HV=16 K=128 V=128 pool=  64
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=varlen B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B= 1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] layout=varlen B=32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H=16 HV=16 K=128 V=128 pool= 128

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling (valid outputs/states only)

ALL PASSED.

============================================================================================
Benchmark: Triton vs CuTe DSL KDA Decode  [CUDA Graph replay — production mode]
============================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
  layout      B    H   HV    K    V |  triton (us) |  cutedsl (us) |  speedup |  delta (us)
  ----------------------------------------------------------------------------------
   dense      1    8   16  128  128 |       12.7 |       12.1 |    1.05x |      +0.6
   dense      4    8   16  128  128 |       12.6 |       12.6 |    1.01x |      +0.1
   dense     32    8   16  128  128 |       24.4 |       23.9 |    1.02x |      +0.6
   dense     64    8   16  128  128 |       60.2 |       60.7 |    0.99x |      -0.5
  varlen      1    8   16  128  128 |       12.6 |       12.2 |    1.03x |      +0.4
  varlen      4    8   16  128  128 |       12.7 |       12.6 |    1.01x |      +0.1
  varlen      8    8   16  128  128 |       13.6 |       13.6 |    1.00x |      +0.0
  varlen     16    8   16  128  128 |       16.3 |       16.6 |    0.98x |      -0.4
  varlen     32    8   16  128  128 |       24.5 |       24.1 |    1.02x |      +0.4
  varlen     64    8   16  128  128 |       59.2 |       60.6 |    0.98x |      -1.3
  varlen    128    8   16  128  128 |      103.6 |      105.9 |    0.98x |      -2.4
  varlen     32   16   32  128  128 |       59.2 |       60.8 |    0.97x |      -1.6
  varlen     64   16   16  128  128 |       59.8 |       60.8 |    0.98x |      -1.0

Modifications

GDN/KDA are similar besides KDA uses K dimension-wise gating, while GDN uses head-wise gating. So KDA can learn more complicated forget pattern, with the cost of higher parameters(more dt_bias dimension) and higher compute complexity (need to handle K dimension gate). To be more specific:

  1. a: GDN is per-head with shape (N, 1, HV), KDA is per-head-per-key with shape (N, 1, HV, K)
  2. dt_bias: GDN is (HV, ), KDA is (HV, K)
  3. g: GDN is scalar, KDA is vector in K dimension: h[k, v] *= exp(g[k])
  4. beta and other logic are identical: A_log's shape (HV), b (H, HV), beta=sigmoid(b)

The KDA CuTeDSL decode kernel can hereby refer to GDN CuTeDSL decode kernel with some changes in a/dt_bias/g.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for the CuTeDSL KDA decode kernel, enhancing the performance of Kimi-Linear/Kimi-2.5 and other KDA-architecture models. The new kernel provides a measurable speedup for single-batch inferences and is integrated into the system as an optional backend. While currently focused on decode, future work is planned to extend CuTeDSL support to prefill operations and address layout compatibility with existing Triton kernels.

Highlights

  • CuTeDSL KDA Decode Kernel Implementation: Implemented a new CuTeDSL kernel for Fused Sigmoid Gating KDA (Kimi Delta Attention) decode operations, supporting both small and large batch sizes, and variable-length sequences.
  • Performance Improvement: Achieved a performance uplift of approximately 1.05x for batch_size=1 on H800 GPUs when using the new CuTeDSL decode kernel.
  • New Benchmark Script: Added a dedicated benchmark and correctness testing script for the CuTeDSL KDA decode kernel, comparing its performance and accuracy against the existing Triton implementation.
  • KDA Backend Integration: Integrated the new CuTeDSL KDA decode kernel into the KDA backend, allowing it to be selected as an alternative to the Triton kernel.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yuan-luo yuan-luo requested review from kaixih and yizhang2077 March 23, 2026 09:26
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for a CuTeDSL KDA decode kernel, including the kernel implementation, integration into the backend, and a comprehensive benchmark suite for correctness and performance testing. The implementation is well-structured, but there are critical thread-safety issues with the use of global caches for compiled kernels and other resources. Additionally, the benchmark script's fallback timing mechanism could be improved for better accuracy. My review provides suggestions to address these points.

@yuan-luo
Copy link
Collaborator Author

/tag-and-rerun-ci



if __name__ == "__main__":
raise SystemExit(main())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we match the benchmark script's style with other bench scripts?

Copy link
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a correct test?

@yuan-luo
Copy link
Collaborator Author

Can we add a correct test?

In benchmark test, there's a correct check firstly.

@yuan-luo yuan-luo force-pushed the cutedsl_kda_decode branch from 5793727 to 7f6e1ba Compare March 23, 2026 13:35
@kaixih
Copy link
Collaborator

kaixih commented Mar 23, 2026

thanks for the pr. i mainly concerned about this plan item:

Support triton prefill KDA kernel with KV and VK layout switch off.

i think we have agreed upon standardizing the vk layout everywhere in the long-term direction to avoid complexity. So, instead of adding back the KV to the KDA triton, I think we should modify the CuTe DSL KDA kernel in VK layout, consistent with the direction Triton is already moving, rather than adding yet another KV kernel that will need migration later.

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Mar 24, 2026

thanks for the pr. i mainly concerned about this plan item:

Support triton prefill KDA kernel with KV and VK layout switch off.

i think we have agreed upon standardizing the vk layout everywhere in the long-term direction to avoid complexity. So, instead of adding back the KV to the KDA triton, I think we should modify the CuTe DSL KDA kernel in VK layout, consistent with the direction Triton is already moving, rather than adding yet another KV kernel that will need migration later.

@BBuf @kaixih After deeply investigating, I found it was not just a trivial modification to VK, but rather we need to change the logic of varlen and non-varlen APIs, normalize, and make q/k/v/a/b/A_log/dt_bias continuous, which will also introduce some overhead and impact the performance. It's comparable to implementing the GDN decode logic in FlashInfer. I think we should specifically implement a KV layout prefill logic for CuTeDSL using KDA for now, so this PR can move forward as an independent kernel first, and we can complete the e2e CuTeDSL later.

Now that the prefill has been changed to VK layout, the GDN in CuTeDSL is also not passing e2e. We have the same issue. I’ll fix it together in the following PR.

Is this OK for you? Thanks.

@BBuf
Copy link
Collaborator

BBuf commented Mar 24, 2026

thanks for the pr. i mainly concerned about this plan item:

Support triton prefill KDA kernel with KV and VK layout switch off.

i think we have agreed upon standardizing the vk layout everywhere in the long-term direction to avoid complexity. So, instead of adding back the KV to the KDA triton, I think we should modify the CuTe DSL KDA kernel in VK layout, consistent with the direction Triton is already moving, rather than adding yet another KV kernel that will need migration later.

@BBuf @kaixih After deeply investigating, I found it was not just a trivial modification to VK, but rather we need to change the logic of varlen and non-varlen APIs, normalize, and make q/k/v/a/b/A_log/dt_bias continuous, which will also introduce some overhead and impact the performance. It's comparable to implementing the GDN decode logic in FlashInfer. I think we should specifically implement a KV layout prefill logic for CuTeDSL using KDA for now, so this PR can move forward as an independent kernel first, and we can complete the e2e CuTeDSL later.

Now that the prefill has been changed to VK layout, the GDN in CuTeDSL is also not passing e2e. We have the same issue. I’ll fix it together in the following PR.

Is this OK for you? Thanks.

If the long-term direction is to standardize on the VK layout, then merging this now may not add much value; we can probably pause this PR for now instead of merging code that would effectively become dead.

@edwingao28
Copy link
Contributor

Benchmark Results: CuTeDSL KDA Decode across A100, H100, H200

Command:

python benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py  --cuda-graph

Correctness: all passed on all 3 GPUs and observed improvement in most metrics

cc: @yuan-luo @BBuf @kaixih

Device: NVIDIA A100-SXM4-80GB  (SM 80)
==============================================================================
Correctness: Triton KDA Decode vs CuTe DSL KDA Decode
==============================================================================
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:232: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:252: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=dense  B= 1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=dense  B= 4 H= 8 HV=16 K=128 V=128 pool=  32
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:702: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:719: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=dense  B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=dense  B=64 H= 8 HV=16 K=128 V=128 pool= 128
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:472: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:492: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=varlen B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=varlen B=16 H= 8 HV=16 K=128 V=128 pool=  64
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:927: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:944: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=varlen B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B= 1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] layout=varlen B=32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H=16 HV=16 K=128 V=128 pool= 128

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling (valid outputs/states only)

ALL PASSED.

============================================================================================
Benchmark: Triton vs CuTe DSL KDA Decode  [CUDA Graph replay — production mode]
============================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
  layout      B    H   HV    K    V |  triton (us) |  cutedsl (us) |  speedup |  delta (us)
  ----------------------------------------------------------------------------------
   dense      1    8   16  128  128 |       14.0 |       12.4 |    1.13x |      +1.6
   dense      4    8   16  128  128 |       15.8 |       14.5 |    1.09x |      +1.3
   dense     32    8   16  128  128 |       57.1 |       50.0 |    1.14x |      +7.1
   dense     64    8   16  128  128 |      103.1 |       97.2 |    1.06x |      +5.9
  varlen      1    8   16  128  128 |       13.6 |       12.0 |    1.13x |      +1.6
  varlen      4    8   16  128  128 |       15.8 |       14.5 |    1.09x |      +1.3
  varlen      8    8   16  128  128 |       19.4 |       19.7 |    0.99x |      -0.3
  varlen     16    8   16  128  128 |       29.9 |       28.4 |    1.05x |      +1.5
  varlen     32    8   16  128  128 |       56.6 |       49.9 |    1.14x |      +6.8
  varlen     64    8   16  128  128 |      101.8 |       97.2 |    1.05x |      +4.6
  varlen    128    8   16  128  128 |      181.5 |      178.2 |    1.02x |      +3.4
  varlen     32   16   32  128  128 |      105.2 |       99.5 |    1.06x |      +5.7
  varlen     64   16   16  128  128 |      103.3 |      100.2 |    1.03x |      +3.0
Device: NVIDIA H100 80GB HBM3  (SM 90)
==============================================================================
Correctness: Triton KDA Decode vs CuTe DSL KDA Decode
==============================================================================
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:232: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:252: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=dense  B= 1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=dense  B= 4 H= 8 HV=16 K=128 V=128 pool=  32
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:702: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:719: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=dense  B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=dense  B=64 H= 8 HV=16 K=128 V=128 pool= 128
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:472: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:492: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=varlen B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=varlen B=16 H= 8 HV=16 K=128 V=128 pool=  64
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:927: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:944: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=varlen B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B= 1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] layout=varlen B=32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H=16 HV=16 K=128 V=128 pool= 128

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling (valid outputs/states only)

ALL PASSED.

============================================================================================
Benchmark: Triton vs CuTe DSL KDA Decode  [CUDA Graph replay — production mode]
============================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
  layout      B    H   HV    K    V |  triton (us) |  cutedsl (us) |  speedup |  delta (us)
  ----------------------------------------------------------------------------------
   dense      1    8   16  128  128 |        9.2 |        8.4 |    1.10x |      +0.9
   dense      4    8   16  128  128 |        9.6 |        9.3 |    1.03x |      +0.3
   dense     32    8   16  128  128 |       24.1 |       23.6 |    1.02x |      +0.5
   dense     64    8   16  128  128 |       60.4 |       61.4 |    0.98x |      -1.0
  varlen      1    8   16  128  128 |        9.3 |        8.2 |    1.13x |      +1.1
  varlen      4    8   16  128  128 |        9.8 |        9.2 |    1.06x |      +0.5
  varlen      8    8   16  128  128 |       10.9 |       10.9 |    1.00x |      -0.0
  varlen     16    8   16  128  128 |       14.1 |       15.6 |    0.90x |      -1.5
  varlen     32    8   16  128  128 |       24.1 |       23.5 |    1.02x |      +0.5
  varlen     64    8   16  128  128 |       59.4 |       61.6 |    0.96x |      -2.2
  varlen    128    8   16  128  128 |      103.7 |      107.5 |    0.96x |      -3.8
  varlen     32   16   32  128  128 |       59.0 |       61.2 |    0.96x |      -2.2
  varlen     64   16   16  128  128 |       59.6 |       61.7 |    0.97x |      -2.0
Device: NVIDIA H200  (SM 90)
==============================================================================
Correctness: Triton KDA Decode vs CuTe DSL KDA Decode
==============================================================================
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:232: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:252: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=dense  B= 1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=dense  B= 4 H= 8 HV=16 K=128 V=128 pool=  32
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:702: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:719: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=dense  B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=dense  B=64 H= 8 HV=16 K=128 V=128 pool= 128
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:472: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:492: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=varlen B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=varlen B=16 H= 8 HV=16 K=128 V=128 pool=  64
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:927: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:944: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=varlen B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B= 1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] layout=varlen B=32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H=16 HV=16 K=128 V=128 pool= 128

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling (valid outputs/states only)

ALL PASSED.

============================================================================================
Benchmark: Triton vs CuTe DSL KDA Decode  [CUDA Graph replay — production mode]
============================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
  layout      B    H   HV    K    V |  triton (us) |  cutedsl (us) |  speedup |  delta (us)
  ----------------------------------------------------------------------------------
   dense      1    8   16  128  128 |       14.8 |       14.5 |    1.02x |      +0.3
   dense      4    8   16  128  128 |       15.2 |       14.8 |    1.03x |      +0.4
   dense     32    8   16  128  128 |       22.2 |       22.5 |    0.99x |      -0.3
   dense     64    8   16  128  128 |       51.9 |       47.5 |    1.09x |      +4.4
  varlen      1    8   16  128  128 |       30.0 |       26.7 |    1.13x |      +3.4
  varlen      4    8   16  128  128 |       15.3 |       14.7 |    1.04x |      +0.6
  varlen      8    8   16  128  128 |       16.1 |       13.2 |    1.22x |      +2.9
  varlen     16    8   16  128  128 |       15.0 |       17.0 |    0.88x |      -2.0
  varlen     32    8   16  128  128 |       22.2 |       22.4 |    0.99x |      -0.2
  varlen     64    8   16  128  128 |       50.3 |       47.1 |    1.07x |      +3.1
  varlen    128    8   16  128  128 |       90.3 |       84.4 |    1.07x |      +5.9
  varlen     32   16   32  128  128 |       50.7 |       53.0 |    0.96x |      -2.3
  varlen     64   16   16  128  128 |       50.8 |       47.2 |    1.08x |      +3.6

@yuan-luo
Copy link
Collaborator Author

thanks for the pr. i mainly concerned about this plan item:

Support triton prefill KDA kernel with KV and VK layout switch off.

i think we have agreed upon standardizing the vk layout everywhere in the long-term direction to avoid complexity. So, instead of adding back the KV to the KDA triton, I think we should modify the CuTe DSL KDA kernel in VK layout, consistent with the direction Triton is already moving, rather than adding yet another KV kernel that will need migration later.

@BBuf @kaixih After deeply investigating, I found it was not just a trivial modification to VK, but rather we need to change the logic of varlen and non-varlen APIs, normalize, and make q/k/v/a/b/A_log/dt_bias continuous, which will also introduce some overhead and impact the performance. It's comparable to implementing the GDN decode logic in FlashInfer. I think we should specifically implement a KV layout prefill logic for CuTeDSL using KDA for now, so this PR can move forward as an independent kernel first, and we can complete the e2e CuTeDSL later.
Now that the prefill has been changed to VK layout, the GDN in CuTeDSL is also not passing e2e. We have the same issue. I’ll fix it together in the following PR.
Is this OK for you? Thanks.

If the long-term direction is to standardize on the VK layout, then merging this now may not add much value; we can probably pause this PR for now instead of merging code that would effectively become dead.

@BBuf This code will not be dead, instead the CuTeDSL VK layout support is in progress and is expected to be in the next PR. Can we move this baseline PR forward? It does not impact the ongoing function.

@yuan-luo
Copy link
Collaborator Author

Benchmark Results: CuTeDSL KDA Decode across A100, H100, H200

Command:

python benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py  --cuda-graph

Correctness: all passed on all 3 GPUs and observed improvement in most metrics

cc: @yuan-luo @BBuf @kaixih

Device: NVIDIA A100-SXM4-80GB  (SM 80)
==============================================================================
Correctness: Triton KDA Decode vs CuTe DSL KDA Decode
==============================================================================
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:232: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:252: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=dense  B= 1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=dense  B= 4 H= 8 HV=16 K=128 V=128 pool=  32
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:702: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:719: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=dense  B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=dense  B=64 H= 8 HV=16 K=128 V=128 pool= 128
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:472: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:492: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=varlen B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=varlen B=16 H= 8 HV=16 K=128 V=128 pool=  64
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:927: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:944: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=varlen B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B= 1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] layout=varlen B=32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H=16 HV=16 K=128 V=128 pool= 128

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling (valid outputs/states only)

ALL PASSED.

============================================================================================
Benchmark: Triton vs CuTe DSL KDA Decode  [CUDA Graph replay — production mode]
============================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
  layout      B    H   HV    K    V |  triton (us) |  cutedsl (us) |  speedup |  delta (us)
  ----------------------------------------------------------------------------------
   dense      1    8   16  128  128 |       14.0 |       12.4 |    1.13x |      +1.6
   dense      4    8   16  128  128 |       15.8 |       14.5 |    1.09x |      +1.3
   dense     32    8   16  128  128 |       57.1 |       50.0 |    1.14x |      +7.1
   dense     64    8   16  128  128 |      103.1 |       97.2 |    1.06x |      +5.9
  varlen      1    8   16  128  128 |       13.6 |       12.0 |    1.13x |      +1.6
  varlen      4    8   16  128  128 |       15.8 |       14.5 |    1.09x |      +1.3
  varlen      8    8   16  128  128 |       19.4 |       19.7 |    0.99x |      -0.3
  varlen     16    8   16  128  128 |       29.9 |       28.4 |    1.05x |      +1.5
  varlen     32    8   16  128  128 |       56.6 |       49.9 |    1.14x |      +6.8
  varlen     64    8   16  128  128 |      101.8 |       97.2 |    1.05x |      +4.6
  varlen    128    8   16  128  128 |      181.5 |      178.2 |    1.02x |      +3.4
  varlen     32   16   32  128  128 |      105.2 |       99.5 |    1.06x |      +5.7
  varlen     64   16   16  128  128 |      103.3 |      100.2 |    1.03x |      +3.0
Device: NVIDIA H100 80GB HBM3  (SM 90)
==============================================================================
Correctness: Triton KDA Decode vs CuTe DSL KDA Decode
==============================================================================
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:232: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:252: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=dense  B= 1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=dense  B= 4 H= 8 HV=16 K=128 V=128 pool=  32
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:702: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:719: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=dense  B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=dense  B=64 H= 8 HV=16 K=128 V=128 pool= 128
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:472: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:492: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=varlen B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=varlen B=16 H= 8 HV=16 K=128 V=128 pool=  64
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:927: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:944: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=varlen B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B= 1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] layout=varlen B=32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H=16 HV=16 K=128 V=128 pool= 128

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling (valid outputs/states only)

ALL PASSED.

============================================================================================
Benchmark: Triton vs CuTe DSL KDA Decode  [CUDA Graph replay — production mode]
============================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
  layout      B    H   HV    K    V |  triton (us) |  cutedsl (us) |  speedup |  delta (us)
  ----------------------------------------------------------------------------------
   dense      1    8   16  128  128 |        9.2 |        8.4 |    1.10x |      +0.9
   dense      4    8   16  128  128 |        9.6 |        9.3 |    1.03x |      +0.3
   dense     32    8   16  128  128 |       24.1 |       23.6 |    1.02x |      +0.5
   dense     64    8   16  128  128 |       60.4 |       61.4 |    0.98x |      -1.0
  varlen      1    8   16  128  128 |        9.3 |        8.2 |    1.13x |      +1.1
  varlen      4    8   16  128  128 |        9.8 |        9.2 |    1.06x |      +0.5
  varlen      8    8   16  128  128 |       10.9 |       10.9 |    1.00x |      -0.0
  varlen     16    8   16  128  128 |       14.1 |       15.6 |    0.90x |      -1.5
  varlen     32    8   16  128  128 |       24.1 |       23.5 |    1.02x |      +0.5
  varlen     64    8   16  128  128 |       59.4 |       61.6 |    0.96x |      -2.2
  varlen    128    8   16  128  128 |      103.7 |      107.5 |    0.96x |      -3.8
  varlen     32   16   32  128  128 |       59.0 |       61.2 |    0.96x |      -2.2
  varlen     64   16   16  128  128 |       59.6 |       61.7 |    0.97x |      -2.0
Device: NVIDIA H200  (SM 90)
==============================================================================
Correctness: Triton KDA Decode vs CuTe DSL KDA Decode
==============================================================================
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:232: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:252: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=dense  B= 1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=dense  B= 4 H= 8 HV=16 K=128 V=128 pool=  32
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:702: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:719: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=dense  B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=dense  B=64 H= 8 HV=16 K=128 V=128 pool= 128
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:472: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:492: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS_SMALL, unroll=16):
  [PASS] layout=varlen B= 4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] layout=varlen B=16 H= 8 HV=16 K=128 V=128 pool=  64
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:927: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
/workspace/sglang_eagle/python/sglang/jit_kernel/cutedsl_kda.py:944: DeprecationWarning: range_dynamic is deprecated and will be removed in the future, please remove it.
  for k_iter in cutlass.range_dynamic(NUM_K_ITERS, unroll=8):
  [PASS] layout=varlen B=32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] layout=varlen B= 1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] layout=varlen B=32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] layout=varlen B=64 H=16 HV=16 K=128 V=128 pool= 128

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling (valid outputs/states only)

ALL PASSED.

============================================================================================
Benchmark: Triton vs CuTe DSL KDA Decode  [CUDA Graph replay — production mode]
============================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
  layout      B    H   HV    K    V |  triton (us) |  cutedsl (us) |  speedup |  delta (us)
  ----------------------------------------------------------------------------------
   dense      1    8   16  128  128 |       14.8 |       14.5 |    1.02x |      +0.3
   dense      4    8   16  128  128 |       15.2 |       14.8 |    1.03x |      +0.4
   dense     32    8   16  128  128 |       22.2 |       22.5 |    0.99x |      -0.3
   dense     64    8   16  128  128 |       51.9 |       47.5 |    1.09x |      +4.4
  varlen      1    8   16  128  128 |       30.0 |       26.7 |    1.13x |      +3.4
  varlen      4    8   16  128  128 |       15.3 |       14.7 |    1.04x |      +0.6
  varlen      8    8   16  128  128 |       16.1 |       13.2 |    1.22x |      +2.9
  varlen     16    8   16  128  128 |       15.0 |       17.0 |    0.88x |      -2.0
  varlen     32    8   16  128  128 |       22.2 |       22.4 |    0.99x |      -0.2
  varlen     64    8   16  128  128 |       50.3 |       47.1 |    1.07x |      +3.1
  varlen    128    8   16  128  128 |       90.3 |       84.4 |    1.07x |      +5.9
  varlen     32   16   32  128  128 |       50.7 |       53.0 |    0.96x |      -2.3
  varlen     64   16   16  128  128 |       50.8 |       47.2 |    1.08x |      +3.6

@edwingao28 Thank you very much for your report.

@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Mar 24, 2026

@BBuf I wrote a test script to verify that kda prefill triton kernel is non-deterministic, which proves two things:

  1. The CuteDSL decode kernel is correct. Case 2 and Case 4 passed.
  2. The result of kda prefill triton kernel can't be integrated with decode cutedsl kda kernel as it is non-deterministic. We need to implement CuteDSL KDA prefill kernel in the next step.

Shall we move forward for this PR?

import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "python"))

import torch
from sglang.jit_kernel.cutedsl_kda import cutedsl_fused_sigmoid_gating_kda_update
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
    fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.fla.kda import chunk_kda


def make_inputs(B, H, HV, K, V, pool_size, layout, device="cuda", dtype=torch.bfloat16, seed=42):
    torch.manual_seed(seed)
    if layout == "varlen":
        q = torch.randn(1, B, H, K, device=device, dtype=dtype)
        k = torch.randn(1, B, H, K, device=device, dtype=dtype)
        v = torch.randn(1, B, HV, V, device=device, dtype=dtype)
        a = torch.randn(B, HV, K, device=device, dtype=dtype)
        b = torch.randn(B, HV, device=device, dtype=dtype)
        prefill_g = torch.randn(1, B, HV, K, device=device, dtype=dtype)
        prefill_beta = torch.sigmoid(torch.randn(1, B, HV, device=device, dtype=dtype))
        cu = torch.arange(B + 1, device=device, dtype=torch.int32)
    else:
        q = torch.randn(B, 1, H, K, device=device, dtype=dtype)
        k = torch.randn(B, 1, H, K, device=device, dtype=dtype)
        v = torch.randn(B, 1, HV, V, device=device, dtype=dtype)
        a = torch.randn(B, 1, HV, K, device=device, dtype=dtype)
        b = torch.randn(B, 1, HV, device=device, dtype=dtype)
        prefill_g = torch.randn(B, 1, HV, K, device=device, dtype=dtype)
        prefill_beta = torch.sigmoid(torch.randn(B, 1, HV, device=device, dtype=dtype))
        cu = torch.arange(B + 1, device=device, dtype=torch.int32)

    A_log = torch.randn(HV, device=device, dtype=torch.float32)
    dt_bias = torch.randn(HV, K, device=device, dtype=dtype)
    ssm_states = torch.randn(pool_size, HV, V, K, device=device, dtype=torch.float32) * 0.1
    cache_indices = torch.arange(B, device=device, dtype=torch.int32)

    return dict(q=q, k=k, v=v, a=a, b=b, prefill_g=prefill_g, prefill_beta=prefill_beta,
                A_log=A_log, dt_bias=dt_bias, ssm_states=ssm_states, cache_indices=cache_indices,
                cu_seqlens=cu, layout=layout)


def run_prefill_only(inp):
    """Run ONLY prefill, return the post-prefill state."""
    ssm_states = inp["ssm_states"].clone()
    v_clone = inp["v"].clone()
    g_clone = inp["prefill_g"].clone()
    beta_clone = inp["prefill_beta"].clone()
    q_clone = inp["q"].clone()
    k_clone = inp["k"].clone()
    
    _ = chunk_kda(
        q=q_clone, k=k_clone, v=v_clone,
        g=g_clone, beta=beta_clone,
        initial_state=ssm_states,
        initial_state_indices=inp["cache_indices"],
        use_qk_l2norm_in_kernel=True,
        cu_seqlens=inp["cu_seqlens"] if inp["layout"] == "varlen" else None,
    )
    return ssm_states, v_clone


def main():
    device = "cuda"
    dtype = torch.bfloat16
    
    configs = [
        ("dense",  1, 8, 16, 128, 128, 32),
        ("dense",  4, 8, 16, 128, 128, 32),
        ("varlen", 4, 8, 16, 128, 128, 32),
    ]
    
    for layout, B, H, HV, K, V, pool_size in configs:
        print(f"\n{'='*80}")
        print(f"Config: {layout} B={B} H={H} HV={HV}")
        print(f"{'='*80}")
        
        # ========================================================
        # TEST 1: Is Triton prefill deterministic?
        # ========================================================
        print("\n--- TEST 1: Triton prefill determinism ---")
        inp1 = make_inputs(B, H, HV, K, V, pool_size, layout, seed=42)
        inp2 = make_inputs(B, H, HV, K, V, pool_size, layout, seed=42)
        
        # Verify inputs are identical
        input_diff = (inp1["ssm_states"] - inp2["ssm_states"]).abs().max().item()
        print(f"  Input state diff (should be 0): {input_diff}")
        
        state1, v1 = run_prefill_only(inp1)
        state2, v2 = run_prefill_only(inp2)
        torch.cuda.synchronize()
        
        state_diff = (state1 - state2).abs().max().item()
        v_diff = (v1 - v2).abs().max().item()
        print(f"  Post-prefill state diff (should be 0 if deterministic): {state_diff}")
        print(f"  Post-prefill v diff (should be 0 if deterministic): {v_diff}")
        print(f"  Post-prefill state magnitude: max={state1.abs().max().item():.4f}, mean={state1.abs().mean().item():.4f}")
        print(f"  Post-prefill v magnitude: max={v1.abs().max().item():.4f}, mean={v1.abs().mean().item():.4f}")
        
        if state_diff > 1e-6:
            print(f"  *** PREFILL IS NON-DETERMINISTIC! diff={state_diff} ***")
            print(f"  This means two calls to chunk_kda produce different states.")
            print(f"  The prefill->decode comparison is invalid.")
            continue
        
        # ========================================================
        # TEST 2: From SAME post-prefill state, compare decode
        # ========================================================
        print("\n--- TEST 2: Decode from same post-prefill state ---")
        
        # Use state1 as the ground truth post-prefill state
        post_prefill_state = state1
        post_prefill_v = v1
        
        inp = make_inputs(B, H, HV, K, V, pool_size, layout, seed=42)
        
        # Triton decode
        st_triton = post_prefill_state.clone()
        o_triton = fused_sigmoid_gating_delta_rule_update(
            A_log=inp["A_log"], dt_bias=inp["dt_bias"],
            q=inp["q"], k=inp["k"], v=post_prefill_v.clone(),
            a=inp["a"], b=inp["b"],
            initial_state_source=st_triton,
            initial_state_indices=inp["cache_indices"],
            cu_seqlens=inp["cu_seqlens"] if layout == "varlen" else None,
            use_qk_l2norm_in_kernel=True,
            softplus_beta=1.0, softplus_threshold=20.0, is_kda=True,
        )
        torch.cuda.synchronize()
        
        # CuTe decode
        st_cute = post_prefill_state.clone()
        o_cute = cutedsl_fused_sigmoid_gating_kda_update(
            A_log=inp["A_log"], dt_bias=inp["dt_bias"],
            q=inp["q"], k=inp["k"], v=post_prefill_v.clone(),
            a=inp["a"], b=inp["b"],
            initial_state_source=st_cute,
            initial_state_indices=inp["cache_indices"],
            cu_seqlens=inp["cu_seqlens"] if layout == "varlen" else None,
            use_qk_l2norm_in_kernel=True,
            softplus_beta=1.0, softplus_threshold=20.0,
        )
        torch.cuda.synchronize()
        
        valid = inp["cache_indices"][inp["cache_indices"] >= 0]
        o_diff = (o_triton.float() - o_cute.float()).abs().max().item()
        s_diff = (st_triton[valid].float() - st_cute[valid].float()).abs().max().item()
        
        print(f"  Output diff (Triton vs CuTe): {o_diff:.6e}")
        print(f"  State diff (Triton vs CuTe): {s_diff:.6e}")
        print(f"  Output magnitude: {o_triton.float().abs().max().item():.4f}")
        print(f"  State magnitude: {st_triton[valid].float().abs().max().item():.4f}")
        
        if o_diff < 3e-2 and s_diff < 3e-2:
            print(f"  *** TEST 2 PASS: Decode is correct from post-prefill state ***")
        else:
            print(f"  *** TEST 2 FAIL: Decode diverges from post-prefill state ***")
        
        # ========================================================
        # TEST 3: Full chain comparison (the original benchmark)
        # ========================================================
        print("\n--- TEST 3: Full prefill->decode chain ---")
        
        inp_a = make_inputs(B, H, HV, K, V, pool_size, layout, seed=42)
        inp_b = make_inputs(B, H, HV, K, V, pool_size, layout, seed=42)
        
        # Path A: Triton prefill → Triton decode
        st_a = inp_a["ssm_states"].clone()
        v_a = inp_a["v"].clone()
        g_a = inp_a["prefill_g"].clone()
        beta_a = inp_a["prefill_beta"].clone()
        q_a = inp_a["q"].clone()
        k_a = inp_a["k"].clone()
        _ = chunk_kda(q=q_a, k=k_a, v=v_a, g=g_a, beta=beta_a,
                      initial_state=st_a, initial_state_indices=inp_a["cache_indices"],
                      use_qk_l2norm_in_kernel=True,
                      cu_seqlens=inp_a["cu_seqlens"] if layout == "varlen" else None)
        torch.cuda.synchronize()
        state_after_prefill_a = st_a.clone()  # snapshot
        
        o_a = fused_sigmoid_gating_delta_rule_update(
            A_log=inp_a["A_log"], dt_bias=inp_a["dt_bias"],
            q=inp_a["q"], k=inp_a["k"], v=v_a,
            a=inp_a["a"], b=inp_a["b"],
            initial_state_source=st_a,
            initial_state_indices=inp_a["cache_indices"],
            cu_seqlens=inp_a["cu_seqlens"] if layout == "varlen" else None,
            use_qk_l2norm_in_kernel=True,
            softplus_beta=1.0, softplus_threshold=20.0, is_kda=True,
        )
        torch.cuda.synchronize()
        
        # Path B: Triton prefill → CuTe decode
        st_b = inp_b["ssm_states"].clone()
        v_b = inp_b["v"].clone()
        g_b = inp_b["prefill_g"].clone()
        beta_b = inp_b["prefill_beta"].clone()
        q_b = inp_b["q"].clone()
        k_b = inp_b["k"].clone()
        _ = chunk_kda(q=q_b, k=k_b, v=v_b, g=g_b, beta=beta_b,
                      initial_state=st_b, initial_state_indices=inp_b["cache_indices"],
                      use_qk_l2norm_in_kernel=True,
                      cu_seqlens=inp_b["cu_seqlens"] if layout == "varlen" else None)
        torch.cuda.synchronize()
        state_after_prefill_b = st_b.clone()  # snapshot
        
        o_b = cutedsl_fused_sigmoid_gating_kda_update(
            A_log=inp_b["A_log"], dt_bias=inp_b["dt_bias"],
            q=inp_b["q"], k=inp_b["k"], v=v_b,
            a=inp_b["a"], b=inp_b["b"],
            initial_state_source=st_b,
            initial_state_indices=inp_b["cache_indices"],
            cu_seqlens=inp_b["cu_seqlens"] if layout == "varlen" else None,
            use_qk_l2norm_in_kernel=True,
            softplus_beta=1.0, softplus_threshold=20.0,
        )
        torch.cuda.synchronize()
        
        valid_a = inp_a["cache_indices"][inp_a["cache_indices"] >= 0]
        valid_b = inp_b["cache_indices"][inp_b["cache_indices"] >= 0]
        
        # Compare post-prefill states
        prefill_state_diff = (state_after_prefill_a[valid_a] - state_after_prefill_b[valid_b]).abs().max().item()
        print(f"  Post-prefill state diff (A vs B, should be 0): {prefill_state_diff:.6e}")
        
        # Compare post-prefill v
        v_diff_chain = (v_a - v_b).abs().max().item()
        print(f"  Post-prefill v diff (A vs B, should be 0): {v_diff_chain:.6e}")
        
        # Compare decode inputs (q, k, a, b)
        q_diff = (inp_a["q"] - inp_b["q"]).abs().max().item()
        k_diff = (inp_a["k"] - inp_b["k"]).abs().max().item()
        a_diff = (inp_a["a"] - inp_b["a"]).abs().max().item()
        b_diff = (inp_a["b"] - inp_b["b"]).abs().max().item()
        print(f"  Decode input diffs: q={q_diff:.6e}, k={k_diff:.6e}, a={a_diff:.6e}, b={b_diff:.6e}")
        
        # Compare final results
        o_diff_chain = (o_a.float() - o_b.float()).abs().max().item()
        s_diff_chain = (st_a[valid_a].float() - st_b[valid_b].float()).abs().max().item()
        print(f"  Final output diff: {o_diff_chain:.6e}")
        print(f"  Final state diff: {s_diff_chain:.6e}")
        
        if prefill_state_diff > 1e-6:
            print(f"\n  *** ROOT CAUSE: Prefill produces different states! ***")
            print(f"  *** This means chunk_kda is non-deterministic ***")
        elif o_diff_chain > 3e-2 or s_diff_chain > 3e-2:
            print(f"\n  *** ROOT CAUSE: Decode diverges despite identical inputs ***")
            print(f"  *** This contradicts decode-only test results ***")
            print(f"  *** Something changes between decode-only and chain context ***")
        else:
            print(f"\n  *** CHAIN PASSES when properly isolated ***")


if __name__ == "__main__":
    main()

The result is as following:

================================================================================
Config: dense B=1 H=8 HV=16
================================================================================

--- TEST 1: Triton prefill determinism ---
  Input state diff (should be 0): 0.0
  Post-prefill state diff (should be 0 if deterministic): 0.10302734375
  Post-prefill v diff (should be 0 if deterministic): 0.0
  Post-prefill state magnitude: max=10029334636355693967291541585133568.0000, mean=1552264832509825059696194617344.0000
  Post-prefill v magnitude: max=3.0625, mean=0.4082
  *** PREFILL IS NON-DETERMINISTIC! diff=0.10302734375 ***
  This means two calls to chunk_kda produce different states.
  The prefill->decode comparison is invalid.

================================================================================
Config: dense B=4 H=8 HV=16
================================================================================

--- TEST 1: Triton prefill determinism ---
  Input state diff (should be 0): 0.0
  Post-prefill state diff (should be 0 if deterministic): 1232.0
  Post-prefill v diff (should be 0 if deterministic): 0.0
  Post-prefill state magnitude: max=11462096727263650248333190383009792.0000, mean=1911517389315790663998811144192.0000
  Post-prefill v magnitude: max=3.9062, mean=0.4102
  *** PREFILL IS NON-DETERMINISTIC! diff=1232.0 ***
  This means two calls to chunk_kda produce different states.
  The prefill->decode comparison is invalid.

================================================================================
Config: varlen B=4 H=8 HV=16
================================================================================

--- TEST 1: Triton prefill determinism ---
  Input state diff (should be 0): 0.0
  Post-prefill state diff (should be 0 if deterministic): 0.0
  Post-prefill v diff (should be 0 if deterministic): 0.0
  Post-prefill state magnitude: max=11462096727263650248333190383009792.0000, mean=1911517389315790663998811144192.0000
  Post-prefill v magnitude: max=3.9062, mean=0.4102

--- TEST 2: Decode from same post-prefill state ---
cutedsl_kda loaded from: /usr/local/lib/python3.12/dist-packages/sglang/jit_kernel/cutedsl_kda.py
initial_state_source: shape=torch.Size([32, 16, 128, 128]), stride=(262144, 16384, 128, 1), contiguous=True, data_ptr=140301760987136
h0_source after contiguous: shape=torch.Size([32, 16, 128, 128]), stride=(262144, 16384, 128, 1), contiguous=True, data_ptr=140301760987136
  Output diff (Triton vs CuTe): 6.103516e-05
  State diff (Triton vs CuTe): 3.249593e+27
  Output magnitude: 29631332780334862259985437425664.0000
  State magnitude: 2523500691774475826195464871477248.0000
  *** TEST 2 FAIL: Decode diverges from post-prefill state ***

--- TEST 3: Full prefill->decode chain ---
cutedsl_kda loaded from: /usr/local/lib/python3.12/dist-packages/sglang/jit_kernel/cutedsl_kda.py
initial_state_source: shape=torch.Size([32, 16, 128, 128]), stride=(262144, 16384, 128, 1), contiguous=True, data_ptr=140301928759296
h0_source after contiguous: shape=torch.Size([32, 16, 128, 128]), stride=(262144, 16384, 128, 1), contiguous=True, data_ptr=140301928759296
  Post-prefill state diff (A vs B, should be 0): 3.453125e+00
  Post-prefill v diff (A vs B, should be 0): 0.000000e+00
  Decode input diffs: q=0.000000e+00, k=0.000000e+00, a=0.000000e+00, b=0.000000e+00
  Final output diff: 3.125000e-02
  Final state diff: 3.249593e+27

  *** ROOT CAUSE: Prefill produces different states! ***
  *** This means chunk_kda is non-deterministic ***

@yuan-luo yuan-luo force-pushed the cutedsl_kda_decode branch from 45e143e to 6450a36 Compare March 24, 2026 15:50
@yuan-luo yuan-luo force-pushed the cutedsl_kda_decode branch from 6450a36 to 2344dfa Compare March 25, 2026 00:17
@BBuf BBuf merged commit f273ba1 into sgl-project:main Mar 25, 2026
51 of 66 checks passed
@yuan-luo yuan-luo deleted the cutedsl_kda_decode branch March 25, 2026 02:27
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
johnnycxm pushed a commit to johnnycxm/sglang that referenced this pull request Mar 25, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
johnnycxm pushed a commit to johnnycxm/sglang that referenced this pull request Mar 25, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants