Skip to content

Commit 2671334

Browse files
authored
[Model] add Hunyuan V1 Dense Model support. (#21368)
Signed-off-by: Asher Zhang <[email protected]>
1 parent 2cc5016 commit 2671334

File tree

4 files changed

+57
-19
lines changed

4 files changed

+57
-19
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ th {
363363
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
364364
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
365365
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
366+
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ |
366367
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ |
367368
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
368369
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def check_available_online(
199199
trust_remote_code=True),
200200
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct",
201201
trust_remote_code=True),
202+
"HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124",
203+
trust_remote_code=True),
202204
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
203205
trust_remote_code=True),
204206
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",

vllm/model_executor/models/hunyuan_v1_moe.py renamed to vllm/model_executor/models/hunyuan_v1.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@
6161
make_layers)
6262

6363

64+
def _is_moe(config: PretrainedConfig) -> bool:
65+
num_experts = getattr(config, "num_experts", None)
66+
if isinstance(num_experts, int):
67+
return num_experts > 1
68+
if isinstance(num_experts, list) and num_experts:
69+
# Ensure all elements are integers before calling max.
70+
if all(isinstance(e, int) for e in num_experts):
71+
return max(num_experts) > 1
72+
else:
73+
return False
74+
return False
75+
76+
6477
def _get_cla_factor(config: PretrainedConfig) -> int:
6578
if not getattr(config, "use_cla", False):
6679
return 1
@@ -140,8 +153,8 @@ def __init__(
140153
# the KV heads across multiple tensor parallel GPUs.
141154
assert tp_size % self.total_num_kv_heads == 0
142155
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
143-
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
144-
if hasattr(config, "head_dim"):
156+
157+
if hasattr(config, "head_dim") and config.head_dim:
145158
self.head_dim = config.head_dim
146159
elif hasattr(config, "attention_head_dim"):
147160
self.head_dim = config.attention_head_dim
@@ -490,12 +503,23 @@ def __init__(
490503
else:
491504
raise RuntimeError(f"Unsupported attention type: {attention_type}")
492505

493-
self.mlp = HunYuanSparseMoeBlock(
494-
config=config,
495-
quant_config=quant_config,
496-
layer_id=layer_id,
497-
prefix=f"{prefix}.mlp",
498-
)
506+
if _is_moe(config):
507+
self.mlp = HunYuanSparseMoeBlock(
508+
config=config,
509+
quant_config=quant_config,
510+
layer_id=layer_id,
511+
prefix=f"{prefix}.mlp",
512+
)
513+
else:
514+
self.mlp = HunYuanMLP(
515+
hidden_size=self.hidden_size,
516+
intermediate_size=self.intermediate_size,
517+
hidden_act=config.hidden_act,
518+
quant_config=quant_config,
519+
bias=getattr(config, "mlp_bias", False),
520+
prefix=f"{prefix}.mlp",
521+
)
522+
499523
self.input_layernorm = RMSNorm(config.hidden_size,
500524
eps=config.rms_norm_eps)
501525
self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -642,15 +666,17 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
642666
return torch.concat((q, k, v))
643667

644668
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
645-
646-
# Params for weights, fp8 weight scales, fp8 activation scales
647-
# (param_name, weight_name, expert_id, shard_id)
648-
return FusedMoE.make_expert_params_mapping(
649-
ckpt_gate_proj_name="gate_proj",
650-
ckpt_down_proj_name="down_proj",
651-
ckpt_up_proj_name="up_proj",
652-
num_experts=self.config.num_experts,
653-
)
669+
if _is_moe(self.config):
670+
# Params for weights, fp8 weight scales, fp8 activation scales
671+
# (param_name, weight_name, expert_id, shard_id)
672+
return FusedMoE.make_expert_params_mapping(
673+
ckpt_gate_proj_name="gate_proj",
674+
ckpt_down_proj_name="down_proj",
675+
ckpt_up_proj_name="up_proj",
676+
num_experts=self.config.num_experts,
677+
)
678+
else:
679+
return []
654680

655681
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
656682
cla_factor = _get_cla_factor(self.config)
@@ -815,7 +841,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
815841
return loaded_params
816842

817843

818-
class HunYuanMoEV1ForCausalLM(nn.Module, SupportsLoRA):
844+
class HunYuanV1Base(nn.Module, SupportsLoRA):
819845
packed_modules_mapping = {
820846
"qkv_proj": [
821847
"q_proj",
@@ -901,3 +927,11 @@ def load_weights(self, weights: Iterable[tuple[str,
901927

902928
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
903929
return self.model.get_expert_mapping()
930+
931+
932+
class HunYuanDenseV1ForCausalLM(HunYuanV1Base):
933+
pass
934+
935+
936+
class HunYuanMoEV1ForCausalLM(HunYuanV1Base):
937+
pass

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@
7979
"GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501
8080
"GritLM": ("gritlm", "GritLM"),
8181
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
82-
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
82+
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
83+
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
8384
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
8485
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
8586
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),

0 commit comments

Comments
 (0)