Skip to content

Commit 429e4e2

Browse files
authored
[Bugfix] Fix ModernBert cuda graph capturing in v1 (#21901)
Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 35afe1b commit 429e4e2

File tree

5 files changed

+39
-42
lines changed

5 files changed

+39
-42
lines changed

tests/models/language/pooling/mteb_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def mteb_test_embed_models(hf_runner,
162162
vllm_runner,
163163
model_info: EmbedModelInfo,
164164
vllm_extra_kwargs=None,
165-
hf_model_callback=None):
165+
hf_model_callback=None,
166+
atol=MTEB_RERANK_TOL):
166167
if not model_info.enable_test:
167168
# A model family has many models with the same architecture,
168169
# and we don't need to test each one.
@@ -198,7 +199,7 @@ def mteb_test_embed_models(hf_runner,
198199
print("SentenceTransformers:", st_dtype, st_main_score)
199200
print("Difference:", st_main_score - vllm_main_score)
200201

201-
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
202+
assert st_main_score == pytest.approx(vllm_main_score, abs=atol)
202203

203204

204205
def run_mteb_rerank(cross_encoder, tasks, languages):

vllm/model_executor/models/bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
466466

467467
def forward(
468468
self,
469-
input_ids: Optional[torch.Tensor],
469+
input_ids: torch.Tensor,
470470
positions: torch.Tensor,
471471
token_type_ids: Optional[torch.Tensor] = None,
472472
intermediate_tensors: Optional[IntermediateTensors] = None,

vllm/model_executor/models/bert_with_rope.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from transformers import PretrainedConfig
99

1010
from vllm.attention import Attention, AttentionType
11+
from vllm.compilation.decorators import support_torch_compile
1112
from vllm.config import CacheConfig, VllmConfig
1213
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1314
get_tensor_model_parallel_world_size,
1415
tensor_model_parallel_all_reduce)
1516
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
1617
get_act_fn)
17-
from vllm.model_executor.layers.fused_moe import fused_moe
18+
from vllm.model_executor.layers.fused_moe.fused_moe import (
19+
fused_topk, torch_vllm_outplace_fused_experts)
1820
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1921
MergedColumnParallelLinear,
2022
QKVParallelLinear,
@@ -284,15 +286,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
284286
hidden_states = hidden_states.view(-1, self.hidden_size)
285287
# router_logits: (num_tokens, n_experts)
286288
router_logits, _ = self.router(hidden_states)
287-
final_hidden_states = fused_moe(hidden_states,
288-
self.w1,
289-
self.w2,
290-
router_logits,
291-
self.top_k,
292-
renormalize=False,
293-
inplace=False,
294-
activation=self.hidden_act,
295-
is_act_and_mul=False)
289+
# FIXME(Isotr0py): This implementation is too tricky,
290+
# we should use FusedMoE instead in the future
291+
# after supporting ungated activation for it.
292+
topk_weights, topk_ids, _ = fused_topk(hidden_states,
293+
router_logits,
294+
self.top_k,
295+
renormalize=False)
296+
final_hidden_states = torch_vllm_outplace_fused_experts(
297+
hidden_states=hidden_states,
298+
w1=self.w1,
299+
w2=self.w2,
300+
topk_weights=topk_weights,
301+
topk_ids=topk_ids,
302+
activation=self.hidden_act,
303+
is_act_and_mul=False,
304+
)
296305

297306
if self.tp_size > 1:
298307
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -391,6 +400,7 @@ def forward(
391400
return hidden_states
392401

393402

403+
@support_torch_compile
394404
class BertWithRope(nn.Module, SupportsQuant):
395405
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
396406

@@ -407,7 +417,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
407417

408418
def forward(
409419
self,
410-
input_ids: Optional[torch.Tensor],
420+
input_ids: torch.Tensor,
411421
positions: torch.Tensor,
412422
intermediate_tensors: Optional[IntermediateTensors] = None,
413423
inputs_embeds: Optional[torch.Tensor] = None,
@@ -554,20 +564,6 @@ class JinaRobertaModel(BertWithRope):
554564
"norm2": "mlp_ln",
555565
})
556566

557-
def forward(
558-
self,
559-
input_ids: torch.Tensor,
560-
position_ids: torch.Tensor,
561-
intermediate_tensors: Optional[IntermediateTensors] = None,
562-
inputs_embeds: Optional[torch.Tensor] = None,
563-
token_type_ids: Optional[torch.Tensor] = None,
564-
) -> torch.Tensor:
565-
return super().forward(input_ids=input_ids,
566-
positions=position_ids,
567-
intermediate_tensors=intermediate_tensors,
568-
inputs_embeds=inputs_embeds,
569-
token_type_ids=token_type_ids)
570-
571567
@torch.inference_mode()
572568
def jina_merge_lora_weights(self, weights: Iterable[tuple[str,
573569
torch.Tensor]]):

vllm/model_executor/models/modernbert.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers import ModernBertConfig
99

1010
from vllm.attention import Attention, AttentionType
11+
from vllm.compilation.decorators import support_torch_compile
1112
from vllm.config import VllmConfig
1213
from vllm.distributed import get_tensor_model_parallel_world_size
1314
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -46,7 +47,7 @@ def forward(
4647
input_ids: torch.Tensor,
4748
inputs_embeds: Optional[torch.Tensor] = None,
4849
) -> torch.Tensor:
49-
if inputs_embeds:
50+
if inputs_embeds is not None:
5051
return self.norm(inputs_embeds)
5152
else:
5253
inputs_embeds = self.tok_embeddings(input_ids)
@@ -117,7 +118,7 @@ def __init__(self,
117118
def forward(
118119
self,
119120
hidden_states: torch.Tensor,
120-
position_ids: Optional[torch.LongTensor] = None,
121+
position_ids: torch.Tensor,
121122
) -> torch.Tensor:
122123
qkv, _ = self.Wqkv(hidden_states)
123124
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
@@ -169,9 +170,9 @@ def __init__(self,
169170
def forward(
170171
self,
171172
hidden_states: torch.Tensor,
172-
position_ids: Optional[torch.LongTensor] = None,
173-
):
174-
attn_outputs = self.attn(self.attn_norm(hidden_states),
173+
position_ids: torch.Tensor,
174+
) -> torch.Tensor:
175+
attn_outputs = self.attn(hidden_states=self.attn_norm(hidden_states),
175176
position_ids=position_ids)
176177
hidden_states = hidden_states + attn_outputs
177178
mlp_output = self.mlp(self.mlp_norm(hidden_states))
@@ -192,13 +193,14 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
192193
def forward(
193194
self,
194195
hidden_states: torch.Tensor,
195-
position_ids: Optional[torch.LongTensor] = None,
196+
position_ids: torch.Tensor,
196197
) -> torch.Tensor:
197198
for i, layer in enumerate(self.layers):
198199
hidden_states = layer(hidden_states, position_ids)
199200
return hidden_states
200201

201202

203+
@support_torch_compile
202204
class ModernBertModel(nn.Module):
203205
hf_to_vllm_mapper = WeightsMapper(
204206
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
@@ -234,13 +236,11 @@ def load_weights(self, weights: Iterable[tuple[str,
234236

235237
def forward(
236238
self,
237-
input_ids: Optional[torch.LongTensor] = None,
238-
positions: Optional[torch.Tensor] = None,
239+
input_ids: torch.Tensor,
240+
positions: torch.Tensor,
239241
intermediate_tensors: Optional[IntermediateTensors] = None,
240242
inputs_embeds: Optional[torch.Tensor] = None,
241-
position_ids: Optional[torch.LongTensor] = None,
242243
) -> torch.Tensor:
243-
position_ids = positions if positions is not None else position_ids
244244
if inputs_embeds is not None:
245245
hidden_states = inputs_embeds
246246
else:
@@ -249,7 +249,7 @@ def forward(
249249

250250
outputs = self.encoder_layer(
251251
hidden_states=hidden_states,
252-
position_ids=position_ids,
252+
position_ids=positions,
253253
)
254254
norm_outputs = self.final_norm(outputs)
255255
return norm_outputs

vllm/model_executor/models/roberta.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
105105

106106
def forward(
107107
self,
108-
input_ids: Optional[torch.Tensor],
108+
input_ids: torch.Tensor,
109109
positions: torch.Tensor,
110110
token_type_ids: Optional[torch.Tensor] = None,
111111
intermediate_tensors: Optional[IntermediateTensors] = None,
@@ -119,8 +119,8 @@ def forward(
119119
position_ids=positions,
120120
padding_idx=self.padding_idx)
121121

122-
return self.model(input_ids=input_ids,
123-
position_ids=positions,
122+
return self.model(input_ids,
123+
positions,
124124
token_type_ids=token_type_ids,
125125
inputs_embeds=inputs_embeds,
126126
intermediate_tensors=intermediate_tensors)

0 commit comments

Comments
 (0)