Skip to content

[V1] Enable prefill optimization for Gemma3n #22628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sarckk
Copy link
Collaborator

@sarckk sarckk commented Aug 11, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR adds an option to enable prefill optimization for Gemma3n model with --kv-sharing-fast-prefill.

Background

In You Only Cache Once (https://arxiv.org/abs/2405.05254), self-decoder layers generate KV caches while cross-decoder layers use cross-attention and reuse the shared KV cache. As only self-decoder layers generate KV caches, cross-decoder layers don't need to do prefill. Below is a figure from the YOCO paper on the prefill optimization:

Screenshot 2025-08-11 at 00 49 56

Design

In vLLM V1, the scheduler does not distinguish between prefill and decode. Instead, tokens for requests doing prefill and decode are batched together, as illustrated below (source: vLLM blog):

When we skip tokens corresponding to prefill in the cross-decoder layers, we therefore will have the batch size reduced during model forward for the cross-decoder layers:

Without optimization enabled (baseline)

Screenshot 2025-08-07 at 20 22 00

With optimization enabled (--kv-sharing-fast-prefill)

Screenshot 2025-08-07 at 20 22 09

With this change, we can no longer compile the top-level model for 2 reasons:

  1. torch.compile in vLLM assumes batch size remains the same within a single model forward. The traced graph will be specialized on the batch size, which leads to silent incorrectness if batch size changes within model forward pass.
  2. CUDA graphs are shape specialized, so we will get incorrect results.

Solution: we split the layers into self- and cross-decoder layers, and compile + graph capture them separately. For Gemma3n-E2B which has 30 layers, the first 20 layers and other 10 layers will be grouped separately into independently compiled and CUDA graph captured modules.

Other changes required in this PR:

  • Build attention metadata builder subclass for eligible layers so it can call make_kv_sharing_fast_prefill_common_attn_metadata to create an attention metadata excluding all prefill tokens. This requires passing logits_indices to CommonAttentionMetadata
  • Create a subclass of attention metadata for eligible layers which isinstance of KVSharingFastPrefillAttentionMetadata. This has two additional metadata (logits_indices_padded and num_logits_indices) which are required for indexing into hidden states in the model implementation to match the shapes that the new attention metadata expects
  • Changes to Gemma3n model implementation.
    • Need to change hidden_states shape from [altup_num_inputs, num_tokens,hidden_size] to [num_tokens,hidden_size, altup_num_inputs] to ensure num_tokens (batch size) comes at dim 0. We cannot have num_tokens be on dim=1 because creating a slice along dim=1 would a) cause torch.compile tensor stride assertions to fail, and b) resolving this by calling contiguous() on the slice would cause memory copy and therefore violate CUDA graph static address constraint.
    • If --kv-sharing-fast-prefill flag is passed, we take a different self.fast_prefill_forward() path which uses the logits_indices_padded metadata passed to index into the subset of tokens for cross-decoder layers (i.e. batch size is reduced). We then merge it back to the output of self-decoder to get the final output.
    • If --kv-sharing-fast-prefill flag is passed, we will compile self-decoder and cross-decoder submodules separately, and we also need to pre-allocate static buffers for CUDA graph replay. If it is not passed (default), we will still compile the top-level Gemma3TextModel
    • Attention group changes. After this PR, the attn groups looks for Gemma3n-2B looks like this:
- attn_groups[0] (non-sliding window layers)
  - attn_groups[0][0]: 4, 9, 14, 19 
  - attn_groups[0][1]: 24, 29

- attn_groups[1]
  - attn_groups[1][0] layers: 0, 1, 2, 3

- attn_groups[2]
  - attn_groups[2][0] layers 5, 6, 7, 8

- attn_groups[3]
  - attn_groups[3][0] layers: 10, 11, 12, 13

- attn_groups[4] (sliding window layers)
  - attn_groups[4][0] layers: 15, 16, 17, 18
  - attn_groups[4][1] layers: 20, 21, 22, 23, 25, 26, 27, 28

Compared to trunk, the only difference is there are extra groups for attn_groups[0] and attn_groups[4] for layers which need a separate attention metadata builder for the fast prefill path. Previously it looks like this:

- attn_groups[0] (non-sliding window layers)
  - attn_groups[0][0]: 4, 9, 14, 19, 24, 29 

# attn_groups[1], attn_groups[2] and attn_groups[3] same

- attn_groups[4] (sliding window layers)
  - attn_groups[4][0] layers: 15, 16, 17, 18, 20, 21, 22, 23, 25, 26, 27, 28

Important

When prompt_logprobs is enabled, we can no longer use fast prefill optimization. This is because by skipping all but last prefill tokens, the logits for the prompt tokens will no longer be valid. For example, multiple choice question (MCQ) evals use prompt_logprobs to get the logprobs of continuation tokens (e.g. lm-evaluation-harness), so using --kv-sharing-fast-prefill will yield inaccurate results. To prevent this, we issue a warning and disable the fast prefill optimization when there is at least one request in each scheduling round for which prompt_logprobs is set to values other than None in its sampling params.

Follow ups

Test Plan

Evals

PORT=8000
vllm serve google/gemma-3n-E2B-it --disable-log-requests
lm_eval --model local-completions --tasks $TASKNAME \
    --model_args model=google/gemma-3n-E2B-it,base_url=http://127.0.0.1:$PORT/v1/completions,num_concurrent=200,tokenized_requests=False --batch_size auto --apply_chat_template --fewshot_as_multiturn

ran gsm8k, mmlu, mmlu pro

Unit tests

pytest tests/v1/worker/test_gpu_model_runner.py -k "test_init_kv_cache"
pytest tests/v1/e2e/test_kv_sharing_fast_prefill.py::test_kv_sharing_fast_prefill

Performance

VLLM_DISABLE_COMPILE_CACHE=1 python -m vllm.entrypoints.openai.api_server --model google/gemma-3n-E2B-it --disable-log-requests -tp 1 --port 8000 --no-enable-prefix-caching --max-num-seqs 128 --max-model-len=32768 --max_num_batched_token=8192 --kv-sharing-fast-prefill

Perform sweep over max-concurrency and random-input-len,
$num_reqs = 256
max-num-batched-tokens = 8192
max-num-seqs = 128

python benchmarks/benchmark_serving.py     --backend vllm     --ignore-eos     --port 8000     --model google/gemma-3n-E2B-it     --dataset-name random --max-concurrency 8 --request-rate inf --num-prompts $num_reqs         --random-input-len 8192 --random-output-len 150

Test Result

Evals

Evals on par

Run gsm_8k.5-shot.strict-match mmlu_pro.5-shot.custom_extract mmlu.0-shot.acc
This PR (fast prefill) 0.5466 0.3444 0.5558 (fast prefill disabled as prompt_logprobs=1)
This PR (full prefill) 0.5413 0.3439 0.5560
Base 0.5474 0.3426 0.5560

Unit tests

Unit tests all pass

Performance

Mean TTFT and TPOT (ms):

Screenshot 2025-08-11 at 00 48 13

sarckk added 2 commits August 10, 2025 17:20
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for Gemma3n models by enabling a fast prefill path, inspired by the YOCO paper. The implementation is well-thought-out, involving a refactoring of the Gemma3n model into self-decoder and cross-decoder modules to work with torch.compile and dynamic batch sizes. The changes to attention metadata and the use of a conditional compilation decorator are clean solutions. Overall, the changes are robust and the associated refactoring of KV cache sharing logic improves the codebase. However, there is a critical gap in testing that should be addressed.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

@sarckk has imported this pull request. If you are a Meta employee, you can view this in D80013417.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
speculative-decoding tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants