Skip to content

Conversation

sarckk
Copy link
Collaborator

@sarckk sarckk commented Aug 11, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR adds an option to enable prefill optimization for Gemma3n model with --kv-sharing-fast-prefill.

Background

In You Only Cache Once (https://arxiv.org/abs/2405.05254), self-decoder layers generate KV caches while cross-decoder layers use cross-attention and reuse the shared KV cache. As only self-decoder layers generate KV caches, cross-decoder layers don't need to do prefill. Below is a figure from the YOCO paper on the prefill optimization:

Screenshot 2025-08-11 at 00 49 56

Design

In vLLM V1, the scheduler does not distinguish between prefill and decode. Instead, tokens for requests doing prefill and decode are batched together, as illustrated below (source: vLLM blog):

When we skip tokens corresponding to prefill in the cross-decoder layers, we therefore will have the batch size reduced during model forward for the cross-decoder layers:

Without optimization enabled (baseline)

Screenshot 2025-08-07 at 20 22 00

With optimization enabled (--kv-sharing-fast-prefill)

Screenshot 2025-08-07 at 20 22 09

With this change, we can no longer compile the top-level model for 2 reasons:

  1. torch.compile in vLLM assumes batch size remains the same within a single model forward. The traced graph will be specialized on the batch size, which leads to silent incorrectness if batch size changes within model forward pass.
  2. CUDA graphs are shape specialized, so we will get incorrect results.

Solution: we split the layers into self- and cross-decoder layers, and compile + graph capture them separately. For Gemma3n-E2B which has 30 layers, the first 20 layers and other 10 layers will be grouped separately into independently compiled and CUDA graph captured modules.

Other changes required in this PR:

  • Build attention metadata builder subclass for eligible layers so it can call make_kv_sharing_fast_prefill_common_attn_metadata to create an attention metadata excluding all prefill tokens. This requires passing logits_indices to CommonAttentionMetadata
  • Create a subclass of attention metadata for eligible layers which isinstance of KVSharingFastPrefillAttentionMetadata. This has two additional metadata (logits_indices_padded and num_logits_indices) which are required for indexing into hidden states in the model implementation to match the shapes that the new attention metadata expects
  • Changes to Gemma3n model implementation.
    • Need to change hidden_states shape from [altup_num_inputs, num_tokens,hidden_size] to [num_tokens,hidden_size, altup_num_inputs] to ensure num_tokens (batch size) comes at dim 0. We cannot have num_tokens be on dim=1 because creating a slice along dim=1 would a) cause torch.compile tensor stride assertions to fail, and b) resolving this by calling contiguous() on the slice would cause memory copy and therefore violate CUDA graph static address constraint.
    • If --kv-sharing-fast-prefill flag is passed, we take a different self.fast_prefill_forward() path which uses the logits_indices_padded metadata passed to index into the subset of tokens for cross-decoder layers (i.e. batch size is reduced). We then merge it back to the output of self-decoder to get the final output.
    • If --kv-sharing-fast-prefill flag is passed, we will compile self-decoder and cross-decoder submodules separately, and we also need to pre-allocate static buffers for CUDA graph replay. If it is not passed (default), we will still compile the top-level Gemma3TextModel
    • Attention group changes. After this PR, the attn groups looks for Gemma3n-2B looks like this:
- attn_groups[0] (non-sliding window layers)
  - attn_groups[0][0]: 4, 9, 14, 19 
  - attn_groups[0][1]: 24, 29

- attn_groups[1]
  - attn_groups[1][0] layers: 0, 1, 2, 3

- attn_groups[2]
  - attn_groups[2][0] layers 5, 6, 7, 8

- attn_groups[3]
  - attn_groups[3][0] layers: 10, 11, 12, 13

- attn_groups[4] (sliding window layers)
  - attn_groups[4][0] layers: 15, 16, 17, 18
  - attn_groups[4][1] layers: 20, 21, 22, 23, 25, 26, 27, 28

Compared to trunk, the only difference is there are extra groups for attn_groups[0] and attn_groups[4] for layers which need a separate attention metadata builder for the fast prefill path. Previously it looks like this:

- attn_groups[0] (non-sliding window layers)
  - attn_groups[0][0]: 4, 9, 14, 19, 24, 29 

# attn_groups[1], attn_groups[2] and attn_groups[3] same

- attn_groups[4] (sliding window layers)
  - attn_groups[4][0] layers: 15, 16, 17, 18, 20, 21, 22, 23, 25, 26, 27, 28

Important

When prompt_logprobs is enabled, we can no longer use fast prefill optimization. This is because by skipping all but last prefill tokens, the logits for the prompt tokens will no longer be valid. For example, multiple choice question (MCQ) evals use prompt_logprobs to get the logprobs of continuation tokens (e.g. lm-evaluation-harness), so using --kv-sharing-fast-prefill will yield inaccurate results. To prevent this, completion requests with prompt logprobs will be rejected (HTTP 400 Bad Request) for online serving. For offline serving (i.e. with LLM class), it will raise an assertion error.

Follow ups

Test Plan

Evals

PORT=8000
vllm serve google/gemma-3n-E2B-it --disable-log-requests
lm_eval --model local-completions --tasks $TASKNAME \
    --model_args model=google/gemma-3n-E2B-it,base_url=http://127.0.0.1:$PORT/v1/completions,num_concurrent=200,tokenized_requests=False --batch_size auto --apply_chat_template --fewshot_as_multiturn

ran gsm8k, mmlu, mmlu pro

Unit tests

pytest tests/v1/worker/test_gpu_model_runner.py -k "test_init_kv_cache"
pytest tests/v1/e2e/test_kv_sharing_fast_prefill.py::test_kv_sharing_fast_prefill

Performance

VLLM_DISABLE_COMPILE_CACHE=1 python -m vllm.entrypoints.openai.api_server --model google/gemma-3n-E2B-it --disable-log-requests -tp 1 --port 8000 --no-enable-prefix-caching --max-num-seqs 128 --max-model-len=32768 --max_num_batched_token=8192 --kv-sharing-fast-prefill

Perform sweep over max-concurrency and random-input-len,
$num_reqs = 256
max-num-batched-tokens = 8192
max-num-seqs = 128

python benchmarks/benchmark_serving.py     --backend vllm     --ignore-eos     --port 8000     --model google/gemma-3n-E2B-it     --dataset-name random --max-concurrency 8 --request-rate inf --num-prompts $num_reqs         --random-input-len 8192 --random-output-len 150

Test Result

Evals

Evals on par

Run gsm_8k.5-shot.strict-match mmlu_pro.5-shot.custom_extract
This PR (fast prefill) 0.5466 0.3444
This PR (full prefill) 0.5413 0.3439
Base 0.5474 0.3426

gsm8k: 0.5451 and mmlu_pro:0.3424 with hybrid memory allocator enabled as well

Unit tests

Unit tests all pass

Performance

Mean TTFT and TPOT (ms):

Screenshot 2025-08-11 at 00 48 13

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for Gemma3n models by enabling a fast prefill path, inspired by the YOCO paper. The implementation is well-thought-out, involving a refactoring of the Gemma3n model into self-decoder and cross-decoder modules to work with torch.compile and dynamic batch sizes. The changes to attention metadata and the use of a conditional compilation decorator are clean solutions. Overall, the changes are robust and the associated refactoring of KV cache sharing logic improves the codebase. However, there is a critical gap in testing that should be addressed.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

@sarckk has imported this pull request. If you are a Meta employee, you can view this in D80013417.

Copy link

mergify bot commented Aug 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 14, 2025
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. I took a look on model runner changes. Will check gemma3m.py later.

@sarckk sarckk requested a review from yewentao256 as a code owner August 19, 2025 19:57
@sarckk sarckk force-pushed the gemma3n-fast-prefill-rebased branch from a822572 to bcf331a Compare August 19, 2025 20:02
@mergify mergify bot removed the needs-rebase label Aug 19, 2025
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almosst LGTM. But maybe we need more discussion about custom attention abstraction @LucasWilkinson

embed_scale: torch.Tensor,
):
super().__init__()
self.decoder_layers = decoder_layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there be any hidden problem when the decoder_layers are registered with both Gemma3nTextModel.layers and Gemma3nTextModel.self_decoder.decoder_layers in nn.Module? A cleaner solution would be to only register it in Gemma3nSelfDecoder (but need to update the weight loader, can do it in a follow-up PR after the model structure is finalized)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, mostly did this for simplicity as I couldn't really think of a case where it would be problematic (though there could be). I do want to separate this to a separate PR if possible

Copy link

mergify bot commented Aug 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 25, 2025
@sarckk sarckk force-pushed the gemma3n-fast-prefill-rebased branch from 24ded7e to b6ee234 Compare August 25, 2025 20:15
@mergify mergify bot removed the needs-rebase label Aug 25, 2025
@sarckk
Copy link
Collaborator Author

sarckk commented Aug 25, 2025

@heheda12345 addressed comments, ready for review

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for your patience for iterating on this optimization.

Some follow-ups:

  1. Abstract out changes in the model to make future model integration easier
  2. Register each weight only once in nn.Module.

@heheda12345 heheda12345 enabled auto-merge (squash) August 26, 2025 16:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 26, 2025
@facebook-github-bot
Copy link

@sarckk has imported this pull request. If you are a Meta employee, you can view this in D80013417.

Copy link

mergify bot commented Aug 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 27, 2025
sarckk added 6 commits August 27, 2025 14:15
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
auto-merge was automatically disabled August 27, 2025 21:15

Head branch was pushed to by a user without write access

@sarckk sarckk force-pushed the gemma3n-fast-prefill-rebased branch from d8dd6ca to e265fbd Compare August 27, 2025 21:15
@mergify mergify bot removed the needs-rebase label Aug 27, 2025
@heheda12345 heheda12345 enabled auto-merge (squash) August 28, 2025 17:18
@simon-mo simon-mo disabled auto-merge August 28, 2025 21:54
@simon-mo simon-mo merged commit cb293f6 into vllm-project:main Aug 28, 2025
39 of 43 checks passed
@DarkLight1337
Copy link
Member

DarkLight1337 commented Aug 29, 2025

I think the multi-modal processor test failure is indeed related to this PR. It has not been failing on the previous commits on main before this one

@sarckk
Copy link
Collaborator Author

sarckk commented Aug 29, 2025

@DarkLight1337 based on the commit history it does look like they started failing after my PR but the failing tests are:

  • models/multimodal/processing/test_tensor_schema.py::test_model_tensor_schema[H2OVLChatModel-h2oai/h2ovl-mississippi-800m]
  • models/multimodal/processing/test_tensor_schema.py::test_model_tensor_schema[KimiVLForConditionalGeneration-moonshotai/Kimi-VL-A3B-Instruct]

and they were passing on my machine. Are you able to reproduce the issue on your dev machine?

EDIT: my bad, seems like the tests only fail if they come after the Gemma3n test case. I was only testing only the two broken test cases and it was fine locally. After including Gemma3n test case, I can reproduce the issue. It's probably due to the weight loading issue, reverting the gemma3n changes in this PR fixes it. #23897 should unbreak the CI while I investigate the issue

@sarckk sarckk mentioned this pull request Aug 29, 2025
5 tasks
@zou3519
Copy link
Collaborator

zou3519 commented Aug 29, 2025

Btw, this has a bad interaction with PyTorch 2.8.0: #20358

I verified that after reverting this the failing test in CI (https://buildkite.com/vllm/ci/builds/28843/steps/canvas?jid=0198f5cd-d810-42bf-8560-c5ef36e6898c) passes

@pratapyash
Copy link

I had an issue working with Lora adapters with the complication set at PIECEWISE, some changes seem relevant to the issues I'm facing #23970

@sarckk
Copy link
Collaborator Author

sarckk commented Aug 30, 2025

I had an issue working with Lora adapters with the complication set at PIECEWISE, some changes seem relevant to the issues I'm facing #23970

@pratapyash could you check if you are still getting the issue after #23897?

@sarckk
Copy link
Collaborator Author

sarckk commented Aug 30, 2025

for those facing problems after this PR, please try rebasing on #23897 which reverts the gemma3n model changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants