diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py index 04b961c9a4..0c2f9e290a 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py @@ -1,11 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import glob import inspect import logging import math -import os from collections.abc import Callable, Iterable from dataclasses import dataclass from typing import Any, cast @@ -22,7 +20,6 @@ from diffusers.utils.torch_utils import randn_tensor from einops import rearrange from PIL import Image -from safetensors.torch import load_file from torch import nn from torchvision import transforms from transformers import PretrainedConfig, Siglip2ImageProcessorFast @@ -357,14 +354,6 @@ def build_batch_2d_rope( return stacked_cos, stacked_sin -def get_full_state_dict(model_path): - files = glob.glob(os.path.join(model_path, "*.safetensors")) - full_sd = {} - for f in files: - full_sd.update(load_file(f)) - return full_sd - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -1744,7 +1733,6 @@ def __init__(self, config: HunyuanImage3Config, prefix: str = ""): lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.wte = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, @@ -2047,7 +2035,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) + inputs_embeds = self.embed_tokens(input_ids) # embed positions hidden_states = inputs_embeds diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py index 9721f2edbc..e4b717a697 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py @@ -37,7 +37,6 @@ UNetDown, UNetUp, build_batch_2d_rope, - get_full_state_dict, real_batched_index_select, ) @@ -113,10 +112,10 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None: self.vllm_config = get_current_vllm_config() self.post_init() - def pre_load(self): - tp_rank = get_tensor_model_parallel_rank() - state_dict = get_full_state_dict(self.od_config.model) - non_layer_prefixes = [ + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else [] + # List of unexpected keywords in weight names + non_model_layer_prefixes = [ "vae", "vision_model", "vision_aligner", @@ -129,32 +128,15 @@ def pre_load(self): "time_embed_2", "final_layer.model", ] - filtered_sd = {} - for k, v in state_dict.items(): - if any(k.startswith(prefix) for prefix in non_layer_prefixes): - filtered_sd[k] = v - - missing, unexpected = self.load_state_dict(filtered_sd, strict=False) - - for prefix in non_layer_prefixes: - if hasattr(self, prefix.split(".")[0]): - module = dict(self.named_modules()).get(prefix) - if module: - module.to(f"{self.model.device.type}:{tp_rank}") + tp_rank = get_tensor_model_parallel_rank() + device_str = f"{self.model.device.type}:{tp_rank}" + named_modules = dict(self.named_modules()) + for prefix in non_model_layer_prefixes: + mod = named_modules.get(prefix) + if mod: + mod.to(device_str) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - self.pre_load() - skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else [] - # List of unexpected keywords in weight names unexpected_keywords = [ - "vae", - "vision_aligner", - "vision_model", - "final_layer", - "patch_embed", - "timestep_emb", - "time_embed", - "time_embed_2", "guidance_emb", "timestep_r_emb", ] @@ -892,7 +874,7 @@ def forward_call( custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids) - inputs_embeds = self.model.wte(input_ids) + inputs_embeds = self.model.embed_tokens(input_ids) bsz, seq_len, n_embd = inputs_embeds.shape