[Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention#34265
Conversation
Signed-off-by: LopezCastroRoberto <[email protected]>
There was a problem hiding this comment.
Code Review
This pull request replaces the custom large_context_topk kernel with flashinfer.top_k_ragged_transform for handling top-k operations in the large context decode path. The changes primarily involve updating sparse_attn_indexer.py to use the FlashInfer function and passing a new offsets_buffer. Corresponding changes are made for API compatibility in the ROCm path. The tests are also updated to validate the new implementation. My review found a critical issue in the test file where a new test function shadows an existing one due to having the same name, and also misuses a pytest parameter. I've provided a suggestion to fix this.
tests/kernels/test_top_k_per_row.py
Outdated
| def test_deepseek_hybrid_topk(clean_logits: bool, top_k: int) -> None: | ||
| torch.set_default_device("cuda:0") | ||
|
|
||
| top_k = 2048 |
There was a problem hiding this comment.
This new test function test_deepseek_hybrid_topk has the same name as an existing test function at line 286. In Python, this will cause the new function to overwrite the old one, and the original test for torch.ops._C.large_context_topk will no longer be executed. Please rename this new test function to avoid this conflict, for example to test_deepseek_hybrid_topk_flashinfer.
Additionally, the top_k parameter from pytest.mark.parametrize is immediately overwritten on line 401. This makes the parameterization ineffective. Please remove the hardcoded value to allow the test to run with different top_k values as intended.
| def test_deepseek_hybrid_topk(clean_logits: bool, top_k: int) -> None: | |
| torch.set_default_device("cuda:0") | |
| top_k = 2048 | |
| def test_deepseek_hybrid_topk_flashinfer(clean_logits: bool, top_k: int) -> None: | |
| torch.set_default_device("cuda:0") | |
bbba437 to
2d74e0f
Compare
Signed-off-by: LopezCastroRoberto <[email protected]>
Summary
This PR integrates FlashInfer's radix-based top-k kernel as an alternative implementation for the large context top-k operation in the sparse attention indexer, specifically for DeepSeek-V3.2 models.
Kernel adapted from: flashinfer-ai/flashinfer#2215
Motivation - Microbenchmark results (NVIDIA B200)
E2E results (NVIDIA B200)
MAIN:
PR:
In this example, the current PR improves MAIN throughput by ~10%
Accuracy
python tests/evals/gsm8k/gsm8k_eval.pyMAIN:
PR: