|
47 | 47 | from vllm.sequence import IntermediateTensors
|
48 | 48 |
|
49 | 49 | from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
50 |
| -from .utils import (is_pp_missing_parameter, |
| 50 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
51 | 51 | make_empty_intermediate_tensors_factory, make_layers)
|
52 | 52 |
|
53 | 53 |
|
@@ -321,6 +321,45 @@ def forward(
|
321 | 321 | hidden_states, _ = self.norm(hidden_states, residual)
|
322 | 322 | return hidden_states
|
323 | 323 |
|
| 324 | + def load_weights(self, weights: Iterable[Tuple[str, |
| 325 | + torch.Tensor]]) -> Set[str]: |
| 326 | + stacked_params_mapping = [ |
| 327 | + # (param_name, shard_name, shard_id) |
| 328 | + ("gate_up_proj", "gate_proj", 0), |
| 329 | + ("gate_up_proj", "up_proj", 1), |
| 330 | + ] |
| 331 | + params_dict = dict(self.named_parameters()) |
| 332 | + loaded_params: Set[str] = set() |
| 333 | + for name, loaded_weight in weights: |
| 334 | + if "rotary_emb.inv_freq" in name: |
| 335 | + continue |
| 336 | + |
| 337 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 338 | + if weight_name not in name: |
| 339 | + continue |
| 340 | + name = name.replace(weight_name, param_name) |
| 341 | + # Skip loading extra bias for GPTQ models. |
| 342 | + if name.endswith(".bias") and name not in params_dict: |
| 343 | + continue |
| 344 | + if is_pp_missing_parameter(name, self): |
| 345 | + continue |
| 346 | + param = params_dict[name] |
| 347 | + weight_loader = param.weight_loader |
| 348 | + weight_loader(param, loaded_weight, shard_id) |
| 349 | + break |
| 350 | + else: |
| 351 | + # Skip loading extra bias for GPTQ models. |
| 352 | + if name.endswith(".bias") and name not in params_dict: |
| 353 | + continue |
| 354 | + if is_pp_missing_parameter(name, self): |
| 355 | + continue |
| 356 | + param = params_dict[name] |
| 357 | + weight_loader = getattr(param, "weight_loader", |
| 358 | + default_weight_loader) |
| 359 | + weight_loader(param, loaded_weight) |
| 360 | + loaded_params.add(name) |
| 361 | + return loaded_params |
| 362 | + |
324 | 363 |
|
325 | 364 | class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
326 | 365 | SupportsQuant):
|
@@ -353,6 +392,7 @@ def __init__(
|
353 | 392 | self.lm_head = ParallelLMHead(config.vocab_size,
|
354 | 393 | config.hidden_size,
|
355 | 394 | quant_config=quant_config)
|
| 395 | + self.lm_head.weight.weight_loader = self.lm_head_weight_loader |
356 | 396 | if self.config.tie_word_embeddings:
|
357 | 397 | self.lm_head.weight = self.model.embed_tokens.weight
|
358 | 398 | self.logits_processor = LogitsProcessor(config.vocab_size)
|
@@ -393,53 +433,22 @@ def sample(
|
393 | 433 |
|
394 | 434 | def load_weights(self, weights: Iterable[Tuple[str,
|
395 | 435 | torch.Tensor]]) -> Set[str]:
|
396 |
| - stacked_params_mapping = [ |
397 |
| - # (param_name, shard_name, shard_id) |
398 |
| - ("gate_up_proj", "gate_proj", 0), |
399 |
| - ("gate_up_proj", "up_proj", 1), |
400 |
| - ] |
401 |
| - params_dict = dict(self.named_parameters()) |
402 |
| - loaded_params: Set[str] = set() |
403 |
| - for name, loaded_weight in weights: |
404 |
| - if "rotary_emb.inv_freq" in name: |
405 |
| - continue |
406 |
| - if name == "lm_head.weight": |
407 |
| - # Unlike Baichuan, Baichuan2 normalizes the head weights. |
408 |
| - # Refer to: |
409 |
| - # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 |
410 |
| - # Distinguish between Baichuan and Baichuan2 by checking the |
411 |
| - # vocab size. This is suggested by |
412 |
| - # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 |
413 |
| - is_baichuan2 = self.config.vocab_size == 125696 |
414 |
| - if is_baichuan2: |
415 |
| - loaded_weight = torch.nn.functional.normalize( |
416 |
| - loaded_weight) |
417 |
| - |
418 |
| - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
419 |
| - if weight_name not in name: |
420 |
| - continue |
421 |
| - name = name.replace(weight_name, param_name) |
422 |
| - # Skip loading extra bias for GPTQ models. |
423 |
| - if name.endswith(".bias") and name not in params_dict: |
424 |
| - continue |
425 |
| - if is_pp_missing_parameter(name, self): |
426 |
| - continue |
427 |
| - param = params_dict[name] |
428 |
| - weight_loader = param.weight_loader |
429 |
| - weight_loader(param, loaded_weight, shard_id) |
430 |
| - break |
431 |
| - else: |
432 |
| - # Skip loading extra bias for GPTQ models. |
433 |
| - if name.endswith(".bias") and name not in params_dict: |
434 |
| - continue |
435 |
| - if is_pp_missing_parameter(name, self): |
436 |
| - continue |
437 |
| - param = params_dict[name] |
438 |
| - weight_loader = getattr(param, "weight_loader", |
439 |
| - default_weight_loader) |
440 |
| - weight_loader(param, loaded_weight) |
441 |
| - loaded_params.add(name) |
442 |
| - return loaded_params |
| 436 | + loader = AutoWeightsLoader(self) |
| 437 | + return loader.load_weights(weights) |
| 438 | + |
| 439 | + def lm_head_weight_loader(self, param: nn.Parameter, |
| 440 | + loaded_weight: torch.Tensor): |
| 441 | + # Unlike Baichuan, Baichuan2 normalizes the head weights. |
| 442 | + # Refer to: |
| 443 | + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 |
| 444 | + # Distinguish between Baichuan and Baichuan2 by checking the |
| 445 | + # vocab size. This is suggested by |
| 446 | + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 |
| 447 | + is_baichuan2 = self.config.vocab_size == 125696 |
| 448 | + if is_baichuan2: |
| 449 | + loaded_weight = torch.nn.functional.normalize(loaded_weight) |
| 450 | + |
| 451 | + default_weight_loader(param, loaded_weight) |
443 | 452 |
|
444 | 453 |
|
445 | 454 | class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
|
0 commit comments