Skip to content

[Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention#34265

Draft
LopezCastroRoberto wants to merge 2 commits intovllm-project:mainfrom
LopezCastroRoberto:perf/topKperRow-FI
Draft

[Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention#34265
LopezCastroRoberto wants to merge 2 commits intovllm-project:mainfrom
LopezCastroRoberto:perf/topKperRow-FI

Conversation

@LopezCastroRoberto
Copy link
Contributor

@LopezCastroRoberto LopezCastroRoberto commented Feb 10, 2026

Summary

This PR integrates FlashInfer's radix-based top-k kernel as an alternative implementation for the large context top-k operation in the sparse attention indexer, specifically for DeepSeek-V3.2 models.

Kernel adapted from: flashinfer-ai/flashinfer#2215

Motivation - Microbenchmark results (NVIDIA B200)

topKperRow_comparisson

E2E results (NVIDIA B200)

vllm serve nvidia/DeepSeek-V3.2-NVFP4 -tp 4
vllm bench serve --backend vllm --model nvidia/DeepSeek-V3.2-NVFP4 --input-len 128000 --output-len 4096 --num-prompts 1

MAIN:

============ Serving Benchmark Result ============
Successful requests:                     1         
Failed requests:                         0         
Benchmark duration (s):                  59.15     
Total input tokens:                      128000    
Total generated tokens:                  4096      
Request throughput (req/s):              0.02      
Output token throughput (tok/s):         69.24     
Peak output token throughput (tok/s):    71.00     
Peak concurrent requests:                1.00      
Total token throughput (tok/s):          2233.14   
---------------Time to First Token----------------
Mean TTFT (ms):                          717.80    
Median TTFT (ms):                        717.80    
P99 TTFT (ms):                           717.80    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.27     
Median TPOT (ms):                        14.27     
P99 TPOT (ms):                           14.27     
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.27     
Median ITL (ms):                         14.27     
P99 ITL (ms):                            14.57     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     1         
Failed requests:                         0         
Benchmark duration (s):                  53.62     
Total input tokens:                      128000    
Total generated tokens:                  4096      
Request throughput (req/s):              0.02      
Output token throughput (tok/s):         76.39     
Peak output token throughput (tok/s):    80.00     
Peak concurrent requests:                1.00      
Total token throughput (tok/s):          2463.46   
---------------Time to First Token----------------
Mean TTFT (ms):                          732.15    
Median TTFT (ms):                        732.15    
P99 TTFT (ms):                           732.15    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.92     
Median TPOT (ms):                        12.92     
P99 TPOT (ms):                           12.92     
---------------Inter-token Latency----------------
Mean ITL (ms):                           12.92     
Median ITL (ms):                         12.72     
P99 ITL (ms):                            13.81     
==================================================

In this example, the current PR improves MAIN throughput by ~10%

Accuracy

python tests/evals/gsm8k/gsm8k_eval.py

MAIN:

Results:
Accuracy: 0.926
Invalid responses: 0.000
Total latency: 54.086 s
Questions per second: 24.387
Total output tokens: 121416
Output tokens per second: 2244.889

PR:

Results:
Accuracy: 0.930
Invalid responses: 0.000
Total latency: 45.613 s
Questions per second: 28.917
Total output tokens: 120617
Output tokens per second: 2644.330

Signed-off-by: LopezCastroRoberto <[email protected]>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft February 10, 2026 18:47
@LopezCastroRoberto LopezCastroRoberto changed the title Add FlashInfer top-k support to large context decode path [Perf] Add FlashInfer top-k support to large context decode path Feb 10, 2026
@mergify mergify bot added rocm Related to AMD ROCm v1 labels Feb 10, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 10, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the custom large_context_topk kernel with flashinfer.top_k_ragged_transform for handling top-k operations in the large context decode path. The changes primarily involve updating sparse_attn_indexer.py to use the FlashInfer function and passing a new offsets_buffer. Corresponding changes are made for API compatibility in the ROCm path. The tests are also updated to validate the new implementation. My review found a critical issue in the test file where a new test function shadows an existing one due to having the same name, and also misuses a pytest parameter. I've provided a suggestion to fix this.

Comment on lines 398 to 401
def test_deepseek_hybrid_topk(clean_logits: bool, top_k: int) -> None:
torch.set_default_device("cuda:0")

top_k = 2048
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This new test function test_deepseek_hybrid_topk has the same name as an existing test function at line 286. In Python, this will cause the new function to overwrite the old one, and the original test for torch.ops._C.large_context_topk will no longer be executed. Please rename this new test function to avoid this conflict, for example to test_deepseek_hybrid_topk_flashinfer.

Additionally, the top_k parameter from pytest.mark.parametrize is immediately overwritten on line 401. This makes the parameterization ineffective. Please remove the hardcoded value to allow the test to run with different top_k values as intended.

Suggested change
def test_deepseek_hybrid_topk(clean_logits: bool, top_k: int) -> None:
torch.set_default_device("cuda:0")
top_k = 2048
def test_deepseek_hybrid_topk_flashinfer(clean_logits: bool, top_k: int) -> None:
torch.set_default_device("cuda:0")

@LopezCastroRoberto LopezCastroRoberto changed the title [Perf] Add FlashInfer top-k support to large context decode path [Perf] Add FlashInfer top-k support to large context decode path - DeepSeek-V3.2 sparse attention Feb 10, 2026
@mergify mergify bot added the deepseek Related to DeepSeek models label Feb 10, 2026
Signed-off-by: LopezCastroRoberto <[email protected]>
@mergify mergify bot added the nvidia label Feb 12, 2026
@LopezCastroRoberto LopezCastroRoberto changed the title [Perf] Add FlashInfer top-k support to large context decode path - DeepSeek-V3.2 sparse attention [Perf][Kernel] Improve topKperRow routine for large context decode path - DeepSeek-V3.2 sparse attention Feb 12, 2026
@LopezCastroRoberto LopezCastroRoberto changed the title [Perf][Kernel] Improve topKperRow routine for large context decode path - DeepSeek-V3.2 sparse attention [Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention Feb 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models nvidia rocm Related to AMD ROCm v1

Projects

Status: Todo
Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant