Skip to content

Commit f8ac20a

Browse files
[Whisper] Enable CUDA graph support for encoder-decoder models
Replace manual BMM cross-attention with RadixAttention to enable CUDA graph capture/replay for the Whisper decode path. The encoder KV cache is now stored in the standard KV pool via the attention backend's encoder_out_cache_loc mechanism. Key changes: - Cross-attention uses RadixAttention with k=None,v=None during decode to read cached encoder KV from the pool - pad_input_ids prepends dummy encoder tokens and sets num_image_tokens so prepare_encoder_info_extend allocates encoder KV cache locations - Auto-select flashinfer backend for encoder-decoder models - Auto-disable radix cache to avoid prefix matching conflicts - Set encoder_len_fill_value to actual encoder length during CUDA graph capture so cross-attention kernels are properly recorded - Fix cross-attention seq_lens_cpu in FlashInfer decode updater: use encoder_lens instead of decoder seq_lens to prevent global_override_indptr_cpu from overriding the correct KV length - Add encoder_out_cache_loc support in trtllm_mha backend - Clamp decoder position_ids to max_target_positions Benchmark (earnings22, 511 samples, concurrency=1): WER: 12.77% (identical with/without CUDA graph) Throughput: 3.26 req/s (+36% vs 2.40 without CUDA graph) Avg latency: 0.297s (-27% vs 0.406s)
1 parent 32a85ef commit f8ac20a

File tree

7 files changed

+233
-103
lines changed

7 files changed

+233
-103
lines changed

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,16 +1048,19 @@ def update_cross_attention(
10481048
fixed_split_size: Optional[int] = None,
10491049
disable_split_kv: Optional[bool] = None,
10501050
):
1051+
# Cache encoder_lens on CPU to avoid GPU→CPU transfer per call
1052+
encoder_lens_cpu = encoder_lens.cpu() if encoder_lens is not None else None
10511053
for wrapper_id in range(2):
10521054
if wrapper_id == 0:
1053-
# Normal attention
10541055
paged_kernel_lens = seq_lens
10551056
kv_start_idx = encoder_lens
1057+
kv_lens_cpu = seq_lens_cpu
10561058
else:
1057-
# Cross attention
1059+
# Cross-attention: attend to encoder tokens only
10581060
paged_kernel_lens = encoder_lens
10591061
kv_start_idx = torch.zeros_like(encoder_lens)
10601062
seq_lens_sum = encoder_lens.sum().item()
1063+
kv_lens_cpu = encoder_lens_cpu
10611064

10621065
self.call_begin_forward(
10631066
decode_wrappers[wrapper_id],
@@ -1067,7 +1070,7 @@ def update_cross_attention(
10671070
self.kv_indptr[wrapper_id],
10681071
kv_start_idx,
10691072
spec_info,
1070-
seq_lens_cpu=seq_lens_cpu,
1073+
seq_lens_cpu=kv_lens_cpu,
10711074
)
10721075

10731076
def call_begin_forward(

python/sglang/srt/layers/attention/trtllm_mha_backend.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,11 @@ def forward_decode(
703703
**kwargs,
704704
) -> torch.Tensor:
705705
"""Run forward for decode using TRTLLM MHA kernel."""
706-
cache_loc = forward_batch.out_cache_loc
706+
cache_loc = (
707+
forward_batch.out_cache_loc
708+
if not layer.is_cross_attention
709+
else forward_batch.encoder_out_cache_loc
710+
)
707711

708712
use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k)
709713

@@ -788,7 +792,11 @@ def forward_extend(
788792
save_kv_cache=True,
789793
**kwargs,
790794
):
791-
cache_loc = forward_batch.out_cache_loc
795+
cache_loc = (
796+
forward_batch.out_cache_loc
797+
if not layer.is_cross_attention
798+
else forward_batch.encoder_out_cache_loc
799+
)
792800

793801
use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k)
794802

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,12 @@ def __init__(self, model_runner: ModelRunner):
590590
else self.dllm_config.block_size
591591
)
592592

593-
self.encoder_len_fill_value = 0
593+
# Non-zero encoder length ensures cross-attention kernels are captured in the graph.
594+
self.encoder_len_fill_value = (
595+
getattr(model_runner.model_config.hf_config, "max_source_positions", 0)
596+
if self.is_encoder_decoder
597+
else 0
598+
)
594599

595600
if self.enable_torch_compile:
596601
set_torch_compile_config()

python/sglang/srt/model_executor/model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2044,7 +2044,11 @@ def _dummy_run(self, batch_size: int, run_ctx=None):
20442044
is_encoder_decoder=self.model_config.is_encoder_decoder,
20452045
require_mlp_tp_gather=require_mlp_tp_gather_,
20462046
seq_len_fill_value=seq_len_fill_value,
2047-
encoder_len_fill_value=0,
2047+
encoder_len_fill_value=(
2048+
getattr(self.model_config.hf_config, "max_source_positions", 0)
2049+
if self.model_config.is_encoder_decoder
2050+
else 0
2051+
),
20482052
num_tokens_per_bs=num_tokens_per_bs,
20492053
cache_loc_dtype=torch.int64,
20502054
enable_mamba_track=False,

python/sglang/srt/models/whisper.py

Lines changed: 32 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -94,70 +94,16 @@ def forward(
9494
"""Input shape: Batch x Time x Channel"""
9595

9696
if self.is_cross_attention:
97+
# Cross-attention: KV cached during prefill, read from pool during decode.
9798
q, _ = self.q_proj(hidden_states)
99+
q = q * self.scaling
98100
if cross_hidden_states is not None:
99101
kv, _ = self.kv_proj(cross_hidden_states)
100102
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
101103
else:
102-
k = torch.zeros_like(q)
103-
v = torch.zeros_like(q)
104-
105-
q = q * self.scaling
106-
num_heads = self.attn.tp_q_head_num
107-
head_dim = self.attn.head_dim
108-
109-
q = q.view(-1, num_heads, head_dim)
110-
k = k.view(-1, num_heads, head_dim)
111-
v = v.view(-1, num_heads, head_dim)
112-
113-
q_len = q.shape[0]
114-
kv_len = k.shape[0]
115-
116-
q = q.transpose(0, 1)
117-
k = k.transpose(0, 1)
118-
v = v.transpose(0, 1)
119-
120-
attn_weights = torch.bmm(q, k.transpose(1, 2))
121-
122-
# Apply block-diagonal mask for batched cross-attention
123-
batch_size = forward_batch.batch_size if forward_batch else 1
124-
if batch_size > 1 and kv_len > 0:
125-
encoder_len_per_request = kv_len // batch_size
126-
if encoder_len_per_request * batch_size == kv_len:
127-
is_decode = forward_batch.forward_mode.is_decode()
128-
if is_decode:
129-
mask = torch.zeros(
130-
(q_len, kv_len), device=q.device, dtype=torch.bool
131-
)
132-
for i in range(batch_size):
133-
enc_start = i * encoder_len_per_request
134-
enc_end = (i + 1) * encoder_len_per_request
135-
mask[i, enc_start:enc_end] = True
136-
attn_weights = attn_weights.masked_fill(
137-
~mask.unsqueeze(0), float("-inf")
138-
)
139-
else:
140-
seq_lens = forward_batch.seq_lens
141-
if seq_lens is not None and len(seq_lens) == batch_size:
142-
seq_lens_list = seq_lens.tolist()
143-
mask = torch.zeros(
144-
(q_len, kv_len), device=q.device, dtype=torch.bool
145-
)
146-
q_start = 0
147-
for i, dec_len in enumerate(seq_lens_list):
148-
enc_start = i * encoder_len_per_request
149-
enc_end = (i + 1) * encoder_len_per_request
150-
q_end = q_start + dec_len
151-
mask[q_start:q_end, enc_start:enc_end] = True
152-
q_start = q_end
153-
attn_weights = attn_weights.masked_fill(
154-
~mask.unsqueeze(0), float("-inf")
155-
)
156-
157-
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
158-
attn_output = torch.bmm(attn_weights, v)
159-
attn_output = attn_output.transpose(0, 1)
160-
attn_output = attn_output.reshape(q_len, num_heads * head_dim)
104+
k = None
105+
v = None
106+
attn_output = self.attn(q, k, v, forward_batch)
161107
else:
162108
qkv, _ = self.qkv_proj(hidden_states)
163109
q, k, v = qkv.chunk(chunks=3, dim=-1)
@@ -394,6 +340,7 @@ def forward(
394340
position_ids=None,
395341
):
396342
inputs_embeds = self.embed_tokens(input_ids)
343+
position_ids = position_ids.clamp(max=self.max_target_positions - 1)
397344
positions = self.embed_positions(position_ids)
398345
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
399346

@@ -420,7 +367,6 @@ def __init__(
420367
)
421368
self.logits_processor = LogitsProcessor(config)
422369
self.config = config
423-
self._encoder_cache = {}
424370

425371
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
426372
stacked_params_mapping = [
@@ -468,8 +414,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
468414
weight_loader = getattr(param, "weight_loader", default_weight_loader)
469415
weight_loader(param, loaded_weight)
470416

471-
def pad_input_ids(self, input_ids: List[int], _mm_inputs: MultimodalInputs):
472-
return input_ids
417+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
418+
# Prepend dummy encoder tokens so that prepare_encoder_info_extend
419+
# correctly allocates encoder KV cache locations in the KV pool.
420+
# These dummy tokens are stripped before the model forward receives input_ids.
421+
encoder_len = self.config.max_source_positions
422+
mm_inputs.num_image_tokens = encoder_len
423+
pad_ids = [0] * encoder_len
424+
return pad_ids + input_ids
473425

474426
def forward(
475427
self,
@@ -479,29 +431,22 @@ def forward(
479431
**kwargs: Any,
480432
) -> LogitsProcessorOutput:
481433
dtype = self.encoder.conv1.weight.dtype
482-
is_decode = forward_batch.forward_mode.is_decode()
483-
484-
if is_decode:
485-
encoder_outputs = None
486-
if forward_batch.req_pool_indices is not None:
487-
req_indices = forward_batch.req_pool_indices.tolist()
488-
encoder_list = []
489-
for req_idx in req_indices:
490-
if req_idx in self._encoder_cache:
491-
encoder_list.append(self._encoder_cache[req_idx])
492-
if encoder_list:
493-
encoder_outputs = torch.cat(encoder_list, dim=0)
494-
else:
495-
encoder_list = []
434+
435+
# Run encoder for requests that haven't cached encoder output yet.
436+
# During decode or when encoder is already cached, encoder_hidden_states
437+
# is None and cross-attention reads KV from the pool via RadixAttention.
438+
encoder_hidden_states = None
439+
if not forward_batch.forward_mode.is_decode():
496440
mm_inputs_list = forward_batch.mm_inputs if forward_batch.mm_inputs else []
497-
req_indices = (
498-
forward_batch.req_pool_indices.tolist()
499-
if forward_batch.req_pool_indices is not None
500-
else []
441+
encoder_cached_list = (
442+
forward_batch.encoder_cached if forward_batch.encoder_cached else []
501443
)
502444

503-
for req_idx, mm_input in zip(req_indices, mm_inputs_list):
504-
if mm_input is None or not mm_input.mm_items:
445+
encoder_list = []
446+
for i, (mm_input, cached) in enumerate(
447+
zip(mm_inputs_list, encoder_cached_list)
448+
):
449+
if cached or mm_input is None or not mm_input.mm_items:
505450
continue
506451

507452
features = mm_input.mm_items[0].feature
@@ -513,21 +458,17 @@ def forward(
513458
features.device, non_blocking=True
514459
)
515460

516-
req_encoder_outputs = self.encoder(
461+
req_encoder_output = self.encoder(
517462
features.to(dtype), encoder_position_ids, forward_batch
518463
)
519-
req_encoder_outputs = req_encoder_outputs.squeeze(0)
520-
521-
self._encoder_cache[req_idx] = req_encoder_outputs
522-
encoder_list.append(req_encoder_outputs)
464+
req_encoder_output = req_encoder_output.squeeze(0)
465+
encoder_list.append(req_encoder_output)
523466

524467
if encoder_list:
525-
encoder_outputs = torch.cat(encoder_list, dim=0)
526-
else:
527-
encoder_outputs = None
468+
encoder_hidden_states = torch.cat(encoder_list, dim=0)
528469

529470
decoder_outputs = self.decoder(
530-
input_ids, encoder_outputs, forward_batch, positions
471+
input_ids, encoder_hidden_states, forward_batch, positions
531472
)
532473

533474
logits = self.logits_processor(

python/sglang/srt/server_args.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,6 +2181,10 @@ def _get_default_attn_backend(self, use_mla_backend: bool, model_config):
21812181
2.2 We will use Flashinfer backend on blackwell.
21822182
2.3 Otherwise, we will use triton backend.
21832183
"""
2184+
# Encoder-decoder models (e.g., Whisper) require flashinfer for cross-attention support
2185+
if model_config.is_encoder_decoder:
2186+
return "flashinfer"
2187+
21842188
if not use_mla_backend:
21852189
# MHA architecture
21862190
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(self):
@@ -2256,12 +2260,13 @@ def _handle_attention_backend_compatibility(self):
22562260
self.speculative_algorithm is None
22572261
), "Speculative decoding is currently not supported with Flex Attention backend"
22582262

2259-
# Encoder-decoder models (e.g., Whisper)
2260-
if model_config.is_encoder_decoder:
2261-
logger.warning(
2262-
"Cuda graph is disabled for encoder-decoder models (e.g., Whisper)"
2263+
# Encoder-decoder models (e.g., Whisper) require radix cache disabled
2264+
# because encoder token padding conflicts with prefix caching.
2265+
if model_config.is_encoder_decoder and not self.disable_radix_cache:
2266+
logger.info(
2267+
"Radix cache is disabled for encoder-decoder models (e.g., Whisper)"
22632268
)
2264-
self.disable_cuda_graph = True
2269+
self.disable_radix_cache = True
22652270

22662271
# Major NVIDIA platforms backends
22672272
if (

0 commit comments

Comments
 (0)