[Whisper] Enable CUDA graph support for encoder-decoder models#21190
[Whisper] Enable CUDA graph support for encoder-decoder models#21190JustinTong0323 wants to merge 5 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
f8ac20a to
78375d3
Compare
Replace manual BMM cross-attention with RadixAttention to enable CUDA graph capture/replay for the Whisper decode path. The encoder KV cache is now stored in the standard KV pool via the attention backend's encoder_out_cache_loc mechanism. Key changes: - Cross-attention uses RadixAttention with k=None,v=None during decode to read cached encoder KV from the pool - pad_input_ids prepends dummy encoder tokens and sets num_image_tokens so prepare_encoder_info_extend allocates encoder KV cache locations - Auto-select flashinfer backend for encoder-decoder models - Auto-disable radix cache to avoid prefix matching conflicts - Set encoder_len_fill_value to actual encoder length during CUDA graph capture so cross-attention kernels are properly recorded - Fix cross-attention seq_lens_cpu in FlashInfer decode updater: use encoder_lens instead of decoder seq_lens to prevent global_override_indptr_cpu from overriding the correct KV length - Add encoder_out_cache_loc support in trtllm_mha backend - Clamp decoder position_ids to max_target_positions Benchmark (earnings22, 511 samples, concurrency=1): WER: 12.77% (identical with/without CUDA graph) Throughput: 3.26 req/s (+36% vs 2.40 without CUDA graph) Avg latency: 0.297s (-27% vs 0.406s)
78375d3 to
3f2c689
Compare
|
Please fix lint. |
|
/tag-and-rerun-ci |
|
/tag-and-rerun-ci |
|
/tag-and-rerun-ci |
SGLang vs vLLM: Whisper Serving Benchmark ReportEnvironment
Launch Commands# SGLang
CUDA_VISIBLE_DEVICES=0 python -m sglang.launch_server \
--model-path openai/whisper-large-v3 --port 30000
# vLLM
CUDA_VISIBLE_DEVICES=0 vllm serve openai/whisper-large-v3 --port 30000Benchmark Commandpython benchmark/asr/bench_sglang.py \
--base-url http://localhost:30000 \
--model openai/whisper-large-v3 \
--api-type transcription \
--concurrency 64 # or 1 for single-request test
1. Accuracy Tests (SGLang w/ CUDA Graph)Test audio: LibriSpeech sample (
Transcription output:
2. Benchmark: High Concurrency (64 concurrent requests, 511 samples)
Sample Predictions Comparison (High Concurrency)Sample 2 — SGLang complete, vLLM truncated: Sample 7 — vLLM cross-attention corruption (repetition): Sample 9 — vLLM cross-language contamination: 3. Benchmark: Single Request (concurrency=1, 50 samples)
|
- Accept `timestamp_granularities[]` and `response_format=verbose_json` in the `/v1/audio/transcriptions` endpoint - Switch decoder prompt from `<|notimestamps|>` to `<|0.00|>` when timestamps are requested so the model emits timestamp tokens - Parse timestamp tokens from output_ids into segments with start/end times in the serving layer - Add TranscriptionSegment and TranscriptionVerboseResponse protocol models matching the OpenAI API spec - Backward compatible: default behavior (json/text) unchanged
2c7297d to
3c53805
Compare
Motivation
Previously, Whisper used a custom
bmm + maskcross-attention implementation with a Python-side_encoder_cachedict, which was incompatible with CUDA graph capture/replay. This PR enables CUDA graph for Whisper, achieving 36% throughput improvement with identical accuracy.Related: #21161
Modifications
RadixAttentionpath. During prefill, encoder KV is projected and saved to the KV pool; during decode,k=None, v=Nonetriggers cached KV read from the pool._encoder_cachedict: Encoder outputs now flow through the native encoder-decoder cache managed by scheduler/attention backends.flashinferfor cross-attention support (trtllm_mha/fa3 lackencoder_out_cache_lochandling).assert len(req.prefix_indices) == 0failures.encoder_len_fill_valuetomax_source_positions(1500 for Whisper) so cross-attention kernels are included in the captured graph. Previously was 0, causing cross-attention to be skipped during capture.update_cross_attentionwas passing decoderseq_lens_cpufor cross-attention wrapper, causingglobal_override_indptr_cputo override KV length from 1500 to ~5. Now correctly usesencoder_lensfor cross-attention.position_idstomax_target_positions - 1to prevent warmup from exceeding the 448-entry position embedding table.encoder_out_cache_locto trtllm_mha backend: For forward compatibility.Accuracy Tests
WER is identical between CUDA graph and non-CUDA-graph modes (12.7690% on earnings22 dataset, 511 samples).
Benchmarking
Dataset:
D4nt3/esb-datasets-earnings22-validation-tiny-filtered(511 samples), concurrency=1, model=openai/whisper-large-v3.Checklist