|
12 | 12 |
|
13 | 13 | from vllm.attention import AttentionType, get_attn_backend
|
14 | 14 | from vllm.attention.layer import Attention
|
15 |
| -from vllm.config import CompilationLevel, VllmConfig |
| 15 | +from vllm.config import (CompilationLevel, VllmConfig, |
| 16 | + get_layers_from_vllm_config) |
16 | 17 | from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
17 | 18 | has_kv_transfer_group)
|
18 | 19 | from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
19 | 20 | from vllm.forward_context import set_forward_context
|
20 | 21 | from vllm.logger import init_logger
|
21 |
| -from vllm.model_executor.layers.fused_moe import FusedMoE |
22 | 22 | from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
23 | 23 | from vllm.model_executor.model_loader import get_model
|
24 | 24 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
@@ -1733,17 +1733,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
1733 | 1733 | format. Layers that do not need KV cache are not included.
|
1734 | 1734 | """
|
1735 | 1735 |
|
1736 |
| - forward_ctx = self.vllm_config.compilation_config.static_forward_context |
| 1736 | + layers = get_layers_from_vllm_config(self.vllm_config, Attention) |
1737 | 1737 | block_size = self.vllm_config.cache_config.block_size
|
1738 | 1738 | use_mla = self.vllm_config.model_config.use_mla
|
1739 | 1739 | kv_cache_spec: dict[str, KVCacheSpec] = {}
|
1740 |
| - for layer_name, attn_module in forward_ctx.items(): |
1741 |
| - if isinstance(attn_module, FusedMoE): |
1742 |
| - continue |
1743 |
| - |
1744 |
| - # TODO: Support other attention modules, e.g., sliding window, |
1745 |
| - # cross-attention |
1746 |
| - assert isinstance(attn_module, Attention) |
| 1740 | + for layer_name, attn_module in layers.items(): |
| 1741 | + # TODO: Support other attention modules, e.g., cross-attention |
1747 | 1742 | if attn_module.attn_type == AttentionType.DECODER:
|
1748 | 1743 | if attn_module.sliding_window is not None:
|
1749 | 1744 | kv_cache_spec[layer_name] = SlidingWindowSpec(
|
|
0 commit comments