Skip to content

Commit 905e91e

Browse files
authored
Revert "[Model] use AutoWeightsLoader for deepseek_v2, internlm2" (#16453)
1 parent f8f9c0b commit 905e91e

File tree

2 files changed

+107
-112
lines changed

2 files changed

+107
-112
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 71 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from vllm.sequence import IntermediateTensors
5454

5555
from .interfaces import SupportsPP
56-
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
56+
from .utils import (PPMissingLayer, is_pp_missing_parameter,
5757
make_empty_intermediate_tensors_factory, make_layers,
5858
maybe_prefix)
5959

@@ -668,6 +668,73 @@ def forward(
668668
hidden_states, _ = self.norm(hidden_states, residual)
669669
return hidden_states
670670

671+
672+
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
673+
674+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
675+
super().__init__()
676+
config = vllm_config.model_config.hf_config
677+
quant_config = vllm_config.quant_config
678+
self.config = config
679+
self.quant_config = quant_config
680+
self.model = DeepseekV2Model(vllm_config=vllm_config,
681+
prefix=maybe_prefix(prefix, "model"))
682+
if get_pp_group().is_last_rank:
683+
self.lm_head = ParallelLMHead(config.vocab_size,
684+
config.hidden_size,
685+
quant_config=quant_config)
686+
else:
687+
self.lm_head = PPMissingLayer()
688+
self.logits_processor = LogitsProcessor(config.vocab_size)
689+
self.sampler = get_sampler()
690+
self.make_empty_intermediate_tensors = (
691+
self.model.make_empty_intermediate_tensors)
692+
693+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
694+
return self.model.get_input_embeddings(input_ids)
695+
696+
def forward(
697+
self,
698+
input_ids: torch.Tensor,
699+
positions: torch.Tensor,
700+
intermediate_tensors: Optional[IntermediateTensors] = None,
701+
inputs_embeds: Optional[torch.Tensor] = None,
702+
) -> Union[torch.Tensor, IntermediateTensors]:
703+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
704+
inputs_embeds)
705+
return hidden_states
706+
707+
def compute_logits(
708+
self,
709+
hidden_states: torch.Tensor,
710+
sampling_metadata: SamplingMetadata,
711+
) -> Optional[torch.Tensor]:
712+
logits = self.logits_processor(self.lm_head, hidden_states,
713+
sampling_metadata)
714+
return logits
715+
716+
def sample(
717+
self,
718+
logits: Optional[torch.Tensor],
719+
sampling_metadata: SamplingMetadata,
720+
) -> Optional[SamplerOutput]:
721+
next_tokens = self.sampler(logits, sampling_metadata)
722+
return next_tokens
723+
724+
def make_empty_intermediate_tensors(
725+
self, batch_size: int, dtype: torch.dtype,
726+
device: torch.device) -> IntermediateTensors:
727+
return IntermediateTensors({
728+
"hidden_states":
729+
torch.zeros((batch_size, self.config.hidden_size),
730+
dtype=dtype,
731+
device=device),
732+
"residual":
733+
torch.zeros((batch_size, self.config.hidden_size),
734+
dtype=dtype,
735+
device=device),
736+
})
737+
671738
def load_weights(self, weights: Iterable[Tuple[str,
672739
torch.Tensor]]) -> Set[str]:
673740
stacked_params_mapping = [
@@ -687,6 +754,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
687754
params_dict = dict(self.named_parameters())
688755
loaded_params: Set[str] = set()
689756
for name, loaded_weight in weights:
757+
if "rotary_emb.inv_freq" in name:
758+
continue
759+
690760
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
691761
if spec_layer is not None:
692762
continue # skip spec decode layers for main model
@@ -754,78 +824,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
754824
return loaded_params
755825

756826

757-
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
758-
759-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
760-
super().__init__()
761-
config = vllm_config.model_config.hf_config
762-
quant_config = vllm_config.quant_config
763-
self.config = config
764-
self.quant_config = quant_config
765-
self.model = DeepseekV2Model(vllm_config=vllm_config,
766-
prefix=maybe_prefix(prefix, "model"))
767-
if get_pp_group().is_last_rank:
768-
self.lm_head = ParallelLMHead(config.vocab_size,
769-
config.hidden_size,
770-
quant_config=quant_config)
771-
else:
772-
self.lm_head = PPMissingLayer()
773-
self.logits_processor = LogitsProcessor(config.vocab_size)
774-
self.sampler = get_sampler()
775-
self.make_empty_intermediate_tensors = (
776-
self.model.make_empty_intermediate_tensors)
777-
778-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
779-
return self.model.get_input_embeddings(input_ids)
780-
781-
def forward(
782-
self,
783-
input_ids: torch.Tensor,
784-
positions: torch.Tensor,
785-
intermediate_tensors: Optional[IntermediateTensors] = None,
786-
inputs_embeds: Optional[torch.Tensor] = None,
787-
) -> Union[torch.Tensor, IntermediateTensors]:
788-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
789-
inputs_embeds)
790-
return hidden_states
791-
792-
def compute_logits(
793-
self,
794-
hidden_states: torch.Tensor,
795-
sampling_metadata: SamplingMetadata,
796-
) -> Optional[torch.Tensor]:
797-
logits = self.logits_processor(self.lm_head, hidden_states,
798-
sampling_metadata)
799-
return logits
800-
801-
def sample(
802-
self,
803-
logits: Optional[torch.Tensor],
804-
sampling_metadata: SamplingMetadata,
805-
) -> Optional[SamplerOutput]:
806-
next_tokens = self.sampler(logits, sampling_metadata)
807-
return next_tokens
808-
809-
def make_empty_intermediate_tensors(
810-
self, batch_size: int, dtype: torch.dtype,
811-
device: torch.device) -> IntermediateTensors:
812-
return IntermediateTensors({
813-
"hidden_states":
814-
torch.zeros((batch_size, self.config.hidden_size),
815-
dtype=dtype,
816-
device=device),
817-
"residual":
818-
torch.zeros((batch_size, self.config.hidden_size),
819-
dtype=dtype,
820-
device=device),
821-
})
822-
823-
def load_weights(self, weights: Iterable[Tuple[str,
824-
torch.Tensor]]) -> Set[str]:
825-
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
826-
return loader.load_weights(weights)
827-
828-
829827
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
830828
pass
831829

vllm/model_executor/models/internlm2.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from vllm.sequence import IntermediateTensors, PoolerOutput
3333

3434
from .interfaces import SupportsLoRA, SupportsPP
35-
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
35+
from .utils import (is_pp_missing_parameter,
3636
make_empty_intermediate_tensors_factory, make_layers,
3737
maybe_prefix)
3838

@@ -306,42 +306,6 @@ def forward(
306306
hidden_states, _ = self.norm(hidden_states, residual)
307307
return hidden_states
308308

309-
def load_weights(self, weights: Iterable[Tuple[str,
310-
torch.Tensor]]) -> Set[str]:
311-
stacked_params_mapping = [
312-
# (param_name, shard_name, shard_id)
313-
("gate_up_proj", "w1", 0),
314-
("gate_up_proj", "w3", 1),
315-
]
316-
params_dict = dict(self.named_parameters())
317-
loaded_params: Set[str] = set()
318-
for name, loaded_weight in weights:
319-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
320-
if weight_name not in name:
321-
continue
322-
name = name.replace(weight_name, param_name)
323-
# Skip loading extra bias for GPTQ models.
324-
if name.endswith(".bias") and name not in params_dict:
325-
continue
326-
if is_pp_missing_parameter(name, self):
327-
continue
328-
param = params_dict[name]
329-
weight_loader = param.weight_loader
330-
weight_loader(param, loaded_weight, shard_id)
331-
break
332-
else:
333-
# Skip loading extra bias for GPTQ models.
334-
if name.endswith(".bias") and name not in params_dict:
335-
continue
336-
if is_pp_missing_parameter(name, self):
337-
continue
338-
param = params_dict[name]
339-
weight_loader = getattr(param, "weight_loader",
340-
default_weight_loader)
341-
weight_loader(param, loaded_weight)
342-
loaded_params.add(name)
343-
return loaded_params
344-
345309

346310
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
347311
packed_modules_mapping = {
@@ -409,8 +373,41 @@ def sample(
409373

410374
def load_weights(self, weights: Iterable[Tuple[str,
411375
torch.Tensor]]) -> Set[str]:
412-
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
413-
return loader.load_weights(weights)
376+
stacked_params_mapping = [
377+
# (param_name, shard_name, shard_id)
378+
("gate_up_proj", "w1", 0),
379+
("gate_up_proj", "w3", 1),
380+
]
381+
params_dict = dict(self.named_parameters())
382+
loaded_params: Set[str] = set()
383+
for name, loaded_weight in weights:
384+
if "rotary_emb.inv_freq" in name:
385+
continue
386+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
387+
if weight_name not in name:
388+
continue
389+
name = name.replace(weight_name, param_name)
390+
# Skip loading extra bias for GPTQ models.
391+
if name.endswith(".bias") and name not in params_dict:
392+
continue
393+
if is_pp_missing_parameter(name, self):
394+
continue
395+
param = params_dict[name]
396+
weight_loader = param.weight_loader
397+
weight_loader(param, loaded_weight, shard_id)
398+
break
399+
else:
400+
# Skip loading extra bias for GPTQ models.
401+
if name.endswith(".bias") and name not in params_dict:
402+
continue
403+
if is_pp_missing_parameter(name, self):
404+
continue
405+
param = params_dict[name]
406+
weight_loader = getattr(param, "weight_loader",
407+
default_weight_loader)
408+
weight_loader(param, loaded_weight)
409+
loaded_params.add(name)
410+
return loaded_params
414411

415412

416413
class InternLM2ForRewardModel(InternLM2ForCausalLM):

0 commit comments

Comments
 (0)