Skip to content

[Refactor][Perf] Qwen3-TTS: re-prefill Code Predictor with torch.compile + enable Code2Wav decoder CUDA Graph#1617

Draft
Sy0307 wants to merge 1 commit intovllm-project:mainfrom
Sy0307:dev/tts_decoder_prof
Draft

[Refactor][Perf] Qwen3-TTS: re-prefill Code Predictor with torch.compile + enable Code2Wav decoder CUDA Graph#1617
Sy0307 wants to merge 1 commit intovllm-project:mainfrom
Sy0307:dev/tts_decoder_prof

Conversation

@Sy0307
Copy link
Contributor

@Sy0307 Sy0307 commented Mar 2, 2026

Summary

WIP. Some issues to be re-checked.

Related issue: #938

Optimize Qwen3-TTS inference latency through two complementary improvements:

  1. Code Predictor: re-prefill + torch.compile

    • Replace _LocalPredictorKVCache with a stateless re-prefill approach. Each AR step re-computes the full growing sequence (2~16 tokens) through standalone SDPA attention, eliminating KV cache management.
    • Combined with torch.compile(mode="reduce-overhead", dynamic=True) for full CUDA Graph capture and kernel fusion.
  2. Code2Wav: enable decoder CUDA Graph in multi-stage mode

    • Bypass the HF tokenizer wrapper (tok.decode()) and call decoder.chunked_decode() directly on GPU, enabling the CUDAGraphDecoderWrapper to capture and replay decoder graphs.
    • Add bucket size 26 to match streaming input length (25 context + 1 chunk).
    • Decouple Code2Wav's internal CUDA Graph from enforce_eager (which only controls vLLM engine-level paged attention graphs).

Changes

  • qwen3_tts_code_predictor_vllm.py:
    • Remove _LocalPredictorKVCache, custom sampling op, and forward_context dependencies.
    • Implement _CodePredictorAttention with standalone SDPA + native GQA.
    • Add pre-allocated projection/position buffers and torch.compile integration.
  • qwen3_tts_code2wav.py:
    • Bypass HF tok.decode() wrapper, call decoder.chunked_decode() directly on GPU.
    • Enable decoder CUDA Graph independently of enforce_eager.
  • cuda_graph_decoder_wrapper.py:
    • Add bucket size 26 to DEFAULT_CAPTURE_SIZES for streaming mode (25 context + 1 chunk).

Test Plan

Unit test:

pytest tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py -v

E2E test:

pytest tests/e2e/online_serving/test_qwen3_tts.py -v

Benchmark (using scripts from PR #1573):

# Start server (same config as PR #1591)
CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \
    "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" \
    --omni --host 127.0.0.1 --port 8000 \
    --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml \
    --trust-remote-code

# Run benchmark (50 prompts, streaming PCM)
python benchmarks/qwen3-tts/vllm_omni/bench_tts_serve.py \
    --port 8000 --num-prompts 50 --max-concurrency 1 4 10 \
    --result-dir benchmarks/qwen3-tts/results/

Test Result

  • Unit test: 31/31 passed
  • E2E test: 6/6 passed (English, Chinese, multi-voice, binary response, API endpoints)

Benchmark (RTX 5090, 50 prompts, streaming PCM, enforce_eager=true, same machine):

0.6B-CustomVoice:

Conc TTFP(ms) E2E(ms) RTF Audio(s/s) Req/s
1 354.6 950.2 0.168 5.97 1.05
4 3014.7 3598.9 0.652 6.06 1.08
10 8147.3 8769.3 1.545 6.05 1.04

1.7B-CustomVoice:

Conc TTFP(ms) E2E(ms) RTF Audio(s/s) Req/s
1 364.8 1001.2 0.173 5.79 1.00
4 3210.7 3854.5 0.680 5.89 1.01
10 8655.0 9328.8 1.638 5.89 0.98

vs PR #1591 (dense KV + SDPA, same machine, same config):

Model Conc PR#1591 RTF This PR RTF Improvement
0.6B 1 0.231 0.168 -27.3%
0.6B 4 0.893 0.652 -27.0%
0.6B 10 2.112 1.545 -26.8%
1.7B 1 0.239 0.173 -27.6%
1.7B 4 0.943 0.680 -27.9%
1.7B 10 2.241 1.638 -26.9%

cc @linyueqian @gcanlin

… with standalone attention layers

- Modified DEFAULT_CAPTURE_SIZES to include an additional value.
- Introduced a new standalone multi-head attention class for the code predictor, replacing the previous implementation.
- Improved the Qwen3TTSCode2Wav class by bypassing the HF wrapper for direct decoder calls, optimizing performance and reducing overhead.

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 requested a review from hsliuustc0106 as a code owner March 2, 2026 19:59
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8ed0b2a8d0

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +66 to +70
self.qkv_proj = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.num_heads,
total_num_kv_heads=self.num_kv_heads,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Pass quantization config into predictor linear layers

When vllm_config.quant_config is enabled (e.g., AWQ/GPTQ/FP8 serving), this refactor instantiates the predictor's parallel linear layers without any quant_config, and the corresponding load_weights path no longer applies quantization scale remapping, so quantized weights/scales for the code predictor can be skipped or mismatched. The previous implementation routed through Qwen3DecoderLayer(..., quant_config=...), so this is a regression that can break startup or produce incorrect residual-code sampling in quantized deployments.

Useful? React with 👍 / 👎.

@Sy0307 Sy0307 marked this pull request as draft March 2, 2026 20:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant