Skip to content

Commit 9ef64fd

Browse files
committed
Fix rebase conflicts
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 2a54824 commit 9ef64fd

File tree

4 files changed

+25
-73
lines changed

4 files changed

+25
-73
lines changed

vllm/attention/layers/chunked_local_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from vllm.attention.selector import get_attn_backend
1212
from vllm.config import CacheConfig, QuantizationConfig
1313
from vllm.v1.attention.backends.utils import (
14-
CommonAttentionMetadata,
15-
make_local_attention_virtual_batches,
14+
CommonAttentionMetadata, make_local_attention_virtual_batches,
1615
subclass_attention_backend)
1716

1817
from ..layer import Attention

vllm/v1/attention/backends/utils.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from abc import abstractmethod
77
from collections.abc import Hashable
88
from dataclasses import dataclass, fields, make_dataclass
9-
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
10-
Protocol, TypeVar)
9+
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
10+
TypeVar)
1111

1212
import numpy as np
1313
import torch
@@ -613,29 +613,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
613613
return common_attn_metadata
614614

615615

616-
def subclass_attention_metadata_builder(
617-
name_prefix: str,
618-
builder_cls: type[AttentionMetadataBuilder[M]],
619-
build: Callable[
620-
[AttentionMetadataBuilder[M], int, CommonAttentionMetadata, bool],
621-
AttentionMetadata,
622-
],
623-
) -> type[AttentionMetadataBuilder[M]]:
624-
"""
625-
Return a new subclass of `builder_cls` whose .build(...) method
626-
is monkey patched to a custom build function.
627-
"""
628-
name: str = name_prefix + builder_cls.__name__ # type: ignore
629-
630-
Wrapped = type(
631-
name,
632-
(builder_cls, ), # inherit from the original
633-
{
634-
"build": build,
635-
})
636-
return Wrapped # type: ignore
637-
638-
639616
def subclass_attention_backend(
640617
name_prefix: str, attention_backend_cls: type[AttentionBackend],
641618
builder_cls: type[AttentionMetadataBuilder[M]]
@@ -826,35 +803,29 @@ def create_kv_sharing_fast_prefill_attn_metadata_subclass(
826803
return attn_metadata_i
827804

828805

829-
@functools.lru_cache
830806
def create_fast_prefill_custom_backend(
831807
prefix: str,
832808
underlying_attn_backend: AttentionBackend,
833809
) -> type[AttentionBackend]:
834810

835-
def build(self,
836-
common_prefix_len: int,
837-
common_attn_metadata: CommonAttentionMetadata,
838-
fast_build: bool = False) -> AttentionMetadata:
839-
new_common_attn_metadata =\
840-
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
841-
metadata = super(self.__class__,
842-
self).build(common_prefix_len,
811+
underlying_builder = underlying_attn_backend.get_builder_cls()
812+
813+
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
814+
815+
def build(self,
816+
common_prefix_len: int,
817+
common_attn_metadata: CommonAttentionMetadata,
818+
fast_build: bool = False) -> AttentionMetadata:
819+
new_common_attn_metadata =\
820+
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
821+
metadata = super().build(common_prefix_len,
843822
new_common_attn_metadata, fast_build)
844-
return create_kv_sharing_fast_prefill_attn_metadata_subclass(
845-
metadata, common_attn_metadata)
823+
return create_kv_sharing_fast_prefill_attn_metadata_subclass(
824+
metadata, common_attn_metadata)
846825

847-
# Dynamically create a new attention backend that wraps the
848-
# underlying attention backend but applies
849-
# `build_preproces_fn` before calling `build(...)`
850-
builder_cls = subclass_attention_metadata_builder(
851-
name_prefix=prefix,
852-
builder_cls=underlying_attn_backend.get_builder_cls(),
853-
build=build,
854-
)
855826
attn_backend = subclass_attention_backend(
856827
name_prefix=prefix,
857828
attention_backend_cls=underlying_attn_backend,
858-
builder_cls=builder_cls)
829+
builder_cls=FastPrefillAttentionBuilder)
859830

860831
return attn_backend

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2993,9 +2993,7 @@ def _reshape_kv_cache_tensors(
29932993
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
29942994
attn_backend = group.backend
29952995
for layer_name in group.layer_names:
2996-
if (
2997-
layer_name in self.runner_only_attn_layers
2998-
):
2996+
if layer_name in self.runner_only_attn_layers:
29992997
continue
30002998
raw_tensor = kv_cache_raw_tensors[layer_name]
30012999
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
@@ -3110,26 +3108,12 @@ def initialize_kv_cache_tensors(
31103108
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
31113109
kv_cache_raw_tensors)
31123110

3113-
# Setup `kv_cache_config` and `kv_caches` for models
3114-
# with cross-layer KV sharing
3115-
if self.shared_kv_cache_layers:
3116-
initialize_kv_cache_for_kv_sharing(
3117-
self.shared_kv_cache_layers,
3118-
kv_cache_config.kv_cache_groups,
3119-
kv_caches,
3120-
self.attn_groups,
3121-
self.runner_only_attn_layers,
3122-
)
3123-
attn_layers = get_layers_from_vllm_config(self.vllm_config,
3124-
Attention)
3125-
# Iterate in reversed order and add layers that re-use KV cache
3126-
# e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
3127-
for layer_name in reversed(attn_layers):
3128-
if layer_name in self.shared_kv_cache_layers:
3129-
self.kv_sharing_fast_prefill_eligible_layers.add(
3130-
layer_name)
3131-
else:
3132-
break
3111+
# Set up cross-layer KV cache sharing
3112+
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
3113+
):
3114+
logger.debug("%s reuses KV cache of %s", layer_name,
3115+
target_layer_name)
3116+
kv_caches[layer_name] = kv_caches[target_layer_name]
31333117

31343118
bind_kv_cache(kv_caches,
31353119
self.compilation_config.static_forward_context,
@@ -3149,6 +3133,7 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups(
31493133
add_kv_sharing_layers_to_kv_cache_groups(
31503134
self.shared_kv_cache_layers,
31513135
kv_cache_config.kv_cache_groups,
3136+
self.runner_only_attn_layers,
31523137
)
31533138

31543139
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:

vllm/v1/worker/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,6 @@ def gather_mm_placeholders(
201201
def add_kv_sharing_layers_to_kv_cache_groups(
202202
shared_kv_cache_layers: dict[str, str],
203203
kv_cache_groups: list[KVCacheGroupSpec],
204-
kv_caches: dict[str, torch.Tensor],
205-
# Optional for now to avoid breaking TPU
206-
attn_groups: Optional[list[list[AttentionGroup]]] = None,
207204
runner_only_attn_layers: Optional[set[str]] = None,
208205
) -> None:
209206
"""

0 commit comments

Comments
 (0)