Skip to content

Commit c498483

Browse files
authored
Refactor sliding window configuration to Transformers best practice (#21927)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 2a84fb4 commit c498483

File tree

16 files changed

+124
-232
lines changed

16 files changed

+124
-232
lines changed

docs/contributing/model/basic.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
117117

118118
To support a model with interleaving sliding windows, we need to take care of the following details:
119119

120-
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
120+
- Make sure the model's `config.json` contains `layer_types`.
121121
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
122122

123123
With these two steps, interleave sliding windows should work with the model.

tests/test_config.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
200200
assert model_config.max_model_len == expected
201201

202202

203-
def test_get_sliding_window():
204-
TEST_SLIDING_WINDOW = 4096
205-
# Test that the sliding window is correctly computed.
206-
# For Qwen1.5/Qwen2, get_sliding_window() should be None
207-
# when use_sliding_window is False.
208-
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
209-
210-
qwen2_model_config.hf_config.use_sliding_window = False
211-
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
212-
assert qwen2_model_config.get_sliding_window() is None
213-
214-
qwen2_model_config.hf_config.use_sliding_window = True
215-
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
216-
217-
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
218-
mistral_model_config.hf_config.sliding_window = None
219-
assert mistral_model_config.get_sliding_window() is None
220-
221-
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
222-
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
223-
224-
225203
@pytest.mark.skipif(current_platform.is_rocm(),
226204
reason="Xformers backend is not supported on ROCm.")
227205
def test_get_pooling_config():

vllm/config/__init__.py

Lines changed: 31 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
ConfigFormat, get_config, get_hf_image_processor_config,
4141
get_hf_text_config, get_pooling_config,
4242
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
43-
maybe_override_with_speculators_target_model, try_get_generation_config,
44-
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
43+
is_interleaved, maybe_override_with_speculators_target_model,
44+
try_get_generation_config, try_get_safetensors_metadata,
45+
try_get_tokenizer_config, uses_mrope)
4546
from vllm.transformers_utils.s3_utils import S3Model
4647
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
4748
# yapf conflicts with isort for this block
@@ -714,53 +715,31 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
714715
revision=self.revision,
715716
)
716717

717-
# Workaround for Gemma 2 which uses interleaved sliding window
718-
# attention, but it's not specified in its config.
719-
# TODO: remove this when Gemma 2 config updated in HuggingFace.
720-
if self.hf_text_config.model_type == "gemma2":
721-
self.hf_text_config.sliding_window_pattern = 2
722-
723-
# TODO: remove this when Gemma 3n config updated in HuggingFace.
724-
if self.hf_text_config.model_type == "gemma3n_text":
725-
# 4 sliding window attention followed by 1 full attention
726-
self.hf_text_config.sliding_window_pattern = "LLLLG"
727-
728-
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
729-
sliding_window_pattern = getattr(self.hf_text_config,
730-
"sliding_window_pattern", None)
731-
has_interleaved_attention = sliding_window_pattern is not None or (
732-
isinstance(sliding_window, list))
733-
734-
if not self.disable_sliding_window and has_interleaved_attention:
735-
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
736-
) in ("XFORMERS", "FLASHINFER"):
737-
sliding_window_len_min = get_min_sliding_window(
738-
self.hf_text_config.sliding_window)
739-
740-
logger.warning_once(
741-
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
742-
self.hf_text_config.model_type,
743-
backend,
744-
sliding_window_len_min,
745-
)
746-
self.disable_sliding_window = True
747-
else:
748-
# for a model with interleaved attention,
749-
# the scheduler and the model treat it as full attention
750-
# (i.e., not dropping any tokens outside the window).
751-
# only the attention layer itself is aware of the sliding
752-
# window, and use the window size to compute the attention.
753-
self.hf_text_config.interleaved_sliding_window = sliding_window
754-
755-
if hasattr(self.hf_text_config, "sliding_window"):
756-
delattr(self.hf_text_config, "sliding_window")
757-
758-
sliding_window = None
718+
# Interleaved attention is not supported by some backends in V0
719+
if (not self.disable_sliding_window
720+
and is_interleaved(self.hf_text_config)
721+
and not envs.VLLM_USE_V1
722+
and (backend := envs.VLLM_ATTENTION_BACKEND)
723+
in ("XFORMERS", "FLASHINFER")):
724+
logger.warning_once(
725+
"%s has interleaved attention, which is currently not "
726+
"supported by the %s backend. Disabling sliding window and "
727+
"capping the max length to the sliding window size (%d).",
728+
self.hf_text_config.model_type,
729+
backend,
730+
self.hf_text_config.sliding_window,
731+
)
732+
self.disable_sliding_window = True
759733

760734
self.original_max_model_len = self.max_model_len
761735
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
762736
self.multimodal_config = self._init_multimodal_config()
763737

738+
if self.disable_sliding_window:
739+
# Set after get_and_verify_max_len to ensure that max_model_len
740+
# can be correctly capped to sliding window size
741+
self.hf_text_config.sliding_window = None
742+
764743
if not self.skip_tokenizer_init:
765744
self._verify_tokenizer_mode()
766745

@@ -1322,27 +1301,10 @@ def verify_with_parallel_config(
13221301
if self.use_async_output_proc:
13231302
self.use_async_output_proc = False
13241303

1325-
def get_hf_config_sliding_window(
1326-
self) -> Union[Optional[int], list[Optional[int]]]:
1327-
"""Get the sliding window size, or None if disabled."""
1328-
1329-
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
1330-
# addition to sliding window size. We check if that field is present
1331-
# and if it's False, return None.
1332-
if (hasattr(self.hf_text_config, "use_sliding_window")
1333-
and not self.hf_text_config.use_sliding_window):
1334-
return None
1304+
def get_sliding_window(self) -> Optional[int]:
1305+
"""Get the sliding window size from the HF text config if present."""
13351306
return getattr(self.hf_text_config, "sliding_window", None)
13361307

1337-
def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
1338-
"""Get the sliding window size, or None if disabled.
1339-
"""
1340-
# If user disables sliding window, return None.
1341-
if self.disable_sliding_window:
1342-
return None
1343-
# Otherwise get the value from the hf config.
1344-
return self.get_hf_config_sliding_window()
1345-
13461308
def get_vocab_size(self) -> int:
13471309
return getattr(self.hf_text_config, "vocab_size", 0)
13481310

@@ -1762,7 +1724,7 @@ def get_and_verify_max_len(self, max_model_len: int):
17621724
tokenizer_config=tokenizer_config,
17631725
max_model_len=max_model_len,
17641726
disable_sliding_window=self.disable_sliding_window,
1765-
sliding_window_len=self.get_hf_config_sliding_window(),
1727+
sliding_window=self.get_sliding_window(),
17661728
spec_target_max_model_len=self.spec_target_max_model_len,
17671729
encoder_config=self.encoder_config)
17681730
logger.info("Using max model len %s", max_model_len)
@@ -3305,7 +3267,7 @@ def _get_and_verify_max_len(
33053267
tokenizer_config: Optional[dict],
33063268
max_model_len: Optional[int],
33073269
disable_sliding_window: bool,
3308-
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
3270+
sliding_window: Optional[int],
33093271
spec_target_max_model_len: Optional[int] = None,
33103272
encoder_config: Optional[Any] = None,
33113273
) -> int:
@@ -3344,13 +3306,10 @@ def _get_and_verify_max_len(
33443306

33453307
# If sliding window is manually disabled, max_length should be less
33463308
# than the sliding window length in the model config.
3347-
if disable_sliding_window and sliding_window_len is not None:
3348-
3349-
sliding_window_len_min = get_min_sliding_window(sliding_window_len)
3350-
max_len_key = "sliding_window" \
3351-
if sliding_window_len_min < derived_max_model_len else max_len_key
3352-
derived_max_model_len = min(derived_max_model_len,
3353-
sliding_window_len_min)
3309+
if (disable_sliding_window and sliding_window is not None
3310+
and sliding_window < derived_max_model_len):
3311+
max_len_key = "sliding_window"
3312+
derived_max_model_len = sliding_window
33543313

33553314
# Consider model_max_length in tokenizer_config
33563315
if tokenizer_config:
@@ -3451,14 +3410,6 @@ def _get_and_verify_max_len(
34513410
return int(max_model_len)
34523411

34533412

3454-
def get_min_sliding_window(
3455-
sliding_window: Union[int, list[Optional[int]]]) -> int:
3456-
if isinstance(sliding_window, list):
3457-
return min(s for s in sliding_window if s is not None)
3458-
3459-
return sliding_window
3460-
3461-
34623413
def get_served_model_name(model: str,
34633414
served_model_name: Optional[Union[str, list[str]]]):
34643415
"""

vllm/engine/arg_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.ray.lazy_utils import is_ray_initialized
4040
from vllm.reasoning import ReasoningParserManager
4141
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
42+
from vllm.transformers_utils.config import is_interleaved
4243
from vllm.transformers_utils.utils import check_gguf_file
4344
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
4445
GiB_bytes, get_ip, is_in_ray_actor)
@@ -1081,14 +1082,21 @@ def create_engine_config(
10811082
"DualChunkFlashAttention is not supported on V1 engine. "
10821083
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
10831084

1085+
sliding_window: Optional[int] = None
1086+
if not is_interleaved(model_config.hf_text_config):
1087+
# Only set CacheConfig.sliding_window if the model is all sliding
1088+
# window. Otherwise CacheConfig.sliding_window will override the
1089+
# global layers in interleaved sliding window models.
1090+
sliding_window = model_config.get_sliding_window()
1091+
10841092
cache_config = CacheConfig(
10851093
block_size=self.block_size,
10861094
gpu_memory_utilization=self.gpu_memory_utilization,
10871095
swap_space=self.swap_space,
10881096
cache_dtype=self.kv_cache_dtype,
10891097
is_attention_free=model_config.is_attention_free,
10901098
num_gpu_blocks_override=self.num_gpu_blocks_override,
1091-
sliding_window=model_config.get_sliding_window(),
1099+
sliding_window=sliding_window,
10921100
enable_prefix_caching=self.enable_prefix_caching,
10931101
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
10941102
cpu_offload_gb=self.cpu_offload_gb,

vllm/model_executor/models/commandr.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,13 @@ def __init__(
182182
)
183183

184184
# Model v2 has interleaved sliding windows, v1 does not
185-
interleaved_sliding_window = getattr(config,
186-
"interleaved_sliding_window",
187-
None)
188-
self.v1 = interleaved_sliding_window is None
189-
190-
layer_idx = extract_layer_index(prefix)
191-
layer_has_sliding_window = (
192-
getattr(config, "sliding_window_pattern", False) and
193-
(layer_idx + 1) % self.config.sliding_window_pattern
194-
!= 0) or (getattr(config, "layer_types", False)
195-
and config.layer_types[layer_idx] == "sliding_attention")
196-
197-
self.sliding_window = (interleaved_sliding_window
198-
or config.sliding_window
199-
if layer_has_sliding_window else None)
185+
self.v1 = isinstance(config, CohereConfig)
186+
187+
self.sliding_window = None
188+
if not self.v1:
189+
layer_idx = extract_layer_index(prefix)
190+
if config.layer_types[layer_idx] == "sliding_attention":
191+
self.sliding_window = config.sliding_window
200192

201193
self.attn = Attention(self.num_heads,
202194
self.head_dim,

vllm/model_executor/models/exaone4.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -159,25 +159,12 @@ def __init__(
159159
if quant_config is not None and quant_config.get_name() == "gguf":
160160
is_neox_style = False
161161

162-
self.apply_all_layers = False # apply rotary embeddings to every layer.
163162
layer_idx = extract_layer_index(prefix)
164-
interleaved_sliding_window = getattr(config,
165-
"interleaved_sliding_window",
166-
4096)
167-
sliding_window_pattern = getattr(config, "sliding_window_pattern",
168-
"LLLG")
169-
170-
if sliding_window_pattern:
171-
layer_has_sliding_window = (
172-
layer_idx + 1) % sliding_window_pattern.__len__() != 0
173-
else:
174-
layer_has_sliding_window = False
175-
self.apply_all_layers = True
163+
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
164+
self.sliding_window = config.sliding_window if is_sliding else None
176165

177-
if layer_has_sliding_window:
178-
self.sliding_window = interleaved_sliding_window
179-
else:
180-
self.sliding_window = None
166+
# apply rotary embeddings to every layer
167+
self.apply_all_layers = not is_sliding
181168

182169
self.rotary_emb = get_rope(
183170
self.head_dim,

vllm/model_executor/models/gemma2.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,10 @@ def __init__(self,
144144
is_neox_style=True,
145145
)
146146

147-
# reference:
148-
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
149147
layer_idx = extract_layer_index(prefix)
150-
use_sliding_window = (layer_idx % 2 == 0 and getattr(
151-
config, "interleaved_sliding_window", None) is not None)
152-
sliding_window = config.interleaved_sliding_window if \
153-
use_sliding_window else None
148+
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
149+
sliding_window = config.sliding_window if is_sliding else None
150+
154151
self.attn = Attention(self.num_heads,
155152
self.head_dim,
156153
self.scaling,

vllm/model_executor/models/gemma3.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,25 +146,19 @@ def __init__(self,
146146
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
147147
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
148148

149-
# TODO(woosuk): Add reference to the original HF implementation.
150149
layer_idx = extract_layer_index(prefix)
151-
self.is_sliding = (getattr(
152-
config, "interleaved_sliding_window", None) is not None and (bool(
153-
(layer_idx + 1) % config.sliding_window_pattern))) or (
154-
getattr(config, "layer_types", None) is not None
155-
and config.layer_types[layer_idx] == "sliding_attention")
150+
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
151+
sliding_window = config.sliding_window if self.is_sliding else None
152+
156153
# Initialize the rotary embedding.
157154
if self.is_sliding:
158155
# Local attention. Override the values in config.json.
159156
self.rope_theta = config.rope_local_base_freq
160157
self.rope_scaling = {"rope_type": "default"}
161-
self.sliding_window = (config.interleaved_sliding_window
162-
or config.sliding_window)
163158
else:
164159
# Global attention. Use the values in config.json.
165160
self.rope_theta = config.rope_theta
166161
self.rope_scaling = config.rope_scaling
167-
self.sliding_window = None
168162
self.rotary_emb = get_rope(
169163
self.head_dim,
170164
rotary_dim=self.head_dim,
@@ -182,7 +176,7 @@ def __init__(self,
182176
cache_config=cache_config,
183177
quant_config=quant_config,
184178
logits_soft_cap=attn_logits_soft_cap,
185-
per_layer_sliding_window=self.sliding_window,
179+
per_layer_sliding_window=sliding_window,
186180
prefix=f"{prefix}.attn")
187181

188182
def forward(

vllm/model_executor/models/gemma3_mm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
502502
self.config = config
503503
self.quant_config = quant_config
504504
self.multimodal_config = multimodal_config
505-
self.sliding_window = getattr(config.text_config,
506-
"interleaved_sliding_window", None)
507505

508506
self.vision_tower = SiglipVisionModel(config.vision_config,
509507
quant_config,
@@ -690,11 +688,11 @@ def prepare_attn_masks(
690688
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
691689
global_attn_masks.append(global_attn_mask)
692690

693-
if self.sliding_window is not None:
691+
if (sliding_window := self.config.sliding_window) is not None:
694692
# Create a local causal mask with sliding window (1024).
695693
local_attn_mask = torch.ones_like(global_attn_mask)
696694
local_attn_mask = torch.tril(local_attn_mask,
697-
diagonal=-self.sliding_window)
695+
diagonal=-sliding_window)
698696
local_attn_mask = torch.where(local_attn_mask == 0,
699697
global_attn_mask, float("-inf"))
700698
local_attn_masks.append(local_attn_mask)

vllm/model_executor/models/gemma3n.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,17 +313,16 @@ def __init__(self,
313313
has_weight=False)
314314

315315
layer_idx = extract_layer_index(prefix)
316+
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
317+
self.sliding_window = config.sliding_window if is_sliding else None
316318

317-
is_sliding_window = (
318-
getattr(config, "interleaved_sliding_window", None) is not None
319-
and config.layer_types[layer_idx] == "sliding_attention")
320-
321-
if is_sliding_window:
322-
self.sliding_window = config.interleaved_sliding_window
319+
# Initialize the rotary embedding.
320+
if is_sliding:
321+
# Local attention. Override the values in config.json.
323322
rope_theta = config.rope_local_base_freq
324323
rope_scaling = {"rope_type": "default"}
325324
else:
326-
self.sliding_window = None
325+
# Global attention. Use the values in config.json.
327326
rope_theta = config.rope_theta
328327
rope_scaling = config.rope_scaling
329328

0 commit comments

Comments
 (0)