53
53
from vllm .sequence import IntermediateTensors
54
54
55
55
from .interfaces import SupportsPP
56
- from .utils import (AutoWeightsLoader , PPMissingLayer , is_pp_missing_parameter ,
56
+ from .utils import (PPMissingLayer , is_pp_missing_parameter ,
57
57
make_empty_intermediate_tensors_factory , make_layers ,
58
58
maybe_prefix )
59
59
@@ -668,6 +668,73 @@ def forward(
668
668
hidden_states , _ = self .norm (hidden_states , residual )
669
669
return hidden_states
670
670
671
+
672
+ class DeepseekV2ForCausalLM (nn .Module , SupportsPP ):
673
+
674
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
675
+ super ().__init__ ()
676
+ config = vllm_config .model_config .hf_config
677
+ quant_config = vllm_config .quant_config
678
+ self .config = config
679
+ self .quant_config = quant_config
680
+ self .model = DeepseekV2Model (vllm_config = vllm_config ,
681
+ prefix = maybe_prefix (prefix , "model" ))
682
+ if get_pp_group ().is_last_rank :
683
+ self .lm_head = ParallelLMHead (config .vocab_size ,
684
+ config .hidden_size ,
685
+ quant_config = quant_config )
686
+ else :
687
+ self .lm_head = PPMissingLayer ()
688
+ self .logits_processor = LogitsProcessor (config .vocab_size )
689
+ self .sampler = get_sampler ()
690
+ self .make_empty_intermediate_tensors = (
691
+ self .model .make_empty_intermediate_tensors )
692
+
693
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
694
+ return self .model .get_input_embeddings (input_ids )
695
+
696
+ def forward (
697
+ self ,
698
+ input_ids : torch .Tensor ,
699
+ positions : torch .Tensor ,
700
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
701
+ inputs_embeds : Optional [torch .Tensor ] = None ,
702
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
703
+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
704
+ inputs_embeds )
705
+ return hidden_states
706
+
707
+ def compute_logits (
708
+ self ,
709
+ hidden_states : torch .Tensor ,
710
+ sampling_metadata : SamplingMetadata ,
711
+ ) -> Optional [torch .Tensor ]:
712
+ logits = self .logits_processor (self .lm_head , hidden_states ,
713
+ sampling_metadata )
714
+ return logits
715
+
716
+ def sample (
717
+ self ,
718
+ logits : Optional [torch .Tensor ],
719
+ sampling_metadata : SamplingMetadata ,
720
+ ) -> Optional [SamplerOutput ]:
721
+ next_tokens = self .sampler (logits , sampling_metadata )
722
+ return next_tokens
723
+
724
+ def make_empty_intermediate_tensors (
725
+ self , batch_size : int , dtype : torch .dtype ,
726
+ device : torch .device ) -> IntermediateTensors :
727
+ return IntermediateTensors ({
728
+ "hidden_states" :
729
+ torch .zeros ((batch_size , self .config .hidden_size ),
730
+ dtype = dtype ,
731
+ device = device ),
732
+ "residual" :
733
+ torch .zeros ((batch_size , self .config .hidden_size ),
734
+ dtype = dtype ,
735
+ device = device ),
736
+ })
737
+
671
738
def load_weights (self , weights : Iterable [Tuple [str ,
672
739
torch .Tensor ]]) -> Set [str ]:
673
740
stacked_params_mapping = [
@@ -687,6 +754,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
687
754
params_dict = dict (self .named_parameters ())
688
755
loaded_params : Set [str ] = set ()
689
756
for name , loaded_weight in weights :
757
+ if "rotary_emb.inv_freq" in name :
758
+ continue
759
+
690
760
spec_layer = get_spec_layer_idx_from_weight_name (self .config , name )
691
761
if spec_layer is not None :
692
762
continue # skip spec decode layers for main model
@@ -754,78 +824,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
754
824
return loaded_params
755
825
756
826
757
- class DeepseekV2ForCausalLM (nn .Module , SupportsPP ):
758
-
759
- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
760
- super ().__init__ ()
761
- config = vllm_config .model_config .hf_config
762
- quant_config = vllm_config .quant_config
763
- self .config = config
764
- self .quant_config = quant_config
765
- self .model = DeepseekV2Model (vllm_config = vllm_config ,
766
- prefix = maybe_prefix (prefix , "model" ))
767
- if get_pp_group ().is_last_rank :
768
- self .lm_head = ParallelLMHead (config .vocab_size ,
769
- config .hidden_size ,
770
- quant_config = quant_config )
771
- else :
772
- self .lm_head = PPMissingLayer ()
773
- self .logits_processor = LogitsProcessor (config .vocab_size )
774
- self .sampler = get_sampler ()
775
- self .make_empty_intermediate_tensors = (
776
- self .model .make_empty_intermediate_tensors )
777
-
778
- def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
779
- return self .model .get_input_embeddings (input_ids )
780
-
781
- def forward (
782
- self ,
783
- input_ids : torch .Tensor ,
784
- positions : torch .Tensor ,
785
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
786
- inputs_embeds : Optional [torch .Tensor ] = None ,
787
- ) -> Union [torch .Tensor , IntermediateTensors ]:
788
- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
789
- inputs_embeds )
790
- return hidden_states
791
-
792
- def compute_logits (
793
- self ,
794
- hidden_states : torch .Tensor ,
795
- sampling_metadata : SamplingMetadata ,
796
- ) -> Optional [torch .Tensor ]:
797
- logits = self .logits_processor (self .lm_head , hidden_states ,
798
- sampling_metadata )
799
- return logits
800
-
801
- def sample (
802
- self ,
803
- logits : Optional [torch .Tensor ],
804
- sampling_metadata : SamplingMetadata ,
805
- ) -> Optional [SamplerOutput ]:
806
- next_tokens = self .sampler (logits , sampling_metadata )
807
- return next_tokens
808
-
809
- def make_empty_intermediate_tensors (
810
- self , batch_size : int , dtype : torch .dtype ,
811
- device : torch .device ) -> IntermediateTensors :
812
- return IntermediateTensors ({
813
- "hidden_states" :
814
- torch .zeros ((batch_size , self .config .hidden_size ),
815
- dtype = dtype ,
816
- device = device ),
817
- "residual" :
818
- torch .zeros ((batch_size , self .config .hidden_size ),
819
- dtype = dtype ,
820
- device = device ),
821
- })
822
-
823
- def load_weights (self , weights : Iterable [Tuple [str ,
824
- torch .Tensor ]]) -> Set [str ]:
825
- loader = AutoWeightsLoader (self , skip_prefixes = ["rotary_emb.inv_freq" ])
826
- return loader .load_weights (weights )
827
-
828
-
829
827
class DeepseekV3ForCausalLM (DeepseekV2ForCausalLM ):
830
828
pass
831
829
0 commit comments