Skip to content

Commit 838ceda

Browse files
authored
[Bugfix] Get a specific type of layer from forward context (#17222)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 4283a28 commit 838ceda

File tree

5 files changed

+28
-23
lines changed

5 files changed

+28
-23
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
is_block_tables_empty)
3939
from vllm.attention.layer import Attention
4040
from vllm.attention.ops.paged_attn import PagedAttention
41-
from vllm.config import VllmConfig
41+
from vllm.config import VllmConfig, get_layers_from_vllm_config
4242
from vllm.logger import init_logger
4343
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
4444
make_tensor_with_pad)
@@ -140,12 +140,10 @@ def get_per_layer_parameters(
140140
to use during `plan`.
141141
"""
142142

143-
layers = vllm_config.compilation_config.static_forward_context
143+
layers = get_layers_from_vllm_config(vllm_config, Attention)
144144
per_layer_params: Dict[str, PerLayerParameters] = {}
145145

146146
for key, layer in layers.items():
147-
assert isinstance(layer, Attention)
148-
149147
impl = layer.impl
150148
assert isinstance(impl, FlashInferImpl)
151149

vllm/config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3445,7 +3445,8 @@ def model_post_init(self, __context: Any) -> None:
34453445
compilation_time: float = PrivateAttr
34463446

34473447
# Per-model forward context
3448-
# Map from layer name to the attention cls
3448+
# Map from layer name to layer objects that need to be accessed outside
3449+
# model code, e.g., Attention, FusedMOE when dp_size>1.
34493450
static_forward_context: dict[str, Any] = PrivateAttr
34503451

34513452
def compute_hash(self) -> str:
@@ -4079,3 +4080,16 @@ def assert_hashable(text):
40794080
f"vLLM tried to hash some configs that may have Python objects ids "
40804081
f"in them. This is a bug, please file an issue. "
40814082
f"Text being hashed: {text}")
4083+
4084+
4085+
T = TypeVar("T")
4086+
4087+
4088+
def get_layers_from_vllm_config(vllm_config: VllmConfig,
4089+
layer_type: type[T]) -> dict[str, T]:
4090+
return {
4091+
layer_name: layer
4092+
for layer_name, layer in
4093+
vllm_config.compilation_config.static_forward_context.items()
4094+
if isinstance(layer, layer_type)
4095+
}

vllm/v1/attention/backends/flashinfer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1515
AttentionType)
1616
from vllm.attention.layer import Attention
17-
from vllm.config import VllmConfig, get_current_vllm_config
17+
from vllm.config import (VllmConfig, get_current_vllm_config,
18+
get_layers_from_vllm_config)
1819
from vllm.logger import init_logger
1920
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2021

@@ -81,12 +82,10 @@ def get_per_layer_parameters(
8182
to use during `plan`.
8283
"""
8384

84-
layers = vllm_config.compilation_config.static_forward_context
85+
layers = get_layers_from_vllm_config(vllm_config, Attention)
8586
per_layer_params: dict[str, PerLayerParameters] = {}
8687

8788
for key, layer in layers.items():
88-
assert isinstance(layer, Attention)
89-
9089
impl = layer.impl
9190
assert isinstance(impl, FlashInferImpl)
9291

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212

1313
from vllm.attention import AttentionType, get_attn_backend
1414
from vllm.attention.layer import Attention
15-
from vllm.config import CompilationLevel, VllmConfig
15+
from vllm.config import (CompilationLevel, VllmConfig,
16+
get_layers_from_vllm_config)
1617
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
1718
has_kv_transfer_group)
1819
from vllm.distributed.parallel_state import get_pp_group, graph_capture
1920
from vllm.forward_context import set_forward_context
2021
from vllm.logger import init_logger
21-
from vllm.model_executor.layers.fused_moe import FusedMoE
2222
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
2323
from vllm.model_executor.model_loader import get_model
2424
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -1733,17 +1733,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
17331733
format. Layers that do not need KV cache are not included.
17341734
"""
17351735

1736-
forward_ctx = self.vllm_config.compilation_config.static_forward_context
1736+
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
17371737
block_size = self.vllm_config.cache_config.block_size
17381738
use_mla = self.vllm_config.model_config.use_mla
17391739
kv_cache_spec: dict[str, KVCacheSpec] = {}
1740-
for layer_name, attn_module in forward_ctx.items():
1741-
if isinstance(attn_module, FusedMoE):
1742-
continue
1743-
1744-
# TODO: Support other attention modules, e.g., sliding window,
1745-
# cross-attention
1746-
assert isinstance(attn_module, Attention)
1740+
for layer_name, attn_module in layers.items():
1741+
# TODO: Support other attention modules, e.g., cross-attention
17471742
if attn_module.attn_type == AttentionType.DECODER:
17481743
if attn_module.sliding_window is not None:
17491744
kv_cache_spec[layer_name] = SlidingWindowSpec(

vllm/v1/worker/tpu_model_runner.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.attention.backends.abstract import AttentionType
1818
from vllm.attention.layer import Attention
1919
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
20-
from vllm.config import VllmConfig
20+
from vllm.config import VllmConfig, get_layers_from_vllm_config
2121
from vllm.forward_context import set_forward_context
2222
from vllm.logger import init_logger
2323
from vllm.model_executor.model_loader import get_model
@@ -429,11 +429,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
429429
format. Layers that do not need KV cache are not included.
430430
"""
431431

432-
forward_ctx = self.vllm_config.compilation_config.static_forward_context
432+
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
433433
block_size = self.vllm_config.cache_config.block_size
434434
kv_cache_spec: dict[str, KVCacheSpec] = {}
435-
for layer_name, attn_module in forward_ctx.items():
436-
assert isinstance(attn_module, Attention)
435+
for layer_name, attn_module in layers.items():
437436
if attn_module.attn_type == AttentionType.DECODER:
438437
if attn_module.sliding_window is not None:
439438
kv_cache_spec[layer_name] = SlidingWindowSpec(

0 commit comments

Comments
 (0)