Skip to content

Commit 6f59bea

Browse files
antrecsantrecaregemini-code-assist[bot]
authored
[Model] Add support for ModernBertForTokenClassification (#26340)
Signed-off-by: Antoine Recanati Le Goat <[email protected]> Signed-off-by: antrec <[email protected]> Co-authored-by: Antoine Recanati Le Goat <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 41f1cf3 commit 6f59bea

File tree

5 files changed

+112
-2
lines changed

5 files changed

+112
-2
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
576576
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
577577
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
578578
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
579+
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ |
579580

580581
!!! note
581582
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.

tests/models/language/pooling/test_token_classification.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,38 @@
1111
# The float32 is required for this tiny model to pass the test.
1212
@pytest.mark.parametrize("dtype", ["float"])
1313
@torch.inference_mode
14-
def test_models(
14+
def test_bert_models(
15+
hf_runner,
16+
vllm_runner,
17+
example_prompts,
18+
model: str,
19+
dtype: str,
20+
) -> None:
21+
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
22+
vllm_outputs = vllm_model.encode(example_prompts)
23+
24+
with hf_runner(
25+
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
26+
) as hf_model:
27+
tokenizer = hf_model.tokenizer
28+
hf_outputs = []
29+
for prompt in example_prompts:
30+
inputs = tokenizer([prompt], return_tensors="pt")
31+
inputs = hf_model.wrap_device(inputs)
32+
output = hf_model.model(**inputs)
33+
hf_outputs.append(softmax(output.logits[0]))
34+
35+
# check logits difference
36+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
37+
hf_output = torch.tensor(hf_output).cpu().float()
38+
vllm_output = torch.tensor(vllm_output).cpu().float()
39+
assert torch.allclose(hf_output, vllm_output, 1e-2)
40+
41+
42+
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
43+
@pytest.mark.parametrize("dtype", ["float"])
44+
@torch.inference_mode
45+
def test_modernbert_models(
1546
hf_runner,
1647
vllm_runner,
1748
example_prompts,

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,9 @@ def check_available_online(
527527
"ModernBertForSequenceClassification": _HfExamplesInfo(
528528
"Alibaba-NLP/gte-reranker-modernbert-base"
529529
),
530+
"ModernBertForTokenClassification": _HfExamplesInfo(
531+
"disham993/electrical-ner-ModernBERT-base"
532+
),
530533
"RobertaForSequenceClassification": _HfExamplesInfo(
531534
"cross-encoder/quora-roberta-base"
532535
),

vllm/model_executor/models/modernbert.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch import nn
88
from transformers import ModernBertConfig
9+
from transformers.activations import ACT2FN
910

1011
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
1112
from vllm.compilation.decorators import support_torch_compile
@@ -29,7 +30,7 @@
2930

3031
from .interfaces import SupportsCrossEncoding
3132
from .interfaces_base import default_pooling_type
32-
from .utils import WeightsMapper, maybe_prefix
33+
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
3334

3435

3536
class ModernBertEmbeddings(nn.Module):
@@ -379,3 +380,73 @@ def forward(
379380
inputs_embeds=inputs_embeds,
380381
positions=positions,
381382
)
383+
384+
385+
class ModernBertPredictionHead(nn.Module):
386+
def __init__(self, config):
387+
super().__init__()
388+
self.config = config
389+
self.dense = nn.Linear(
390+
config.hidden_size, config.hidden_size, bias=config.classifier_bias
391+
)
392+
self.act = ACT2FN[config.classifier_activation]
393+
self.norm = nn.LayerNorm(
394+
config.hidden_size,
395+
eps=getattr(config, "norm_eps", 1e-5),
396+
bias=getattr(config, "norm_bias", True),
397+
)
398+
399+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
400+
return self.norm(self.act(self.dense(hidden_states)))
401+
402+
403+
@default_pooling_type("ALL")
404+
class ModernBertForTokenClassification(nn.Module):
405+
is_pooling_model = True
406+
407+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
408+
super().__init__()
409+
config = vllm_config.model_config.hf_config
410+
self.head_dtype = vllm_config.model_config.head_dtype
411+
self.num_labels = config.num_labels
412+
self.model = ModernBertModel(
413+
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")
414+
)
415+
self.head = ModernBertPredictionHead(config)
416+
self.classifier = nn.Linear(
417+
config.hidden_size, config.num_labels, dtype=self.head_dtype
418+
)
419+
420+
pooler_config = vllm_config.model_config.pooler_config
421+
assert pooler_config is not None
422+
423+
self.pooler = DispatchPooler(
424+
{
425+
"encode": Pooler.for_encode(pooler_config),
426+
}
427+
)
428+
429+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
430+
return self.model.get_input_embeddings(input_ids)
431+
432+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
433+
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
434+
loaded_params = loader.load_weights(weights)
435+
return loaded_params
436+
437+
def forward(
438+
self,
439+
input_ids: Optional[torch.Tensor],
440+
positions: torch.Tensor,
441+
intermediate_tensors: Optional[IntermediateTensors] = None,
442+
inputs_embeds: Optional[torch.Tensor] = None,
443+
) -> torch.Tensor:
444+
hidden_states = self.model(
445+
input_ids=input_ids,
446+
positions=positions,
447+
inputs_embeds=inputs_embeds,
448+
intermediate_tensors=intermediate_tensors,
449+
)
450+
hidden_states = self.head(hidden_states)
451+
hidden_states = hidden_states.to(self.head_dtype)
452+
return self.classifier(hidden_states)

vllm/model_executor/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@
225225
"modernbert",
226226
"ModernBertForSequenceClassification",
227227
),
228+
"ModernBertForTokenClassification": (
229+
"modernbert",
230+
"ModernBertForTokenClassification",
231+
),
228232
"RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"),
229233
"XLMRobertaForSequenceClassification": (
230234
"roberta",

0 commit comments

Comments
 (0)