|
61 | 61 | make_layers)
|
62 | 62 |
|
63 | 63 |
|
| 64 | +def _is_moe(config: PretrainedConfig) -> bool: |
| 65 | + num_experts = getattr(config, "num_experts", None) |
| 66 | + if isinstance(num_experts, int): |
| 67 | + return num_experts > 1 |
| 68 | + if isinstance(num_experts, list) and num_experts: |
| 69 | + # Ensure all elements are integers before calling max. |
| 70 | + if all(isinstance(e, int) for e in num_experts): |
| 71 | + return max(num_experts) > 1 |
| 72 | + else: |
| 73 | + return False |
| 74 | + return False |
| 75 | + |
| 76 | + |
64 | 77 | def _get_cla_factor(config: PretrainedConfig) -> int:
|
65 | 78 | if not getattr(config, "use_cla", False):
|
66 | 79 | return 1
|
@@ -140,8 +153,8 @@ def __init__(
|
140 | 153 | # the KV heads across multiple tensor parallel GPUs.
|
141 | 154 | assert tp_size % self.total_num_kv_heads == 0
|
142 | 155 | self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
143 |
| - # MistralConfig has an optional head_dim introduced by Mistral-Nemo |
144 |
| - if hasattr(config, "head_dim"): |
| 156 | + |
| 157 | + if hasattr(config, "head_dim") and config.head_dim: |
145 | 158 | self.head_dim = config.head_dim
|
146 | 159 | elif hasattr(config, "attention_head_dim"):
|
147 | 160 | self.head_dim = config.attention_head_dim
|
@@ -490,12 +503,23 @@ def __init__(
|
490 | 503 | else:
|
491 | 504 | raise RuntimeError(f"Unsupported attention type: {attention_type}")
|
492 | 505 |
|
493 |
| - self.mlp = HunYuanSparseMoeBlock( |
494 |
| - config=config, |
495 |
| - quant_config=quant_config, |
496 |
| - layer_id=layer_id, |
497 |
| - prefix=f"{prefix}.mlp", |
498 |
| - ) |
| 506 | + if _is_moe(config): |
| 507 | + self.mlp = HunYuanSparseMoeBlock( |
| 508 | + config=config, |
| 509 | + quant_config=quant_config, |
| 510 | + layer_id=layer_id, |
| 511 | + prefix=f"{prefix}.mlp", |
| 512 | + ) |
| 513 | + else: |
| 514 | + self.mlp = HunYuanMLP( |
| 515 | + hidden_size=self.hidden_size, |
| 516 | + intermediate_size=self.intermediate_size, |
| 517 | + hidden_act=config.hidden_act, |
| 518 | + quant_config=quant_config, |
| 519 | + bias=getattr(config, "mlp_bias", False), |
| 520 | + prefix=f"{prefix}.mlp", |
| 521 | + ) |
| 522 | + |
499 | 523 | self.input_layernorm = RMSNorm(config.hidden_size,
|
500 | 524 | eps=config.rms_norm_eps)
|
501 | 525 | self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
@@ -642,15 +666,17 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
|
642 | 666 | return torch.concat((q, k, v))
|
643 | 667 |
|
644 | 668 | def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
645 |
| - |
646 |
| - # Params for weights, fp8 weight scales, fp8 activation scales |
647 |
| - # (param_name, weight_name, expert_id, shard_id) |
648 |
| - return FusedMoE.make_expert_params_mapping( |
649 |
| - ckpt_gate_proj_name="gate_proj", |
650 |
| - ckpt_down_proj_name="down_proj", |
651 |
| - ckpt_up_proj_name="up_proj", |
652 |
| - num_experts=self.config.num_experts, |
653 |
| - ) |
| 669 | + if _is_moe(self.config): |
| 670 | + # Params for weights, fp8 weight scales, fp8 activation scales |
| 671 | + # (param_name, weight_name, expert_id, shard_id) |
| 672 | + return FusedMoE.make_expert_params_mapping( |
| 673 | + ckpt_gate_proj_name="gate_proj", |
| 674 | + ckpt_down_proj_name="down_proj", |
| 675 | + ckpt_up_proj_name="up_proj", |
| 676 | + num_experts=self.config.num_experts, |
| 677 | + ) |
| 678 | + else: |
| 679 | + return [] |
654 | 680 |
|
655 | 681 | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
656 | 682 | cla_factor = _get_cla_factor(self.config)
|
@@ -815,7 +841,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
815 | 841 | return loaded_params
|
816 | 842 |
|
817 | 843 |
|
818 |
| -class HunYuanMoEV1ForCausalLM(nn.Module, SupportsLoRA): |
| 844 | +class HunYuanV1Base(nn.Module, SupportsLoRA): |
819 | 845 | packed_modules_mapping = {
|
820 | 846 | "qkv_proj": [
|
821 | 847 | "q_proj",
|
@@ -901,3 +927,11 @@ def load_weights(self, weights: Iterable[tuple[str,
|
901 | 927 |
|
902 | 928 | def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
903 | 929 | return self.model.get_expert_mapping()
|
| 930 | + |
| 931 | + |
| 932 | +class HunYuanDenseV1ForCausalLM(HunYuanV1Base): |
| 933 | + pass |
| 934 | + |
| 935 | + |
| 936 | +class HunYuanMoEV1ForCausalLM(HunYuanV1Base): |
| 937 | + pass |
0 commit comments