@@ -82,8 +82,15 @@ def __init__(
82
82
self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
83
83
cache_config .cache_dtype ]
84
84
85
- self .is_multimodal_model = model_config .is_multimodal_model
85
+ # NOTE(woosuk): sliding_window is None for models with interleaved
86
+ # attention. Use interleaved_sliding_window instead.
86
87
self .sliding_window = model_config .get_sliding_window ()
88
+ self .interleaved_sliding_window = getattr (
89
+ model_config .hf_text_config , "interleaved_sliding_window" , None )
90
+ self .window_size = (self .sliding_window
91
+ or self .interleaved_sliding_window )
92
+
93
+ self .is_multimodal_model = model_config .is_multimodal_model
87
94
self .block_size = cache_config .block_size
88
95
self .max_model_len = model_config .max_model_len
89
96
self .max_num_blocks_per_req = cdiv (self .max_model_len , self .block_size )
@@ -674,7 +681,7 @@ def _compute_cascade_attn_prefix_len(
674
681
num_query_heads = self .num_query_heads ,
675
682
num_kv_heads = self .num_kv_heads ,
676
683
use_alibi = False , # FIXME
677
- use_sliding_window = self .sliding_window is not None ,
684
+ use_sliding_window = self .window_size is not None ,
678
685
num_sms = self .num_sms ,
679
686
)
680
687
return common_prefix_len if use_cascade else 0
0 commit comments