Skip to content

Commit 1ca198e

Browse files
authored
[BugFix] Fix load_weights error when loading HunyuanImage3.0 (#1598)
Signed-off-by: Semmer2 <semmer@live.cn>
1 parent e37a89f commit 1ca198e

File tree

2 files changed

+13
-43
lines changed

2 files changed

+13
-43
lines changed

vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import glob
54
import inspect
65
import logging
76
import math
8-
import os
97
from collections.abc import Callable, Iterable
108
from dataclasses import dataclass
119
from typing import Any, cast
@@ -22,7 +20,6 @@
2220
from diffusers.utils.torch_utils import randn_tensor
2321
from einops import rearrange
2422
from PIL import Image
25-
from safetensors.torch import load_file
2623
from torch import nn
2724
from torchvision import transforms
2825
from transformers import PretrainedConfig, Siglip2ImageProcessorFast
@@ -357,14 +354,6 @@ def build_batch_2d_rope(
357354
return stacked_cos, stacked_sin
358355

359356

360-
def get_full_state_dict(model_path):
361-
files = glob.glob(os.path.join(model_path, "*.safetensors"))
362-
full_sd = {}
363-
for f in files:
364-
full_sd.update(load_file(f))
365-
return full_sd
366-
367-
368357
def rotate_half(x):
369358
"""Rotates half the hidden dims of the input."""
370359
x1 = x[..., : x.shape[-1] // 2]
@@ -1744,7 +1733,6 @@ def __init__(self, config: HunyuanImage3Config, prefix: str = ""):
17441733
lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0
17451734
self.vocab_size = config.vocab_size + lora_vocab
17461735
self.org_vocab_size = config.vocab_size
1747-
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
17481736
if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank):
17491737
self.embed_tokens = VocabParallelEmbedding(
17501738
self.vocab_size,
@@ -2047,7 +2035,7 @@ def forward(
20472035
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
20482036

20492037
if inputs_embeds is None:
2050-
inputs_embeds = self.wte(input_ids)
2038+
inputs_embeds = self.embed_tokens(input_ids)
20512039

20522040
# embed positions
20532041
hidden_states = inputs_embeds

vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
UNetDown,
3838
UNetUp,
3939
build_batch_2d_rope,
40-
get_full_state_dict,
4140
real_batched_index_select,
4241
)
4342

@@ -113,10 +112,10 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None:
113112
self.vllm_config = get_current_vllm_config()
114113
self.post_init()
115114

116-
def pre_load(self):
117-
tp_rank = get_tensor_model_parallel_rank()
118-
state_dict = get_full_state_dict(self.od_config.model)
119-
non_layer_prefixes = [
115+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
116+
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
117+
# List of unexpected keywords in weight names
118+
non_model_layer_prefixes = [
120119
"vae",
121120
"vision_model",
122121
"vision_aligner",
@@ -129,32 +128,15 @@ def pre_load(self):
129128
"time_embed_2",
130129
"final_layer.model",
131130
]
132-
filtered_sd = {}
133-
for k, v in state_dict.items():
134-
if any(k.startswith(prefix) for prefix in non_layer_prefixes):
135-
filtered_sd[k] = v
136-
137-
missing, unexpected = self.load_state_dict(filtered_sd, strict=False)
138-
139-
for prefix in non_layer_prefixes:
140-
if hasattr(self, prefix.split(".")[0]):
141-
module = dict(self.named_modules()).get(prefix)
142-
if module:
143-
module.to(f"{self.model.device.type}:{tp_rank}")
131+
tp_rank = get_tensor_model_parallel_rank()
132+
device_str = f"{self.model.device.type}:{tp_rank}"
133+
named_modules = dict(self.named_modules())
134+
for prefix in non_model_layer_prefixes:
135+
mod = named_modules.get(prefix)
136+
if mod:
137+
mod.to(device_str)
144138

145-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
146-
self.pre_load()
147-
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
148-
# List of unexpected keywords in weight names
149139
unexpected_keywords = [
150-
"vae",
151-
"vision_aligner",
152-
"vision_model",
153-
"final_layer",
154-
"patch_embed",
155-
"timestep_emb",
156-
"time_embed",
157-
"time_embed_2",
158140
"guidance_emb",
159141
"timestep_r_emb",
160142
]
@@ -892,7 +874,7 @@ def forward_call(
892874

893875
custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)
894876

895-
inputs_embeds = self.model.wte(input_ids)
877+
inputs_embeds = self.model.embed_tokens(input_ids)
896878

897879
bsz, seq_len, n_embd = inputs_embeds.shape
898880

0 commit comments

Comments
 (0)