|
6 | 6 | import torch
|
7 | 7 | from torch import nn
|
8 | 8 | from transformers import ModernBertConfig
|
| 9 | +from transformers.activations import ACT2FN |
9 | 10 |
|
10 | 11 | from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
11 | 12 | from vllm.compilation.decorators import support_torch_compile
|
|
29 | 30 |
|
30 | 31 | from .interfaces import SupportsCrossEncoding
|
31 | 32 | from .interfaces_base import default_pooling_type
|
32 |
| -from .utils import WeightsMapper, maybe_prefix |
| 33 | +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix |
33 | 34 |
|
34 | 35 |
|
35 | 36 | class ModernBertEmbeddings(nn.Module):
|
@@ -379,3 +380,73 @@ def forward(
|
379 | 380 | inputs_embeds=inputs_embeds,
|
380 | 381 | positions=positions,
|
381 | 382 | )
|
| 383 | + |
| 384 | + |
| 385 | +class ModernBertPredictionHead(nn.Module): |
| 386 | + def __init__(self, config): |
| 387 | + super().__init__() |
| 388 | + self.config = config |
| 389 | + self.dense = nn.Linear( |
| 390 | + config.hidden_size, config.hidden_size, bias=config.classifier_bias |
| 391 | + ) |
| 392 | + self.act = ACT2FN[config.classifier_activation] |
| 393 | + self.norm = nn.LayerNorm( |
| 394 | + config.hidden_size, |
| 395 | + eps=getattr(config, "norm_eps", 1e-5), |
| 396 | + bias=getattr(config, "norm_bias", True), |
| 397 | + ) |
| 398 | + |
| 399 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 400 | + return self.norm(self.act(self.dense(hidden_states))) |
| 401 | + |
| 402 | + |
| 403 | +@default_pooling_type("ALL") |
| 404 | +class ModernBertForTokenClassification(nn.Module): |
| 405 | + is_pooling_model = True |
| 406 | + |
| 407 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 408 | + super().__init__() |
| 409 | + config = vllm_config.model_config.hf_config |
| 410 | + self.head_dtype = vllm_config.model_config.head_dtype |
| 411 | + self.num_labels = config.num_labels |
| 412 | + self.model = ModernBertModel( |
| 413 | + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") |
| 414 | + ) |
| 415 | + self.head = ModernBertPredictionHead(config) |
| 416 | + self.classifier = nn.Linear( |
| 417 | + config.hidden_size, config.num_labels, dtype=self.head_dtype |
| 418 | + ) |
| 419 | + |
| 420 | + pooler_config = vllm_config.model_config.pooler_config |
| 421 | + assert pooler_config is not None |
| 422 | + |
| 423 | + self.pooler = DispatchPooler( |
| 424 | + { |
| 425 | + "encode": Pooler.for_encode(pooler_config), |
| 426 | + } |
| 427 | + ) |
| 428 | + |
| 429 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 430 | + return self.model.get_input_embeddings(input_ids) |
| 431 | + |
| 432 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 433 | + loader = AutoWeightsLoader(self, skip_prefixes=["drop"]) |
| 434 | + loaded_params = loader.load_weights(weights) |
| 435 | + return loaded_params |
| 436 | + |
| 437 | + def forward( |
| 438 | + self, |
| 439 | + input_ids: Optional[torch.Tensor], |
| 440 | + positions: torch.Tensor, |
| 441 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 442 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 443 | + ) -> torch.Tensor: |
| 444 | + hidden_states = self.model( |
| 445 | + input_ids=input_ids, |
| 446 | + positions=positions, |
| 447 | + inputs_embeds=inputs_embeds, |
| 448 | + intermediate_tensors=intermediate_tensors, |
| 449 | + ) |
| 450 | + hidden_states = self.head(hidden_states) |
| 451 | + hidden_states = hidden_states.to(self.head_dtype) |
| 452 | + return self.classifier(hidden_states) |
0 commit comments