Skip to content

Commit 31060b2

Browse files
authored
[V1][BugFix] Detect interleaved sliding window attention (#14896)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent fc1f677 commit 31060b2

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,15 @@ def __init__(
8282
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
8383
cache_config.cache_dtype]
8484

85-
self.is_multimodal_model = model_config.is_multimodal_model
85+
# NOTE(woosuk): sliding_window is None for models with interleaved
86+
# attention. Use interleaved_sliding_window instead.
8687
self.sliding_window = model_config.get_sliding_window()
88+
self.interleaved_sliding_window = getattr(
89+
model_config.hf_text_config, "interleaved_sliding_window", None)
90+
self.window_size = (self.sliding_window
91+
or self.interleaved_sliding_window)
92+
93+
self.is_multimodal_model = model_config.is_multimodal_model
8794
self.block_size = cache_config.block_size
8895
self.max_model_len = model_config.max_model_len
8996
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
@@ -674,7 +681,7 @@ def _compute_cascade_attn_prefix_len(
674681
num_query_heads=self.num_query_heads,
675682
num_kv_heads=self.num_kv_heads,
676683
use_alibi=False, # FIXME
677-
use_sliding_window=self.sliding_window is not None,
684+
use_sliding_window=self.window_size is not None,
678685
num_sms=self.num_sms,
679686
)
680687
return common_prefix_len if use_cascade else 0

0 commit comments

Comments
 (0)