4949from vllm .sequence import IntermediateTensors
5050
5151from .interfaces import SupportsLoRA , SupportsPP
52- from .utils import (is_pp_missing_parameter ,
52+ from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
5353 make_empty_intermediate_tensors_factory , make_layers ,
5454 maybe_prefix )
5555
@@ -448,6 +448,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
448448 (lora_config .max_loras or 1 )) if lora_config else 0 )
449449 self .vocab_size = config .vocab_size + lora_vocab
450450 self .org_vocab_size = config .vocab_size
451+ self .config = config
452+ self .quant_config = quant_config
451453
452454 self .embed_tokens = VocabParallelEmbedding (
453455 self .vocab_size ,
@@ -504,85 +506,6 @@ def forward(
504506 hidden_states = self .norm (hidden_states )
505507 return hidden_states
506508
507-
508- class PhiMoEForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
509- fall_back_to_pt_during_load = False
510-
511- packed_modules_mapping = {
512- "qkv_proj" : [
513- "q_proj" ,
514- "k_proj" ,
515- "v_proj" ,
516- ],
517- }
518-
519- # LoRA specific attributes
520- embedding_modules = {
521- "embed_tokens" : "input_embeddings" ,
522- "lm_head" : "output_embeddings" ,
523- }
524- embedding_padding_modules = ["lm_head" ]
525-
526- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
527- super ().__init__ ()
528- config = vllm_config .model_config .hf_config
529- lora_config = vllm_config .lora_config
530- self .config = config
531- self .lora_config = lora_config
532- self .quant_config = vllm_config .quant_config
533-
534- self .model = PhiMoEModel (vllm_config = vllm_config ,
535- prefix = maybe_prefix (prefix , "model" ))
536- self .unpadded_vocab_size = config .vocab_size
537- if lora_config :
538- self .unpadded_vocab_size += lora_config .lora_extra_vocab_size
539- self .lm_head = ParallelLMHead (
540- self .unpadded_vocab_size ,
541- config .hidden_size ,
542- org_num_embeddings = config .vocab_size ,
543- padding_size = (
544- DEFAULT_VOCAB_PADDING_SIZE
545- # We need bigger padding if using lora for kernel
546- # compatibility
547- if not lora_config else lora_config .lora_vocab_padding_size ),
548- quant_config = None ,
549- bias = True ,
550- )
551- self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
552- config .vocab_size )
553- self .sampler = get_sampler ()
554-
555- self .make_empty_intermediate_tensors = (
556- self .model .make_empty_intermediate_tensors )
557-
558- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
559- return self .model .get_input_embeddings (input_ids )
560-
561- def forward (
562- self ,
563- input_ids : torch .Tensor ,
564- positions : torch .Tensor ,
565- intermediate_tensors : Optional [IntermediateTensors ] = None ,
566- inputs_embeds : Optional [torch .Tensor ] = None ,
567- ) -> Union [torch .Tensor , IntermediateTensors ]:
568- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
569- inputs_embeds )
570- return hidden_states
571-
572- def compute_logits (self , hidden_states : torch .Tensor ,
573- sampling_metadata : SamplingMetadata ) -> torch .Tensor :
574- logits = self .logits_processor (self .lm_head , hidden_states ,
575- sampling_metadata )
576- return logits
577-
578- def sample (
579- self ,
580- logits : Optional [torch .Tensor ],
581- sampling_metadata : SamplingMetadata ,
582- ) -> Optional [SamplerOutput ]:
583- next_tokens = self .sampler (logits , sampling_metadata )
584- return next_tokens
585-
586509 def load_weights (self , weights : Iterable [Tuple [str ,
587510 torch .Tensor ]]) -> Set [str ]:
588511 stacked_params_mapping = [
@@ -601,9 +524,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
601524 params_dict = dict (self .named_parameters ())
602525 loaded_params : Set [str ] = set ()
603526 for name , loaded_weight in weights :
604- if "rotary_emb.inv_freq" in name :
605- continue
606-
607527 if (self .quant_config is not None and
608528 (scale_name := self .quant_config .get_cache_scale (name ))):
609529 # Loading kv cache quantization scales
@@ -667,3 +587,90 @@ def load_weights(self, weights: Iterable[Tuple[str,
667587 weight_loader (param , loaded_weight )
668588 loaded_params .add (name )
669589 return loaded_params
590+
591+
592+ class PhiMoEForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
593+ fall_back_to_pt_during_load = False
594+
595+ packed_modules_mapping = {
596+ "qkv_proj" : [
597+ "q_proj" ,
598+ "k_proj" ,
599+ "v_proj" ,
600+ ],
601+ }
602+
603+ # LoRA specific attributes
604+ embedding_modules = {
605+ "embed_tokens" : "input_embeddings" ,
606+ "lm_head" : "output_embeddings" ,
607+ }
608+ embedding_padding_modules = ["lm_head" ]
609+
610+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
611+ super ().__init__ ()
612+ config = vllm_config .model_config .hf_config
613+ lora_config = vllm_config .lora_config
614+ self .config = config
615+ self .lora_config = lora_config
616+ self .quant_config = vllm_config .quant_config
617+
618+ self .model = PhiMoEModel (vllm_config = vllm_config ,
619+ prefix = maybe_prefix (prefix , "model" ))
620+ self .unpadded_vocab_size = config .vocab_size
621+ if lora_config :
622+ self .unpadded_vocab_size += lora_config .lora_extra_vocab_size
623+ self .lm_head = ParallelLMHead (
624+ self .unpadded_vocab_size ,
625+ config .hidden_size ,
626+ org_num_embeddings = config .vocab_size ,
627+ padding_size = (
628+ DEFAULT_VOCAB_PADDING_SIZE
629+ # We need bigger padding if using lora for kernel
630+ # compatibility
631+ if not lora_config else lora_config .lora_vocab_padding_size ),
632+ quant_config = None ,
633+ bias = True ,
634+ )
635+ self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
636+ config .vocab_size )
637+ self .sampler = get_sampler ()
638+
639+ self .make_empty_intermediate_tensors = (
640+ self .model .make_empty_intermediate_tensors )
641+
642+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
643+ return self .model .get_input_embeddings (input_ids )
644+
645+ def forward (
646+ self ,
647+ input_ids : torch .Tensor ,
648+ positions : torch .Tensor ,
649+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
650+ inputs_embeds : Optional [torch .Tensor ] = None ,
651+ ) -> Union [torch .Tensor , IntermediateTensors ]:
652+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
653+ inputs_embeds )
654+ return hidden_states
655+
656+ def compute_logits (self , hidden_states : torch .Tensor ,
657+ sampling_metadata : SamplingMetadata ) -> torch .Tensor :
658+ logits = self .logits_processor (self .lm_head , hidden_states ,
659+ sampling_metadata )
660+ return logits
661+
662+ def sample (
663+ self ,
664+ logits : Optional [torch .Tensor ],
665+ sampling_metadata : SamplingMetadata ,
666+ ) -> Optional [SamplerOutput ]:
667+ next_tokens = self .sampler (logits , sampling_metadata )
668+ return next_tokens
669+
670+ def load_weights (self , weights : Iterable [Tuple [str ,
671+ torch .Tensor ]]) -> Set [str ]:
672+ loader = AutoWeightsLoader (
673+ self ,
674+ skip_prefixes = (["rotary_emb.inv_freq" ]),
675+ )
676+ return loader .load_weights (weights )
0 commit comments