Skip to content

Commit b6ee234

Browse files
committed
Address comments
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent c447e1d commit b6ee234

File tree

4 files changed

+51
-82
lines changed

4 files changed

+51
-82
lines changed

vllm/model_executor/models/gemma3n.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@
4747
default_weight_loader, maybe_remap_kv_scale_name)
4848
from vllm.model_executor.sampling_metadata import SamplingMetadata
4949
from vllm.sequence import IntermediateTensors
50-
from vllm.v1.attention.backends.utils import (
51-
KVSharingFastPrefillAttentionMetadata)
50+
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
5251

5352
from .interfaces import SupportsQuant
5453
from .utils import (AutoWeightsLoader, extract_layer_index,
@@ -866,8 +865,7 @@ def fast_prefill_forward(
866865
# Last layer is a KV sharing layer
867866
layer_attn_metadata = attn_metadata[
868867
self.layers[-1].self_attn.attn.layer_name]
869-
if (isinstance(layer_attn_metadata,
870-
KVSharingFastPrefillAttentionMetadata)):
868+
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
871869
logits_indices_padded = (
872870
layer_attn_metadata.logits_indices_padded)
873871
num_logits_indices = layer_attn_metadata.num_logits_indices

vllm/v1/attention/backends/utils.py

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import enum
55
import functools
66
from abc import abstractmethod
7-
from collections.abc import Hashable
87
from dataclasses import dataclass, fields, make_dataclass
98
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
109
TypeVar)
@@ -67,11 +66,12 @@ class CommonAttentionMetadata:
6766
block_table_tensor: torch.Tensor
6867
slot_mapping: torch.Tensor
6968

69+
causal: bool = True
70+
71+
# Needed by FastPrefillAttentionBuilder
7072
logits_indices_padded: Optional[torch.Tensor] = None
7173
num_logits_indices: Optional[int] = None
7274

73-
causal: bool = True
74-
7575

7676
@dataclass
7777
class UbatchSlice:
@@ -557,9 +557,8 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
557557
# Skip computing fast prefill path
558558
return common_attn_metadata
559559

560-
if (common_attn_metadata.logits_indices_padded is None
561-
or common_attn_metadata.num_logits_indices is None):
562-
return common_attn_metadata
560+
assert common_attn_metadata.logits_indices_padded is not None
561+
assert common_attn_metadata.num_logits_indices is not None
563562

564563
logits_indices_padded = common_attn_metadata.logits_indices_padded
565564
num_logits_indices = common_attn_metadata.num_logits_indices
@@ -750,59 +749,12 @@ def subclass_attention_metadata(
750749
return Wrapped
751750

752751

753-
@functools.lru_cache
754-
def make_kv_sharing_fast_prefill_attention_metadata(
755-
metadata_cls: Hashable, ) -> Any:
756-
"""
757-
Return a new subclass of `metadata_cls` for fast prefill
758-
"""
759-
attn_metadata_dataclass = subclass_attention_metadata(
760-
name_prefix="KVSharingFastPrefill",
761-
metadata_cls=metadata_cls,
762-
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
763-
)
764-
# Make attention metadata type inherit
765-
# KVSharingFastPrefillAttentionMetadata type
766-
fast_prefill_metadata_type = type(
767-
attn_metadata_dataclass.__name__,
768-
(
769-
attn_metadata_dataclass,
770-
KVSharingFastPrefillAttentionMetadata,
771-
),
772-
{},
773-
)
774-
return fast_prefill_metadata_type
775-
776-
777752
@runtime_checkable
778-
class KVSharingFastPrefillAttentionMetadata(Protocol):
753+
class KVSharingFastPrefillMetadata(Protocol):
779754
logits_indices_padded: torch.Tensor
780755
num_logits_indices: int
781756

782757

783-
def create_kv_sharing_fast_prefill_attn_metadata_subclass(
784-
metadata: Any,
785-
common_attn_metadata: CommonAttentionMetadata,
786-
) -> Any:
787-
# Dynamically create a a dataclass type that inherits
788-
# from attention metadata type but includes additional
789-
# fields logits_indices_padded and num_logits_indices
790-
# which are required for prefill truncation
791-
fast_prefill_metadata_type = (
792-
make_kv_sharing_fast_prefill_attention_metadata(
793-
metadata_cls=type(metadata), )) # type: ignore
794-
# Avoid deepcopy caused by dict.asdict
795-
attn_metadata_fields = {}
796-
for field in fields(metadata.__class__):
797-
attn_metadata_fields[field.name] = getattr(metadata, field.name)
798-
attn_metadata_i = fast_prefill_metadata_type(
799-
**attn_metadata_fields,
800-
logits_indices_padded=common_attn_metadata.logits_indices_padded,
801-
num_logits_indices=common_attn_metadata.num_logits_indices,
802-
)
803-
return attn_metadata_i
804-
805-
806758
def create_fast_prefill_custom_backend(
807759
prefix: str,
808760
underlying_attn_backend: AttentionBackend,
@@ -820,7 +772,27 @@ def build(self,
820772
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
821773
metadata = super().build(common_prefix_len,
822774
new_common_attn_metadata, fast_build)
823-
return create_kv_sharing_fast_prefill_attn_metadata_subclass(
775+
776+
class KVSharingFastPrefillAttentionMetadata(
777+
metadata.__class__, KVSharingFastPrefillMetadata):
778+
779+
def __init__(self, metadata, common_attention_metadata):
780+
# Shallow copy all fields in metadata cls
781+
for field in fields(metadata.__class__):
782+
setattr(self, field.name,
783+
getattr(metadata, field.name))
784+
785+
# Set additional fields that will be used in model code
786+
assert (common_attn_metadata.logits_indices_padded
787+
is not None
788+
and common_attn_metadata.num_logits_indices
789+
is not None)
790+
self.logits_indices_padded = \
791+
common_attn_metadata.logits_indices_padded
792+
self.num_logits_indices = \
793+
common_attn_metadata.num_logits_indices
794+
795+
return KVSharingFastPrefillAttentionMetadata(
824796
metadata, common_attn_metadata)
825797

826798
attn_backend = subclass_attention_backend(

vllm/v1/engine/async_llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,9 @@ async def generate(
338338
if (self.vllm_config.cache_config.kv_sharing_fast_prefill
339339
and sampling_params.prompt_logprobs):
340340
raise ValueError(
341-
"Fast prefill produces incorrect logprobs for prompt tokens")
341+
"--kv-sharing-fast-prefill produces incorrect logprobs for "
342+
"prompt tokens, please disable it when the requests need "
343+
"prompt logprobs")
342344

343345
try:
344346
# We start the output_handler on the first call to generate() so

vllm/v1/worker/gpu_model_runner.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,12 @@ def execute_model(
14561456
return self.kv_connector_no_forward(scheduler_output,
14571457
self.vllm_config)
14581458

1459+
if self.cache_config.kv_sharing_fast_prefill:
1460+
assert not self.input_batch.num_prompt_logprobs, (
1461+
"--kv-sharing-fast-prefill produces incorrect logprobs for "
1462+
"prompt tokens, tokens, please disable it when the requests "
1463+
"need prompt logprobs")
1464+
14591465
# Prepare the decoder inputs.
14601466
(attn_metadata, logits_indices, spec_decode_metadata,
14611467
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
@@ -3084,6 +3090,19 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups(
30843090
self.runner_only_attn_layers,
30853091
)
30863092

3093+
if self.cache_config.kv_sharing_fast_prefill:
3094+
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
3095+
# similar KV sharing setups, only the layers that generate KV caches
3096+
# are involved in the prefill phase, enabling prefill to early exit.
3097+
attn_layers = get_layers_from_vllm_config(self.vllm_config,
3098+
Attention)
3099+
for layer_name in reversed(attn_layers):
3100+
if layer_name in self.shared_kv_cache_layers:
3101+
self.kv_sharing_fast_prefill_eligible_layers.add(
3102+
layer_name)
3103+
else:
3104+
break
3105+
30873106
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
30883107
"""
30893108
Initialize KV cache based on `kv_cache_config`.
@@ -3092,8 +3111,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
30923111
cache size of each layer
30933112
"""
30943113
kv_cache_config = deepcopy(kv_cache_config)
3095-
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
3096-
self.maybe_add_kv_sharing_fast_prefill_layers(attn_layers)
30973114
self.kv_cache_config = kv_cache_config
30983115
self.may_reinitialize_input_batch(kv_cache_config)
30993116
self.may_add_encoder_only_layers_to_kv_cache_config()
@@ -3137,26 +3154,6 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
31373154
self.kv_cache_config.kv_cache_groups.append(
31383155
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
31393156

3140-
def maybe_add_kv_sharing_fast_prefill_layers(self,
3141-
attn_layers: dict[str,
3142-
Attention]):
3143-
"""
3144-
In You Only Cache Once (https://arxiv.org/abs/2405.05254), or other
3145-
similar KV sharing setups, the layers that re-use the shared KV cache
3146-
(cross-decoder layers) can skip prefill, as only the earlier layers
3147-
that generate KV caches are involved in the prefill phase.
3148-
"""
3149-
if not self.cache_config.kv_sharing_fast_prefill:
3150-
# Optimization disabled, return
3151-
return
3152-
3153-
# Iterate in reversed order and add layers that re-use KV cache
3154-
for layer_name in reversed(attn_layers):
3155-
if layer_name in self.shared_kv_cache_layers:
3156-
self.kv_sharing_fast_prefill_eligible_layers.add(layer_name)
3157-
else:
3158-
break
3159-
31603157
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
31613158
"""
31623159
Generates the KVCacheSpec by parsing the kv cache format from each

0 commit comments

Comments
 (0)