[Refactor][Perf] Qwen3-TTS: re-prefill Code Predictor with torch.compile + enable Code2Wav decoder CUDA Graph#1617
Conversation
… with standalone attention layers - Modified DEFAULT_CAPTURE_SIZES to include an additional value. - Introduced a new standalone multi-head attention class for the code predictor, replacing the previous implementation. - Improved the Qwen3TTSCode2Wav class by bypassing the HF wrapper for direct decoder calls, optimizing performance and reducing overhead. Signed-off-by: Sy03 <1370724210@qq.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8ed0b2a8d0
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| self.qkv_proj = QKVParallelLinear( | ||
| hidden_size=self.hidden_size, | ||
| head_size=self.head_dim, | ||
| total_num_heads=self.num_heads, | ||
| total_num_kv_heads=self.num_kv_heads, |
There was a problem hiding this comment.
Pass quantization config into predictor linear layers
When vllm_config.quant_config is enabled (e.g., AWQ/GPTQ/FP8 serving), this refactor instantiates the predictor's parallel linear layers without any quant_config, and the corresponding load_weights path no longer applies quantization scale remapping, so quantized weights/scales for the code predictor can be skipped or mismatched. The previous implementation routed through Qwen3DecoderLayer(..., quant_config=...), so this is a regression that can break startup or produce incorrect residual-code sampling in quantized deployments.
Useful? React with 👍 / 👎.
Summary
WIP. Some issues to be re-checked.
Related issue: #938
Optimize Qwen3-TTS inference latency through two complementary improvements:
Code Predictor: re-prefill + torch.compile
_LocalPredictorKVCachewith a stateless re-prefill approach. Each AR step re-computes the full growing sequence (2~16 tokens) through standalone SDPA attention, eliminating KV cache management.torch.compile(mode="reduce-overhead", dynamic=True)for full CUDA Graph capture and kernel fusion.Code2Wav: enable decoder CUDA Graph in multi-stage mode
tok.decode()) and calldecoder.chunked_decode()directly on GPU, enabling theCUDAGraphDecoderWrapperto capture and replay decoder graphs.enforce_eager(which only controls vLLM engine-level paged attention graphs).Changes
qwen3_tts_code_predictor_vllm.py:_LocalPredictorKVCache, custom sampling op, andforward_contextdependencies._CodePredictorAttentionwith standalone SDPA + native GQA.torch.compileintegration.qwen3_tts_code2wav.py:tok.decode()wrapper, calldecoder.chunked_decode()directly on GPU.enforce_eager.cuda_graph_decoder_wrapper.py:DEFAULT_CAPTURE_SIZESfor streaming mode (25 context + 1 chunk).Test Plan
Unit test:
E2E test:
Benchmark (using scripts from PR #1573):
Test Result
Benchmark (RTX 5090, 50 prompts, streaming PCM, enforce_eager=true, same machine):
0.6B-CustomVoice:
1.7B-CustomVoice:
vs PR #1591 (dense KV + SDPA, same machine, same config):
cc @linyueqian @gcanlin