[V1] Enable prefill optimization for Gemma3n #22628
Open
+669
−245
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
This PR adds an option to enable prefill optimization for Gemma3n model with
--kv-sharing-fast-prefill
.Background
In You Only Cache Once (https://arxiv.org/abs/2405.05254), self-decoder layers generate KV caches while cross-decoder layers use cross-attention and reuse the shared KV cache. As only self-decoder layers generate KV caches, cross-decoder layers don't need to do prefill. Below is a figure from the YOCO paper on the prefill optimization:
Design
In vLLM V1, the scheduler does not distinguish between prefill and decode. Instead, tokens for requests doing prefill and decode are batched together, as illustrated below (source: vLLM blog):
When we skip tokens corresponding to prefill in the cross-decoder layers, we therefore will have the batch size reduced during model forward for the cross-decoder layers:
Without optimization enabled (baseline)
With optimization enabled (--kv-sharing-fast-prefill)
With this change, we can no longer compile the top-level model for 2 reasons:
Solution: we split the layers into self- and cross-decoder layers, and compile + graph capture them separately. For Gemma3n-E2B which has 30 layers, the first 20 layers and other 10 layers will be grouped separately into independently compiled and CUDA graph captured modules.
Other changes required in this PR:
make_kv_sharing_fast_prefill_common_attn_metadata
to create an attention metadata excluding all prefill tokens. This requires passinglogits_indices
toCommonAttentionMetadata
KVSharingFastPrefillAttentionMetadata
. This has two additional metadata (logits_indices_padded
andnum_logits_indices
) which are required for indexing into hidden states in the model implementation to match the shapes that the new attention metadata expectshidden_states
shape from[altup_num_inputs, num_tokens,hidden_size]
to[num_tokens,hidden_size, altup_num_inputs]
to ensurenum_tokens
(batch size) comes at dim 0. We cannot havenum_tokens
be on dim=1 because creating a slice along dim=1 would a) cause torch.compile tensor stride assertions to fail, and b) resolving this by callingcontiguous()
on the slice would cause memory copy and therefore violate CUDA graph static address constraint.--kv-sharing-fast-prefill
flag is passed, we take a differentself.fast_prefill_forward()
path which uses thelogits_indices_padded
metadata passed to index into the subset of tokens for cross-decoder layers (i.e. batch size is reduced). We then merge it back to the output of self-decoder to get the final output.--kv-sharing-fast-prefill
flag is passed, we will compile self-decoder and cross-decoder submodules separately, and we also need to pre-allocate static buffers for CUDA graph replay. If it is not passed (default), we will still compile the top-levelGemma3TextModel
Compared to trunk, the only difference is there are extra groups for attn_groups[0] and attn_groups[4] for layers which need a separate attention metadata builder for the fast prefill path. Previously it looks like this:
Important
When
prompt_logprobs
is enabled, we can no longer use fast prefill optimization. This is because by skipping all but last prefill tokens, the logits for the prompt tokens will no longer be valid. For example, multiple choice question (MCQ) evals useprompt_logprobs
to get the logprobs of continuation tokens (e.g. lm-evaluation-harness), so using--kv-sharing-fast-prefill
will yield inaccurate results. To prevent this, we issue a warning and disable the fast prefill optimization when there is at least one request in each scheduling round for whichprompt_logprobs
is set to values other thanNone
in its sampling params.Follow ups
--kv-sharing-fast-prefill
is set)self_decoder_hidden_state.clone()
in Gemma3nTest Plan
Evals
ran gsm8k, mmlu, mmlu pro
Unit tests
Performance
Perform sweep over
max-concurrency
andrandom-input-len
,$num_reqs = 256
max-num-batched-tokens = 8192
max-num-seqs = 128
Test Result
Evals
Evals on par
Unit tests
Unit tests all pass
Performance
Mean TTFT and TPOT (ms):