Skip to content

Commit 6d729c4

Browse files
noooopmaxdebayser
andauthored
[Bugfix] Fix ModernBert load & Enable sliding window attention for bidirectional attention. (#22637)
Signed-off-by: wang.yuqi <[email protected]> Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: Max de Bayser <[email protected]>
1 parent 2f46579 commit 6d729c4

File tree

4 files changed

+101
-59
lines changed

4 files changed

+101
-59
lines changed

tests/models/language/pooling/test_gte.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
import pytest
66

7-
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
8-
LASTPoolingEmbedModelInfo, check_transformers_version)
7+
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
8+
EmbedModelInfo, LASTPoolingEmbedModelInfo,
9+
RerankModelInfo, check_transformers_version)
910
from .embed_utils import correctness_test_embed_models
10-
from .mteb_utils import mteb_test_embed_models
11+
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
1112

1213
MODELS = [
1314
########## BertModel
@@ -58,6 +59,14 @@
5859
enable_test=False),
5960
]
6061

62+
RERANK_MODELS = [
63+
# classifier_pooling: mean
64+
CLSPoolingRerankModelInfo(
65+
"Alibaba-NLP/gte-reranker-modernbert-base",
66+
architecture="ModernBertForSequenceClassification",
67+
enable_test=True),
68+
]
69+
6170

6271
@pytest.mark.parametrize("model_info", MODELS)
6372
def test_embed_models_mteb(hf_runner, vllm_runner,
@@ -88,3 +97,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
8897

8998
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
9099
example_prompts, vllm_extra_kwargs)
100+
101+
102+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
103+
def test_rerank_models_mteb(hf_runner, vllm_runner,
104+
model_info: RerankModelInfo) -> None:
105+
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)

vllm/model_executor/models/modernbert.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from vllm.sequence import IntermediateTensors
2727
from vllm.tasks import PoolingTask
2828

29-
from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
30-
default_pooling_type)
29+
from .interfaces import SupportsCrossEncoding, default_pooling_type
3130
from .utils import WeightsMapper, maybe_prefix
3231

3332

@@ -93,16 +92,14 @@ def __init__(self,
9392
bias=config.attention_bias,
9493
)
9594

95+
sliding_window = None
9696
if layer_id % config.global_attn_every_n_layers != 0:
97-
self.local_attention = (config.local_attention // 2,
98-
config.local_attention // 2)
97+
sliding_window = config.local_attention // 2
98+
rope_theta = config.local_rope_theta if config.local_rope_theta \
99+
is not None else config.global_rope_theta
99100
else:
100-
self.local_attention = (-1, -1)
101+
rope_theta = config.global_rope_theta
101102

102-
rope_theta = config.global_rope_theta
103-
if self.local_attention != (
104-
-1, -1) and config.local_rope_theta is not None:
105-
rope_theta = config.local_rope_theta
106103
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
107104
head_size=self.head_dim,
108105
dim=self.head_dim,
@@ -111,7 +108,8 @@ def __init__(self,
111108
self.head_dim,
112109
self.scaling,
113110
prefix=f"{layer_id}.attn",
114-
attn_type=AttentionType.ENCODER_ONLY)
111+
attn_type=AttentionType.ENCODER_ONLY,
112+
per_layer_sliding_window=sliding_window)
115113
self.Wo = RowParallelLinear(config.hidden_size,
116114
config.hidden_size,
117115
bias=config.attention_bias)
@@ -278,6 +276,7 @@ def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
278276
return self.pooling.get_pooling_updates(task)
279277

280278
def _head(self, pooled_output: torch.Tensor):
279+
pooled_output = pooled_output.to(self.dense.weight.dtype)
281280
return self.norm(self.act(self.dense(pooled_output)))
282281

283282
def forward(
@@ -296,8 +295,7 @@ def forward(
296295

297296

298297
@default_pooling_type("CLS")
299-
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
300-
SupportsCrossEncoding):
298+
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
301299

302300
is_pooling_model = True
303301

@@ -308,6 +306,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
308306
self.model = ModernBertModel(vllm_config=vllm_config,
309307
prefix=maybe_prefix(prefix, "modernbert"))
310308
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
309+
self.pooling = ModernBertPooler(config)
311310

312311
pooler_config = vllm_config.model_config.pooler_config
313312
assert pooler_config is not None
@@ -317,14 +316,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
317316
Pooler.for_encode(pooler_config),
318317
"classify":
319318
ClassifierPooler(
320-
pooling=ModernBertPooler(config),
319+
pooling=self.pooling,
321320
classifier=self.classifier,
322321
act_fn=ClassifierPooler.act_fn_for_seq_cls(
323322
vllm_config.model_config),
324323
),
325324
"score":
326325
ClassifierPooler(
327-
pooling=ModernBertPooler(config),
326+
pooling=self.pooling,
328327
classifier=self.classifier,
329328
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
330329
vllm_config.model_config),
@@ -353,7 +352,7 @@ def weight_filter():
353352
default_weight_loader)
354353
weight_loader(param, loaded_weight)
355354
if name.startswith("head"):
356-
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
355+
param = params_dict["pooling." + name[len("head") + 1:]]
357356
weight_loader = getattr(param, "weight_loader",
358357
default_weight_loader)
359358
weight_loader(param, loaded_weight)
@@ -368,5 +367,5 @@ def forward(
368367
return self.model(
369368
input_ids=input_ids,
370369
inputs_embeds=inputs_embeds,
371-
position_ids=positions,
370+
positions=positions,
372371
)

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def __init__(
384384
self.alibi_slopes = alibi_slopes
385385
if sliding_window is None:
386386
self.sliding_window = (-1, -1)
387+
elif attn_type == AttentionType.ENCODER_ONLY:
388+
self.sliding_window = (sliding_window - 1, sliding_window - 1)
387389
else:
388390
self.sliding_window = (sliding_window - 1, 0)
389391
self.kv_cache_dtype = kv_cache_dtype

vllm/v1/worker/gpu_model_runner.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,8 @@ def _prepare_inputs(
826826
# Prepare encoder attention metadata separately
827827
# (encoder layers are not in KV cache groups)
828828
if self.is_encoder_only_model:
829-
common_attn_metadata, encoder_attn_metadata = \
829+
830+
per_layer_metadata = \
830831
self._build_encoder_only_attn_metadata(
831832
scheduler_output)
832833

@@ -835,6 +836,8 @@ def _prepare_inputs(
835836
self.vllm_config, Attention)
836837
for layer_name, attn_module in attention_layers.items():
837838
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
839+
common_attn_metadata, encoder_attn_metadata =\
840+
per_layer_metadata[layer_name]
838841
attn_metadata[layer_name] = encoder_attn_metadata
839842

840843
# Prepare the attention metadata for each KV cache group and make layers
@@ -2683,30 +2686,41 @@ def create_attn_groups(
26832686
# Check if model is encoder-only
26842687
block_size = self.vllm_config.cache_config.block_size
26852688
use_mla = self.vllm_config.model_config.use_mla
2686-
attn_specs = list[AttentionSpec]()
2687-
for attn_module in attn_layers.values():
2689+
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
2690+
for layer_name, attn_module in attn_layers.items():
26882691

26892692
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
2690-
assert attn_module.sliding_window is None, "Sliding "
2691-
"window attention is not supported for encoder-only models"
2692-
2693-
attn_specs.append(
2694-
FullAttentionSpec(block_size=block_size,
2695-
num_kv_heads=attn_module.num_kv_heads,
2696-
head_size=attn_module.head_size,
2697-
dtype=self.kv_cache_dtype,
2698-
use_mla=use_mla))
2693+
if attn_module.sliding_window is None:
2694+
attn_spec: AttentionSpec = FullAttentionSpec(
2695+
block_size=block_size,
2696+
num_kv_heads=attn_module.num_kv_heads,
2697+
head_size=attn_module.head_size,
2698+
dtype=self.kv_cache_dtype,
2699+
use_mla=use_mla)
2700+
else:
2701+
attn_spec = SlidingWindowSpec(
2702+
block_size=block_size,
2703+
num_kv_heads=attn_module.num_kv_heads,
2704+
head_size=attn_module.head_size,
2705+
dtype=self.kv_cache_dtype,
2706+
sliding_window=attn_module.sliding_window,
2707+
use_mla=use_mla)
2708+
attn_specs[attn_spec].append(layer_name)
2709+
26992710
else:
27002711
raise ValueError("Expected only encoder-only layers")
27012712

27022713
if len(attn_specs) > 0:
2703-
assert len(attn_specs) == len(attn_layers), \
2704-
"All or none of the layers are expected to be encoder-only"
2714+
total_layers = 0
2715+
for attn_spec, layer_names in attn_specs.items():
27052716

2706-
attn_backends = get_attn_backends_for_layers(attn_layers.keys())
2717+
attn_backends = get_attn_backends_for_layers(layer_names)
2718+
total_layers += len(layer_names)
27072719

2708-
self.attn_groups.append(
2709-
create_attn_groups(attn_backends, attn_specs[0]))
2720+
self.attn_groups.append(
2721+
create_attn_groups(attn_backends, attn_spec))
2722+
assert total_layers == len(attn_layers), \
2723+
"All or none of the layers are expected to be encoder-only"
27102724
self.is_encoder_only_model = True
27112725

27122726
def calculate_reorder_batch_threshold(self) -> None:
@@ -3071,7 +3085,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
30713085

30723086
def _build_encoder_only_attn_metadata(
30733087
self, scheduler_output: "SchedulerOutput") -> \
3074-
tuple[CommonAttentionMetadata, Any]:
3088+
dict[str, tuple[CommonAttentionMetadata, Any]]:
30753089
"""Prepare encoder attention metadata for encoder-only models.
30763090
30773091
Args:
@@ -3088,33 +3102,45 @@ def _build_encoder_only_attn_metadata(
30883102
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
30893103
max_num_scheduled_tokens = max(tokens)
30903104

3091-
# Use the first attention metadata builder
3092-
# to create encoder attention metadata
3093-
builder = self.attn_groups[0][0].metadata_builder
3094-
30953105
dummy_block_table = torch.zeros((num_reqs, 1),
30963106
dtype=torch.int32,
30973107
device=self.device)
30983108
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
30993109
dtype=torch.int32,
31003110
device=self.device)
31013111

3102-
common_metadata = CommonAttentionMetadata(
3103-
query_start_loc=self.query_start_loc[:num_reqs + 1],
3104-
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
3105-
seq_lens=self.seq_lens[:num_reqs],
3106-
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
3107-
num_computed_tokens_cpu=self.input_batch.
3108-
num_computed_tokens_cpu_tensor[:num_reqs],
3109-
num_reqs=num_reqs,
3110-
num_actual_tokens=total_num_scheduled_tokens,
3111-
max_query_len=max_num_scheduled_tokens,
3112-
block_table_tensor=dummy_block_table,
3113-
slot_mapping=dummy_slot_mapping,
3114-
causal=False,
3115-
)
3112+
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
31163113

3117-
return common_metadata, builder.build(
3118-
common_prefix_len=0, # No cascade for encoder
3119-
common_attn_metadata=common_metadata,
3120-
)
3114+
for attn_group_list in self.attn_groups:
3115+
3116+
assert len(attn_group_list) == 1
3117+
attn_group = attn_group_list[0]
3118+
3119+
# Use the first attention metadata builder
3120+
# to create encoder attention metadata
3121+
builder = attn_group.metadata_builder
3122+
3123+
common_metadata = CommonAttentionMetadata(
3124+
query_start_loc=self.query_start_loc[:num_reqs + 1],
3125+
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
3126+
seq_lens=self.seq_lens[:num_reqs],
3127+
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
3128+
num_computed_tokens_cpu=self.input_batch.
3129+
num_computed_tokens_cpu_tensor[:num_reqs],
3130+
num_reqs=num_reqs,
3131+
num_actual_tokens=total_num_scheduled_tokens,
3132+
max_query_len=max_num_scheduled_tokens,
3133+
block_table_tensor=dummy_block_table,
3134+
slot_mapping=dummy_slot_mapping,
3135+
causal=False,
3136+
)
3137+
3138+
metadata = builder.build(
3139+
common_prefix_len=0, # No cascade for encoder
3140+
common_attn_metadata=common_metadata,
3141+
)
3142+
3143+
for layer_name in attn_group.layer_names:
3144+
group_metadata[layer_name] = (common_metadata, metadata)
3145+
3146+
return group_metadata

0 commit comments

Comments
 (0)