Skip to content

[Refactor][Perf] Replace _LocalPredictorKVCache with SDPA#1591

Open
gcanlin wants to merge 2 commits intovllm-project:mainfrom
gcanlin:code-kv
Open

[Refactor][Perf] Replace _LocalPredictorKVCache with SDPA#1591
gcanlin wants to merge 2 commits intovllm-project:mainfrom
gcanlin:code-kv

Conversation

@gcanlin
Copy link
Contributor

@gcanlin gcanlin commented Mar 2, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

To solve some problems:

  • The code predictor (a small 5-layer sub-model for residual codebook prediction) was borrowing model runner v2's attn_utils (vllm.v1.worker.gpu.attn_utils) to manage its own KV cache, but vllm-omni uses model runner v1. The current implementation is so hacky.
  • We have to add the NPU-specifc code branch because v2 backends needed different metadata formats.

The code predictor doesn't need paged attention because paged attention solves three problems it doesn't have:

  • Variable-length sequences — Paged attention manages memory efficiently when sequences have unpredictable lengths (hundreds to thousands of tokens). The code predictor always runs exactly num_code_groups tokens (~32), known at init time.
  • Memory fragmentation across concurrent sequences — Paged attention avoids fragmentation when the engine juggles thousands of sequences with different lifetimes. The code predictor runs a small batch, allocates at startup, and reuses the same buffer every frame.
  • Shared KV cache pool managed by the scheduler — The main LLM's KV cache is allocated/freed by vLLM's scheduler across requests. The code predictor's KV cache is private — it's allocated once, used for ~30 decode steps, then zeroed. No scheduler involvement.

Test Plan

CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve     "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"     --omni --host 127.0.0.1 --port 8000     --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml     --trust-remote-code
python benchmarks/qwen3-tts/vllm_omni/bench_tts_serve.py     --port 8000     --num-prompts 50     --max-concurrency 1 4 10     --save-audio     --result-dir benchmarks/qwen3-tts/results/

Test Result

It looks like RTF values have some problem...

Before:

======================================================================
 Summary: async_chunk
======================================================================
  Conc   TTFP(ms)     E2E(ms)      RTF      Audio(s/s)   Req/s
  ------ ------------ ------------ -------- ------------ --------
  1      2502.4       7012.5       2.477    0.40         0.1
  4      23747.9      27808.0      9.832    0.41         0.1
  10     70739.9      75256.3      23.191   0.41         0.1
======================================================================

After:

======================================================================
 Summary: async_chunk
======================================================================
  Conc   TTFP(ms)     E2E(ms)      RTF      Audio(s/s)   Req/s
  ------ ------------ ------------ -------- ------------ --------
  1      2042.6       5650.0       1.996    0.50         0.2
  4      19093.7      23453.3      8.049    0.51         0.2
  10     51751.7      55537.6      17.790   0.51         0.2
======================================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin marked this pull request as ready for review March 2, 2026 06:00
@gcanlin gcanlin requested a review from hsliuustc0106 as a code owner March 2, 2026 06:00
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 05ffc34ffd

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin
Copy link
Contributor Author

gcanlin commented Mar 2, 2026

cc @Sy0307 @linyueqian

Copy link
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR #1591 Review: [Refactor][Perf] Replace _LocalPredictorKVCache with SDPA

Overview

Refactors the Qwen3-TTS code predictor from using vLLM's v2 paged attention backend to a simpler SDPA-based approach with dense KV buffers.

Design Rationale ✅

The PR correctly identifies that paged attention is overkill for this use case:

  • Fixed sequence length (~32 tokens)
  • Private KV cache (no scheduler involvement)
  • Small batch sizes
  • Sequential decoding (no concurrent sequences with different lifetimes)

Performance Results ✅

Metric Before After Improvement
TTFP (conc=1) 2502ms 2042ms -18%
E2E (conc=1) 7012ms 5650ms -19%
RTF (conc=1) 2.477 1.996 -19%

Important Issues: 3 found

1. Hardcoded dtype in KV Cache Allocation

The dtype should be derived from the model config or input dtype, not hardcoded to bfloat16.

2. No Unit Tests for New Classes

No unit tests for CodePredictorAttention and CodePredictorDecoderLayer.

3. Batch Size Reallocation Edge Case

When batch size increases, old cache data is discarded. Should document the assumption that reallocation only happens during prefill.


Strengths

  • ✅ Significant simplification (~100 lines removed)
  • ✅ Eliminates v2 dependency
  • ✅ Removes NPU special-casing
  • ✅ 18-20% performance improvement
  • ✅ Weight compatibility preserved

Recommendation

After addressing the hardcoded dtype issue, this PR is ready for merge. Unit tests could be a follow-up.

batch_size,
self._num_kv_heads,
self._max_seq_len,
self._head_dim,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded dtype

Consider deriving dtype from config or input instead of hardcoding bfloat16:

dtype = inputs_embeds.dtype  # or self.config.torch_dtype
k = torch.zeros(..., dtype=dtype, device=device)
v = torch.zeros(..., dtype=dtype, device=device)

This ensures consistency if the model runs in a different precision.

last_idx = torch.arange(qlen - 1, bsz * qlen, step=qlen, device=out.device, dtype=torch.long)
last_h = out.index_select(0, last_idx)
last_h = out[:, -1, :] # [B, hidden]
logits = self.lm_head[0](last_h)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batch size reallocation

When bsz increases, the old cache is discarded and reallocated. This works because reallocation should only happen during prefill (not mid-generation).

Consider adding an assertion or comment documenting this assumption:

# Reallocation only expected during prefill; cache is reset between generations.
if self._kv_caches is None or self._kv_caches[0][0].shape[0] < bsz:
    self._kv_caches = self._allocate_kv_caches(bsz, device)

@Sy0307
Copy link
Contributor

Sy0307 commented Mar 2, 2026

LGTM. Good Work! I think it's worthwhile to refactor build_attn_metadata/ set_forward_context here to eliminate the time overhead. I previously had other ideas to reduce the bubbles here — for example, instead of doing incremental KV-cache decoding at each AR step, treating the "entire prefix generated so far (growing sequence)" as a prefill and running it from scratch each time.

However, after testing #1591, I found that it essentially eliminates the bubbles and redundant Python calls, achieving approximately 20% performance improvement.

@Sy0307
Copy link
Contributor

Sy0307 commented Mar 2, 2026

for example, instead of doing incremental KV-cache decoding at each AR step, treating the "entire prefix generated so far (growing sequence)" as a prefill and running it from scratch each time.

Though I wonder if this approach might perform better when combined with torch.compile? I'm still not sure, but this PR can be completed and merged first, and afterwards I'll try to write an implementation matching the description above to compare the performance — we can have more discussion at that point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants