31
31
from vllm .compilation .decorators import support_torch_compile
32
32
from vllm .config import (CacheConfig , DeviceConfig , ModelConfig ,
33
33
ParallelConfig , VllmConfig )
34
+ from vllm .config .utils import getattr_iter
34
35
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
35
36
from vllm .distributed .utils import get_pp_indices
36
37
from vllm .logger import init_logger
@@ -486,10 +487,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
486
487
487
488
# Input embeddings
488
489
if not isinstance (self .model .get_input_embeddings (), PPMissingLayer ):
490
+ names = ("embedding_size" , "hidden_size" )
491
+ embedding_dim = getattr_iter (self .text_config , names , None )
492
+ assert embedding_dim is not None
489
493
self .model .set_input_embeddings (
490
494
VocabParallelEmbedding (
491
495
self .text_config .vocab_size ,
492
- self . text_config . hidden_size ,
496
+ embedding_dim = embedding_dim ,
493
497
org_num_embeddings = self .text_config .vocab_size ,
494
498
quant_config = self .quant_config ,
495
499
))
@@ -645,7 +649,9 @@ def create_attention_instances(
645
649
attn_type = attn_type )
646
650
return attention_instances
647
651
648
- def init_parameters (self , module : nn .Module ):
652
+ def init_parameters (self ,
653
+ module : nn .Module ,
654
+ dtype : Optional [torch .dtype ] = None ):
649
655
"""
650
656
If a `parameter` is on the `meta` device, then its parent
651
657
`module` is the original module created by:
@@ -659,11 +665,11 @@ def init_parameters(self, module: nn.Module):
659
665
if param .device == torch .device ("meta" ):
660
666
new_param = nn .Parameter (
661
667
torch .empty_like (param .data ,
662
- dtype = self .model_config .dtype ,
668
+ dtype = dtype or self .model_config .dtype ,
663
669
device = self .device_config .device ))
664
670
setattr (module , name , new_param )
665
671
for child in module .children ():
666
- self .init_parameters (child )
672
+ self .init_parameters (child , dtype )
667
673
668
674
def forward (
669
675
self ,
@@ -712,73 +718,6 @@ def load_weights(self, weights: Iterable[tuple[str,
712
718
return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
713
719
714
720
715
- @support_torch_compile (enable_if = can_enable_torch_compile )
716
- class TransformersModel (TransformersBase ):
717
- hf_to_vllm_mapper = WeightsMapper (
718
- orig_to_new_prefix = {
719
- # Handle BERT-like models
720
- "bert" : "model" ,
721
- # Add `model.` prefix for base model checkpoints
722
- "" : "model." ,
723
- # Remove `model.` prefix if it was already there
724
- "model.model." : "model." ,
725
- # Pooling adapters will be adjacent to `model`
726
- "model.pooler" : "pooler" ,
727
- "model.score" : "score" ,
728
- # Classifier adapter's classifier layer is renamed to score
729
- "model.classifier" : "score" ,
730
- },
731
- orig_to_new_suffix = {
732
- # Replace legacy suffixes used for norms
733
- ".gamma" : ".weight" ,
734
- ".beta" : ".bias" ,
735
- })
736
-
737
- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
738
- super ().__init__ (vllm_config = vllm_config , prefix = prefix )
739
-
740
- # After creating a pooling model, `pooler` will be duplicated.
741
- # The one inside `model` comes from the Transformers modelling code.
742
- # The one after `model` is an adapter from vLLM.
743
- # We want to use the adapter so we nullify the original pooler.
744
- if getattr (self .model , "pooler" , None ) is not None :
745
- self .skip_prefixes .append ("pooler." )
746
- self .model .pooler = torch .nn .Identity ()
747
-
748
- # Some encoder models have the position_ids buffer in the checkpoint.
749
- # vLLM will always pass position_ids as an argument, so we skip loading
750
- # the buffer if it exists
751
- self .skip_substrs .append ("position_ids" )
752
-
753
- # Some encoder models have the bias of the final classifier layer
754
- # in the checkpoint. vLLM does not use this bias, so we skip loading
755
- # it if it exists
756
- self .skip_substrs .append ("score.bias" )
757
-
758
- def create_attention_instances (
759
- self , attn_type : AttentionType = AttentionType .DECODER ):
760
- # TODO(hmellor): Better way to detect encoder models
761
- # In encoder models, the attention layers will have `is_causal=False`
762
- is_encoder = lambda m : not getattr (m , "is_causal" , True )
763
- # vLLM does not support encoder-decoder models, so if any encoder layer
764
- # is found, we assume the whole model is an encoder model
765
- if any (is_encoder (m ) for m in self .model .modules ()):
766
- attn_type = AttentionType .ENCODER_ONLY
767
-
768
- # Check minimum transformers version for encoder models support
769
- if attn_type == AttentionType .ENCODER_ONLY :
770
- import transformers
771
- from packaging .version import Version
772
- installed = Version (transformers .__version__ )
773
- required = Version ("4.57.0.dev0" )
774
- if installed < required :
775
- raise ValueError (
776
- "Encoder models with the Transformers backend require "
777
- f"transformers>={ required } , but got { installed } " )
778
-
779
- return super ().create_attention_instances (attn_type )
780
-
781
-
782
721
@support_torch_compile (enable_if = can_enable_torch_compile )
783
722
class TransformersForCausalLM (TransformersBase ):
784
723
0 commit comments