Skip to content

Commit cb293f6

Browse files
authored
[V1] Enable prefill optimization for Gemma3n (#22628)
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 7ffbf27 commit cb293f6

File tree

9 files changed

+583
-228
lines changed

9 files changed

+583
-228
lines changed

tests/v1/e2e/test_kv_sharing_fast_prefill.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,75 +2,20 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import random
5-
from typing import Optional, Union
65

76
import pytest
87
import torch
98

109
from vllm import LLM, SamplingParams
1110
from vllm.config import CompilationConfig, CompilationLevel
1211
from vllm.distributed import cleanup_dist_env_and_memory
13-
from vllm.forward_context import get_forward_context
14-
from vllm.model_executor.models.gemma3n_mm import (
15-
Gemma3nForConditionalGeneration)
16-
from vllm.model_executor.models.registry import ModelRegistry
17-
from vllm.model_executor.models.utils import extract_layer_index
18-
from vllm.sequence import IntermediateTensors
1912

2013
from ...utils import fork_new_process_for_each_test
2114

2215
# global seed
2316
SEED = 42
2417

2518

26-
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
27-
28-
def forward(
29-
self,
30-
input_ids: torch.Tensor,
31-
positions: torch.Tensor,
32-
intermediate_tensors: Optional[IntermediateTensors] = None,
33-
inputs_embeds: Optional[torch.Tensor] = None,
34-
**kwargs,
35-
) -> Union[torch.Tensor, IntermediateTensors]:
36-
hidden_states = super().forward(input_ids, positions,
37-
intermediate_tensors, inputs_embeds,
38-
**kwargs)
39-
attn_metadata = get_forward_context().attn_metadata
40-
# attn_metadata is None during dummy runs
41-
if (attn_metadata is not None
42-
and self.language_model.cache_config.kv_sharing_fast_prefill):
43-
assert isinstance(attn_metadata, dict) # true in V1
44-
# Gemma3n-E2B has 30 layers, with last 20 layers being
45-
# cross-decoder layers. Check attention metadata is correct
46-
for layer_name, metadata in attn_metadata.items():
47-
layer_idx = extract_layer_index(layer_name)
48-
if layer_idx >= 20:
49-
assert hasattr(metadata, 'logits_indices_padded')
50-
assert hasattr(metadata, 'num_logits_indices')
51-
else:
52-
assert not hasattr(metadata, 'logits_indices_padded')
53-
assert not hasattr(metadata, 'num_logits_indices')
54-
55-
# Last layer will be a KV sharing layer
56-
layer_attn_metadata = attn_metadata[
57-
self.language_model.model.layers[-1].self_attn.attn.layer_name]
58-
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
59-
assert logits_indices_padded is not None
60-
num_logits_indices = layer_attn_metadata.num_logits_indices
61-
assert num_logits_indices > 0
62-
# Reset hidden states to random values and
63-
# only set logits at logits_indices to valid values
64-
# Because logits_indices are the only positions that are used
65-
# for output token sampling, this still produces same outputs
66-
logits_hs = hidden_states[logits_indices_padded]
67-
hidden_states = torch.randn_like(hidden_states)
68-
gen_indices = logits_indices_padded[:num_logits_indices]
69-
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
70-
71-
return hidden_states
72-
73-
7419
@pytest.fixture
7520
def test_prompts():
7621
"""
@@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
12469
enforce_eager: bool,
12570
test_prompts: list[str],
12671
):
127-
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
128-
TestGemma3nForConditionalGeneration)
12972
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
13073
compilation_config = CompilationConfig(
13174
# This allows vLLM compilation backend to handle allocating and

vllm/config/cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,19 @@ def __post_init__(self) -> None:
145145

146146
self._verify_cache_dtype()
147147
self._verify_prefix_caching()
148+
self._verify_kv_sharing_fast_prefill()
148149

149150
def metrics_info(self):
150151
# convert cache_config to dict(key: str, value: str) for prometheus
151152
# metrics info
152153
return {key: str(value) for key, value in self.__dict__.items()}
153154

155+
def _verify_kv_sharing_fast_prefill(self) -> None:
156+
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
157+
raise NotImplementedError(
158+
"Fast prefill optimization for KV sharing is not supported "
159+
"in V0 currently.")
160+
154161
@model_validator(mode='after')
155162
def _verify_args(self) -> Self:
156163
if self.cpu_offload_gb < 0:
@@ -162,11 +169,6 @@ def _verify_args(self) -> Self:
162169
"GPU memory utilization must be less than 1.0. Got "
163170
f"{self.gpu_memory_utilization}.")
164171

165-
if self.kv_sharing_fast_prefill:
166-
logger.warning_once(
167-
"--kv-sharing-fast-prefill is currently work in progress "
168-
"and not functional yet (i.e. no prefill savings)")
169-
170172
return self
171173

172174
def _verify_cache_dtype(self) -> None:

0 commit comments

Comments
 (0)