Skip to content

[Whisper] Enable CUDA graph support for encoder-decoder models#21190

Open
JustinTong0323 wants to merge 5 commits intosgl-project:mainfrom
JustinTong0323:feat/whisper-cudagraph
Open

[Whisper] Enable CUDA graph support for encoder-decoder models#21190
JustinTong0323 wants to merge 5 commits intosgl-project:mainfrom
JustinTong0323:feat/whisper-cudagraph

Conversation

@JustinTong0323
Copy link
Collaborator

Motivation

Previously, Whisper used a custom bmm + mask cross-attention implementation with a Python-side _encoder_cache dict, which was incompatible with CUDA graph capture/replay. This PR enables CUDA graph for Whisper, achieving 36% throughput improvement with identical accuracy.

Related: #21161

Modifications

  • Cross-attention → RadixAttention: Replaced manual BMM cross-attention with native RadixAttention path. During prefill, encoder KV is projected and saved to the KV pool; during decode, k=None, v=None triggers cached KV read from the pool.
  • Removed _encoder_cache dict: Encoder outputs now flow through the native encoder-decoder cache managed by scheduler/attention backends.
  • Auto-select flashinfer backend: Encoder-decoder models require flashinfer for cross-attention support (trtllm_mha/fa3 lack encoder_out_cache_loc handling).
  • Auto-disable radix cache: Encoder token padding conflicts with prefix caching, causing assert len(req.prefix_indices) == 0 failures.
  • Fixed CUDA graph capture: Set encoder_len_fill_value to max_source_positions (1500 for Whisper) so cross-attention kernels are included in the captured graph. Previously was 0, causing cross-attention to be skipped during capture.
  • Fixed FlashInfer decode cross-attention planning: update_cross_attention was passing decoder seq_lens_cpu for cross-attention wrapper, causing global_override_indptr_cpu to override KV length from 1500 to ~5. Now correctly uses encoder_lens for cross-attention.
  • Position embedding OOB protection: Clamp position_ids to max_target_positions - 1 to prevent warmup from exceeding the 448-entry position embedding table.
  • Added encoder_out_cache_loc to trtllm_mha backend: For forward compatibility.

Accuracy Tests

# launch server (no manual flags needed - auto-configured)
sglang serve --model-path openai/whisper-large-v3

# benchmark
python benchmark/asr/bench_sglang.py \
    --base-url http://127.0.0.1:30000 \
    --model openai/whisper-large-v3 \
    --api-type transcription \
    --language en \
    --concurrency 1
Config WER Avg Latency Throughput
CUDA graph (ours) 12.77% 0.297s 3.26 req/s
No CUDA graph (baseline) 12.77% 0.406s 2.40 req/s

WER is identical between CUDA graph and non-CUDA-graph modes (12.7690% on earnings22 dataset, 511 samples).

Benchmarking

Metric No CUDA Graph CUDA Graph Improvement
WER 12.77% 12.77% identical
Avg Latency 0.406s 0.297s -27%
Median Latency 0.376s 0.267s -29%
P95 Latency 0.837s 0.601s -28%
Throughput 2.40 req/s 3.26 req/s +36%
Token Throughput 44.71 tok/s 60.72 tok/s +36%

Dataset: D4nt3/esb-datasets-earnings22-validation-tiny-filtered (511 samples), concurrency=1, model=openai/whisper-large-v3.

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added the blackwell SM100/SM120 label Mar 23, 2026
@JustinTong0323 JustinTong0323 force-pushed the feat/whisper-cudagraph branch from f8ac20a to 78375d3 Compare March 23, 2026 06:42
Replace manual BMM cross-attention with RadixAttention to enable CUDA
graph capture/replay for the Whisper decode path. The encoder KV cache
is now stored in the standard KV pool via the attention backend's
encoder_out_cache_loc mechanism.

Key changes:
- Cross-attention uses RadixAttention with k=None,v=None during decode
  to read cached encoder KV from the pool
- pad_input_ids prepends dummy encoder tokens and sets num_image_tokens
  so prepare_encoder_info_extend allocates encoder KV cache locations
- Auto-select flashinfer backend for encoder-decoder models
- Auto-disable radix cache to avoid prefix matching conflicts
- Set encoder_len_fill_value to actual encoder length during CUDA graph
  capture so cross-attention kernels are properly recorded
- Fix cross-attention seq_lens_cpu in FlashInfer decode updater: use
  encoder_lens instead of decoder seq_lens to prevent
  global_override_indptr_cpu from overriding the correct KV length
- Add encoder_out_cache_loc support in trtllm_mha backend
- Clamp decoder position_ids to max_target_positions

Benchmark (earnings22, 511 samples, concurrency=1):
  WER: 12.77% (identical with/without CUDA graph)
  Throughput: 3.26 req/s (+36% vs 2.40 without CUDA graph)
  Avg latency: 0.297s (-27% vs 0.406s)
@JustinTong0323 JustinTong0323 force-pushed the feat/whisper-cudagraph branch from 78375d3 to 3f2c689 Compare March 23, 2026 06:43
@JustinTong0323 JustinTong0323 requested a review from HaiShaw as a code owner March 23, 2026 06:43
@yuan-luo
Copy link
Collaborator

Please fix lint.

@JustinTong0323
Copy link
Collaborator Author

/tag-and-rerun-ci

@JustinTong0323
Copy link
Collaborator Author

/tag-and-rerun-ci

@JustinTong0323
Copy link
Collaborator Author

/tag-and-rerun-ci

@JustinTong0323
Copy link
Collaborator Author

JustinTong0323 commented Mar 24, 2026

SGLang vs vLLM: Whisper Serving Benchmark Report

Environment

Item Detail
GPU NVIDIA B200 (183 GB), single-GPU serving (CUDA_VISIBLE_DEVICES=0)
Model openai/whisper-large-v3 (1.55B params)
Dataset D4nt3/esb-datasets-earnings22-validation-tiny-filtered (511 samples, earnings call transcriptions, <30s each)
SGLang 0.5.9 — branch feat/whisper-cudagraph, CUDA graph enabled (52 batch sizes)
vLLM 0.18.0 — FA4, CUDA graph enabled (decode only, 51 batch sizes), nvidia-cutlass-dsl==4.4.2
Torch SGLang: 2.9.1 / vLLM: 2.10.0

Launch Commands

# SGLang
CUDA_VISIBLE_DEVICES=0 python -m sglang.launch_server \
  --model-path openai/whisper-large-v3 --port 30000

# vLLM
CUDA_VISIBLE_DEVICES=0 vllm serve openai/whisper-large-v3 --port 30000

Benchmark Command

python benchmark/asr/bench_sglang.py \
  --base-url http://localhost:30000 \
  --model openai/whisper-large-v3 \
  --api-type transcription \
  --concurrency 64           # or 1 for single-request test

SGLang benchmarks include --language en; vLLM benchmarks omit it due to a vLLM bug where language=en produces garbled output. SGLang defaults to en when language is not specified, so the comparison is equivalent.


1. Accuracy Tests (SGLang w/ CUDA Graph)

Test audio: LibriSpeech sample (Narsil/asr_dummy/1.flac)

Test Result
Basic transcription PASS — correct full output
Consistency (3 sequential identical requests) PASS — all 3 outputs identical
Quality (keyword check: stew, dinner, turnips, carrots, potatoes, mutton) PASS — 6/6 keywords found
CUDA graph capture 52 batch sizes captured, disable_cuda_graph=False

Transcription output:

" he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce"


2. Benchmark: High Concurrency (64 concurrent requests, 511 samples)

Metric SGLang (CUDA graph) vLLM 0.18.0 Delta
WER 12.78% 51.22% 4.0x better
Avg Latency 0.516s 1.206s 2.3x faster
Median Latency 0.469s 0.606s 1.3x faster
P95 Latency 1.103s 7.451s 6.8x faster
Throughput 48.82 req/s 22.56 req/s 2.2x higher
Token Throughput 910.07 tok/s 493.29 tok/s 1.8x higher
Total Time 10.47s 22.65s 2.2x faster

Sample Predictions Comparison (High Concurrency)

Sample 2 — SGLang complete, vLLM truncated:

REF:    so within fiscal year 2021 say 120 a 100 depending on what the micro will do
        and next year it is not necessarily payable in q one is we will look at what
        the cash flows for 2022 look like
SGLang: so within fiscal year 2021 say $120000 $100000 depending on what the macro
        will do and next year it is not necessarily payable in q one is we will look
        at what the cash flows for 2022 look like                                      ✓
vLLM:   so within fiscal year 2021 say $120000 $100000 depending on what the macro
        will do and next year it is not necessarily payable in q one is     ← TRUNCATED

Sample 7 — vLLM cross-attention corruption (repetition):

REF:    essentially a transformer that was allocated to a future project we we have now retrofitted
SGLang: essentially a transformer that was allocated to a future project
        we now have retrofitted                                                         ✓
vLLM:   essentially a transformer that was allocated to a future project we now
        this this this this this this this this this this this this this this
        this this this this                                        ← REPETITION/GARBAGE

Sample 9 — vLLM cross-language contamination:

REF:    i think you mentioned generation was impacted by a transformer upgrade
SGLang: i think you mentioned a generation was impacted by a transformer outfit        ✓
vLLM:   i think you mentioned a generation was impacted by a transformer outfit
        i i i i i i i i i i i i i i i i抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽抽
                                                   ← REPETITION + CROSS-LANGUAGE GARBAGE

3. Benchmark: Single Request (concurrency=1, 50 samples)

Metric SGLang (CUDA graph) vLLM 0.18.0 Delta
WER 13.23% 14.84% 1.1x better
Avg Latency 0.097s 0.136s 1.4x faster
Median Latency 0.083s 0.126s 1.5x faster
P95 Latency 0.206s 0.217s 1.1x
Throughput 8.00 req/s 6.08 req/s 1.3x higher
Token Throughput 173.05 tok/s 133.58 tok/s 1.3x higher

- Accept `timestamp_granularities[]` and `response_format=verbose_json`
  in the `/v1/audio/transcriptions` endpoint
- Switch decoder prompt from `<|notimestamps|>` to `<|0.00|>` when
  timestamps are requested so the model emits timestamp tokens
- Parse timestamp tokens from output_ids into segments with start/end
  times in the serving layer
- Add TranscriptionSegment and TranscriptionVerboseResponse protocol
  models matching the OpenAI API spec
- Backward compatible: default behavior (json/text) unchanged
@JustinTong0323 JustinTong0323 force-pushed the feat/whisper-cudagraph branch from 2c7297d to 3c53805 Compare March 25, 2026 04:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants