Skip to content

Commit 01276c1

Browse files
committed
[BugFix] Fix load_weights error when loading HunyuanImage3.0
Move some submodule load weights code of HunyuanImage3Pipeline to AutoWeightsLoader:load_weights, fix weights not initialized error. Signed-off-by: Semmer2 <semmer@live.cn>
1 parent fec0182 commit 01276c1

File tree

2 files changed

+13
-31
lines changed

2 files changed

+13
-31
lines changed

vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,7 +1744,6 @@ def __init__(self, config: HunyuanImage3Config, prefix: str = ""):
17441744
lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0
17451745
self.vocab_size = config.vocab_size + lora_vocab
17461746
self.org_vocab_size = config.vocab_size
1747-
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
17481747
if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank):
17491748
self.embed_tokens = VocabParallelEmbedding(
17501749
self.vocab_size,
@@ -2047,7 +2046,7 @@ def forward(
20472046
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
20482047

20492048
if inputs_embeds is None:
2050-
inputs_embeds = self.wte(input_ids)
2049+
inputs_embeds = self.embed_tokens(input_ids)
20512050

20522051
# embed positions
20532052
hidden_states = inputs_embeds

vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None:
113113
self.vllm_config = get_current_vllm_config()
114114
self.post_init()
115115

116-
def pre_load(self):
117-
tp_rank = get_tensor_model_parallel_rank()
118-
state_dict = get_full_state_dict(self.od_config.model)
119-
non_layer_prefixes = [
116+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
117+
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
118+
# List of unexpected keywords in weight names
119+
non_model_layer_prefixes = [
120120
"vae",
121121
"vision_model",
122122
"vision_aligner",
@@ -129,32 +129,15 @@ def pre_load(self):
129129
"time_embed_2",
130130
"final_layer.model",
131131
]
132-
filtered_sd = {}
133-
for k, v in state_dict.items():
134-
if any(k.startswith(prefix) for prefix in non_layer_prefixes):
135-
filtered_sd[k] = v
136-
137-
missing, unexpected = self.load_state_dict(filtered_sd, strict=False)
138-
139-
for prefix in non_layer_prefixes:
140-
if hasattr(self, prefix.split(".")[0]):
141-
module = dict(self.named_modules()).get(prefix)
142-
if module:
143-
module.to(f"{self.model.device.type}:{tp_rank}")
132+
tp_rank = get_tensor_model_parallel_rank()
133+
device_str = f"{self.model.device.type}:{tp_rank}"
134+
named_modules = dict(self.named_modules())
135+
for prefix in non_model_layer_prefixes:
136+
mod = named_modules.get(prefix)
137+
if mod:
138+
mod.to(device_str)
144139

145-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
146-
self.pre_load()
147-
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
148-
# List of unexpected keywords in weight names
149140
unexpected_keywords = [
150-
"vae",
151-
"vision_aligner",
152-
"vision_model",
153-
"final_layer",
154-
"patch_embed",
155-
"timestep_emb",
156-
"time_embed",
157-
"time_embed_2",
158141
"guidance_emb",
159142
"timestep_r_emb",
160143
]
@@ -892,7 +875,7 @@ def forward_call(
892875

893876
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
894877

895-
inputs_embeds = self.model.wte(input_ids)
878+
inputs_embeds = self.model.embed_tokens(input_ids)
896879

897880
bsz, seq_len, n_embd = inputs_embeds.shape
898881

0 commit comments

Comments
 (0)