[Refactor][Perf] Replace _LocalPredictorKVCache with SDPA#1591
[Refactor][Perf] Replace _LocalPredictorKVCache with SDPA#1591gcanlin wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 05ffc34ffd
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py
Show resolved
Hide resolved
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
hsliuustc0106
left a comment
There was a problem hiding this comment.
PR #1591 Review: [Refactor][Perf] Replace _LocalPredictorKVCache with SDPA
Overview
Refactors the Qwen3-TTS code predictor from using vLLM's v2 paged attention backend to a simpler SDPA-based approach with dense KV buffers.
Design Rationale ✅
The PR correctly identifies that paged attention is overkill for this use case:
- Fixed sequence length (~32 tokens)
- Private KV cache (no scheduler involvement)
- Small batch sizes
- Sequential decoding (no concurrent sequences with different lifetimes)
Performance Results ✅
| Metric | Before | After | Improvement |
|---|---|---|---|
| TTFP (conc=1) | 2502ms | 2042ms | -18% |
| E2E (conc=1) | 7012ms | 5650ms | -19% |
| RTF (conc=1) | 2.477 | 1.996 | -19% |
Important Issues: 3 found
1. Hardcoded dtype in KV Cache Allocation
The dtype should be derived from the model config or input dtype, not hardcoded to bfloat16.
2. No Unit Tests for New Classes
No unit tests for CodePredictorAttention and CodePredictorDecoderLayer.
3. Batch Size Reallocation Edge Case
When batch size increases, old cache data is discarded. Should document the assumption that reallocation only happens during prefill.
Strengths
- ✅ Significant simplification (~100 lines removed)
- ✅ Eliminates v2 dependency
- ✅ Removes NPU special-casing
- ✅ 18-20% performance improvement
- ✅ Weight compatibility preserved
Recommendation
After addressing the hardcoded dtype issue, this PR is ready for merge. Unit tests could be a follow-up.
| batch_size, | ||
| self._num_kv_heads, | ||
| self._max_seq_len, | ||
| self._head_dim, |
There was a problem hiding this comment.
Hardcoded dtype
Consider deriving dtype from config or input instead of hardcoding bfloat16:
dtype = inputs_embeds.dtype # or self.config.torch_dtype
k = torch.zeros(..., dtype=dtype, device=device)
v = torch.zeros(..., dtype=dtype, device=device)This ensures consistency if the model runs in a different precision.
| last_idx = torch.arange(qlen - 1, bsz * qlen, step=qlen, device=out.device, dtype=torch.long) | ||
| last_h = out.index_select(0, last_idx) | ||
| last_h = out[:, -1, :] # [B, hidden] | ||
| logits = self.lm_head[0](last_h) |
There was a problem hiding this comment.
Batch size reallocation
When bsz increases, the old cache is discarded and reallocated. This works because reallocation should only happen during prefill (not mid-generation).
Consider adding an assertion or comment documenting this assumption:
# Reallocation only expected during prefill; cache is reset between generations.
if self._kv_caches is None or self._kv_caches[0][0].shape[0] < bsz:
self._kv_caches = self._allocate_kv_caches(bsz, device)|
LGTM. Good Work! I think it's worthwhile to refactor However, after testing #1591, I found that it essentially eliminates the bubbles and redundant Python calls, achieving approximately 20% performance improvement. |
Though I wonder if this approach might perform better when combined with torch.compile? I'm still not sure, but this PR can be completed and merged first, and afterwards I'll try to write an implementation matching the description above to compare the performance — we can have more discussion at that point. |
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
To solve some problems:
The code predictor doesn't need paged attention because paged attention solves three problems it doesn't have:
Test Plan
Test Result
It looks like RTF values have some problem...
Before:
After:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)