Skip to content

Commit 883a04c

Browse files
committed
[BugFix] Fix load_weights error when loading HunyuanImage3.0
Move some submodule load weights code of HunyuanImage3Pipeline to AutoWeightsLoader:load_weights, fix weights not initialized error. Signed-off-by: Semmer2 <semmer@live.cn>
1 parent fec0182 commit 883a04c

File tree

2 files changed

+13
-40
lines changed

2 files changed

+13
-40
lines changed

vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,6 @@ def build_batch_2d_rope(
357357
return stacked_cos, stacked_sin
358358

359359

360-
def get_full_state_dict(model_path):
361-
files = glob.glob(os.path.join(model_path, "*.safetensors"))
362-
full_sd = {}
363-
for f in files:
364-
full_sd.update(load_file(f))
365-
return full_sd
366-
367-
368360
def rotate_half(x):
369361
"""Rotates half the hidden dims of the input."""
370362
x1 = x[..., : x.shape[-1] // 2]
@@ -1744,7 +1736,6 @@ def __init__(self, config: HunyuanImage3Config, prefix: str = ""):
17441736
lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0
17451737
self.vocab_size = config.vocab_size + lora_vocab
17461738
self.org_vocab_size = config.vocab_size
1747-
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
17481739
if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank):
17491740
self.embed_tokens = VocabParallelEmbedding(
17501741
self.vocab_size,
@@ -2047,7 +2038,7 @@ def forward(
20472038
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
20482039

20492040
if inputs_embeds is None:
2050-
inputs_embeds = self.wte(input_ids)
2041+
inputs_embeds = self.embed_tokens(input_ids)
20512042

20522043
# embed positions
20532044
hidden_states = inputs_embeds

vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
UNetDown,
3838
UNetUp,
3939
build_batch_2d_rope,
40-
get_full_state_dict,
4140
real_batched_index_select,
4241
)
4342

@@ -113,10 +112,10 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None:
113112
self.vllm_config = get_current_vllm_config()
114113
self.post_init()
115114

116-
def pre_load(self):
117-
tp_rank = get_tensor_model_parallel_rank()
118-
state_dict = get_full_state_dict(self.od_config.model)
119-
non_layer_prefixes = [
115+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
116+
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
117+
# List of unexpected keywords in weight names
118+
non_model_layer_prefixes = [
120119
"vae",
121120
"vision_model",
122121
"vision_aligner",
@@ -129,32 +128,15 @@ def pre_load(self):
129128
"time_embed_2",
130129
"final_layer.model",
131130
]
132-
filtered_sd = {}
133-
for k, v in state_dict.items():
134-
if any(k.startswith(prefix) for prefix in non_layer_prefixes):
135-
filtered_sd[k] = v
136-
137-
missing, unexpected = self.load_state_dict(filtered_sd, strict=False)
138-
139-
for prefix in non_layer_prefixes:
140-
if hasattr(self, prefix.split(".")[0]):
141-
module = dict(self.named_modules()).get(prefix)
142-
if module:
143-
module.to(f"{self.model.device.type}:{tp_rank}")
131+
tp_rank = get_tensor_model_parallel_rank()
132+
device_str = f"{self.model.device.type}:{tp_rank}"
133+
named_modules = dict(self.named_modules())
134+
for prefix in non_model_layer_prefixes:
135+
mod = named_modules.get(prefix)
136+
if mod:
137+
mod.to(device_str)
144138

145-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
146-
self.pre_load()
147-
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
148-
# List of unexpected keywords in weight names
149139
unexpected_keywords = [
150-
"vae",
151-
"vision_aligner",
152-
"vision_model",
153-
"final_layer",
154-
"patch_embed",
155-
"timestep_emb",
156-
"time_embed",
157-
"time_embed_2",
158140
"guidance_emb",
159141
"timestep_r_emb",
160142
]
@@ -892,7 +874,7 @@ def forward_call(
892874

893875
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
894876

895-
inputs_embeds = self.model.wte(input_ids)
877+
inputs_embeds = self.model.embed_tokens(input_ids)
896878

897879
bsz, seq_len, n_embd = inputs_embeds.shape
898880

0 commit comments

Comments
 (0)