Skip to content

Commit ad51030

Browse files
authored
Override attention metadata for fast prefill in some KV sharing setups (#21590)
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 366f6b3 commit ad51030

File tree

6 files changed

+287
-26
lines changed

6 files changed

+287
-26
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import gc
5+
import random
6+
from typing import Optional, Union
7+
8+
import pytest
9+
import torch
10+
11+
from vllm import LLM, SamplingParams
12+
from vllm.config import CompilationConfig, CompilationLevel
13+
from vllm.forward_context import get_forward_context
14+
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
15+
from vllm.model_executor.models.registry import ModelRegistry
16+
from vllm.model_executor.models.utils import extract_layer_index
17+
from vllm.sequence import IntermediateTensors
18+
19+
from ...utils import fork_new_process_for_each_test
20+
21+
22+
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
23+
24+
def forward(
25+
self,
26+
input_ids: torch.Tensor,
27+
positions: torch.Tensor,
28+
intermediate_tensors: Optional[IntermediateTensors] = None,
29+
inputs_embeds: Optional[torch.Tensor] = None,
30+
**kwargs,
31+
) -> Union[torch.Tensor, IntermediateTensors]:
32+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
33+
inputs_embeds, **kwargs)
34+
attn_metadata = get_forward_context().attn_metadata
35+
# attn_metadata is None during dummy runs
36+
if (attn_metadata is not None
37+
and self.cache_config.kv_sharing_fast_prefill):
38+
assert isinstance(attn_metadata, dict) # true in V1
39+
# Gemma3n-E2B has 30 layers, with last 20 layers being
40+
# cross-decoder layers. Check attention metadata is correct
41+
for layer_name, metadata in attn_metadata.items():
42+
layer_idx = extract_layer_index(layer_name)
43+
if layer_idx >= 20:
44+
assert hasattr(metadata, 'logits_indices_padded')
45+
assert hasattr(metadata, 'num_logits_indices')
46+
else:
47+
assert not hasattr(metadata, 'logits_indices_padded')
48+
assert not hasattr(metadata, 'num_logits_indices')
49+
50+
# Last layer will be a KV sharing layer
51+
layer_attn_metadata = attn_metadata[
52+
self.model.language_model.layers[-1].self_attn.attn.layer_name]
53+
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
54+
assert logits_indices_padded is not None
55+
num_logits_indices = layer_attn_metadata.num_logits_indices
56+
assert num_logits_indices > 0
57+
# Reset hidden states to random values and
58+
# only set logits at logits_indices to valid values
59+
# Because logits_indices are the only positions that are used
60+
# for output token sampling, this still produces same outputs
61+
logits_hs = hidden_states[logits_indices_padded]
62+
hidden_states = torch.randn_like(hidden_states)
63+
gen_indices = logits_indices_padded[:num_logits_indices]
64+
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
65+
66+
return hidden_states
67+
68+
69+
@pytest.fixture
70+
def test_prompts():
71+
"""
72+
Adapted from tests/v1/e2e/test_spec_decode.py
73+
"""
74+
prompt_types = ["repeat", "sentence"]
75+
# Setting higher num prompts increases the chance of numerics mismatch
76+
# due to matrix multiplication numerics depending on batch dimension
77+
num_prompts = 10
78+
prompts = []
79+
80+
random.seed(0)
81+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
82+
83+
for kind in random_prompt_type_choices:
84+
word_choices = ["test", "temp", "hello", "where"]
85+
word = random.choice(word_choices)
86+
if kind == "repeat":
87+
prompt = f"""please repeat the word '{word}' 10 times."""
88+
elif kind == "sentence":
89+
prompt = f"""please give a ten-word sentence that
90+
uses the word {word} at least once."""
91+
else:
92+
raise ValueError(f"Unknown prompt type: {kind}")
93+
prompts.append(prompt)
94+
95+
return prompts
96+
97+
98+
@fork_new_process_for_each_test
99+
@pytest.mark.parametrize("enforce_eager", [True, False])
100+
def test_kv_sharing_fast_prefill(
101+
monkeypatch: pytest.MonkeyPatch,
102+
enforce_eager: bool,
103+
test_prompts: list[str],
104+
):
105+
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
106+
TestGemma3nForConditionalGeneration)
107+
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
108+
compilation_config = CompilationConfig(
109+
# This allows vLLM compilation backend to handle allocating and
110+
# managing buffers for cudagraph
111+
cudagraph_copy_inputs=True,
112+
level=CompilationLevel.PIECEWISE
113+
if not enforce_eager else CompilationLevel.NO_COMPILATION)
114+
115+
with monkeypatch.context() as m:
116+
m.setenv("VLLM_USE_V1", "1")
117+
118+
llm = LLM(
119+
model="google/gemma-3n-E2B-it",
120+
enforce_eager=enforce_eager,
121+
compilation_config=compilation_config,
122+
)
123+
ref_responses = llm.generate(test_prompts, sampling_params)
124+
125+
del llm
126+
gc.collect()
127+
torch.cuda.empty_cache()
128+
129+
llm = LLM(model="google/gemma-3n-E2B-it",
130+
enforce_eager=enforce_eager,
131+
compilation_config=compilation_config,
132+
kv_sharing_fast_prefill=True)
133+
optimized_responses = llm.generate(test_prompts, sampling_params)
134+
135+
misses = 0
136+
137+
for ref_response, optimized_response in zip(ref_responses,
138+
optimized_responses):
139+
if ref_response.outputs[0].text != optimized_response.outputs[
140+
0].text:
141+
misses += 1
142+
143+
assert misses == 0

vllm/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,16 @@ class CacheConfig:
17951795
num_cpu_blocks: Optional[int] = field(default=None, init=False)
17961796
"""The number of blocks to allocate for CPU memory."""
17971797

1798+
kv_sharing_fast_prefill: bool = False
1799+
"""This feature is work in progress and no prefill optimization takes place
1800+
with this flag enabled currently.
1801+
1802+
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
1803+
some layers can skip tokens corresponding to prefill. This flag enables
1804+
attention metadata for eligible layers to be overriden with metadata
1805+
necessary for implementating this optimization in some models (e.g. Gemma3n)
1806+
"""
1807+
17981808
def compute_hash(self) -> str:
17991809
"""
18001810
WARNING: Whenever a new field is added to this config,
@@ -1836,6 +1846,11 @@ def _verify_args(self) -> Self:
18361846
"GPU memory utilization must be less than 1.0. Got "
18371847
f"{self.gpu_memory_utilization}.")
18381848

1849+
if self.kv_sharing_fast_prefill:
1850+
logger.warning_once(
1851+
"--kv-sharing-fast-prefill is currently work in progress "
1852+
"and not functional yet (i.e. no prefill savings)")
1853+
18391854
return self
18401855

18411856
def _verify_cache_dtype(self) -> None:

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,9 @@ class EngineArgs:
445445
# DEPRECATED
446446
enable_prompt_adapter: bool = False
447447

448+
kv_sharing_fast_prefill: bool = \
449+
CacheConfig.kv_sharing_fast_prefill
450+
448451
def __post_init__(self):
449452
# support `EngineArgs(compilation_config={...})`
450453
# without having to manually construct a
@@ -697,6 +700,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
697700
**cache_kwargs["cpu_offload_gb"])
698701
cache_group.add_argument("--calculate-kv-scales",
699702
**cache_kwargs["calculate_kv_scales"])
703+
cache_group.add_argument("--kv-sharing-fast-prefill",
704+
**cache_kwargs["kv_sharing_fast_prefill"])
700705

701706
# Multimodal related configs
702707
multimodal_kwargs = get_kwargs(MultiModalConfig)
@@ -1069,6 +1074,7 @@ def create_engine_config(
10691074
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
10701075
cpu_offload_gb=self.cpu_offload_gb,
10711076
calculate_kv_scales=self.calculate_kv_scales,
1077+
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
10721078
)
10731079

10741080
# Get the current placement group if Ray is initialized and

vllm/model_executor/models/gemma3n.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
793793
del lora_config # Unused.
794794
super().__init__()
795795
self.config = config
796+
self.cache_config = vllm_config.cache_config
796797
self.model = Gemma3nModel(vllm_config=vllm_config,
797798
prefix=maybe_prefix(prefix, "model"))
798799
self.logits_processor = LogitsProcessor(

vllm/v1/attention/backends/utils.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import abc
44
import functools
55
from abc import abstractmethod
6-
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
6+
from dataclasses import dataclass, make_dataclass
7+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
88

99
import numpy as np
1010
import torch
@@ -508,3 +508,34 @@ def reorder_batch_to_split_decodes_and_prefills(
508508
modified_batch = True
509509

510510
return modified_batch
511+
512+
513+
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
514+
('logits_indices_padded', Optional[torch.Tensor], None),
515+
('num_logits_indices', int, 0),
516+
]
517+
518+
519+
def subclass_attention_metadata(
520+
name_prefix: str,
521+
metadata_cls: Any,
522+
fields: list[tuple[str, Any, Any]],
523+
) -> Any:
524+
"""
525+
Return a new subclass of `metadata_cls` with additional fields
526+
"""
527+
name: str = name_prefix + metadata_cls.__name__ # type: ignore
528+
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
529+
return Wrapped
530+
531+
532+
def make_kv_sharing_fast_prefill_attention_metadata(
533+
metadata_cls: Any, ) -> Any:
534+
"""
535+
Return a new subclass of `metadata_cls` for fast prefill
536+
"""
537+
return subclass_attention_metadata(
538+
name_prefix="KVSharingFastPrefill",
539+
metadata_cls=metadata_cls,
540+
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
541+
)

0 commit comments

Comments
 (0)