Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import glob
import inspect
import logging
import math
import os
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import Any, cast
Expand All @@ -22,7 +20,6 @@
from diffusers.utils.torch_utils import randn_tensor
from einops import rearrange
from PIL import Image
from safetensors.torch import load_file
from torch import nn
from torchvision import transforms
from transformers import PretrainedConfig, Siglip2ImageProcessorFast
Expand Down Expand Up @@ -357,14 +354,6 @@ def build_batch_2d_rope(
return stacked_cos, stacked_sin


def get_full_state_dict(model_path):
files = glob.glob(os.path.join(model_path, "*.safetensors"))
full_sd = {}
for f in files:
full_sd.update(load_file(f))
return full_sd


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -1744,7 +1733,6 @@ def __init__(self, config: HunyuanImage3Config, prefix: str = ""):
lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.wte = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
Expand Down Expand Up @@ -2047,7 +2035,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
inputs_embeds = self.embed_tokens(input_ids)

# embed positions
hidden_states = inputs_embeds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
UNetDown,
UNetUp,
build_batch_2d_rope,
get_full_state_dict,
real_batched_index_select,
)

Expand Down Expand Up @@ -113,10 +112,10 @@ def __init__(self, od_config: OmniDiffusionConfig) -> None:
self.vllm_config = get_current_vllm_config()
self.post_init()

def pre_load(self):
tp_rank = get_tensor_model_parallel_rank()
state_dict = get_full_state_dict(self.od_config.model)
non_layer_prefixes = [
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
# List of unexpected keywords in weight names
non_model_layer_prefixes = [
"vae",
"vision_model",
"vision_aligner",
Expand All @@ -129,32 +128,15 @@ def pre_load(self):
"time_embed_2",
"final_layer.model",
]
filtered_sd = {}
for k, v in state_dict.items():
if any(k.startswith(prefix) for prefix in non_layer_prefixes):
filtered_sd[k] = v

missing, unexpected = self.load_state_dict(filtered_sd, strict=False)

for prefix in non_layer_prefixes:
if hasattr(self, prefix.split(".")[0]):
module = dict(self.named_modules()).get(prefix)
if module:
module.to(f"{self.model.device.type}:{tp_rank}")
tp_rank = get_tensor_model_parallel_rank()
device_str = f"{self.model.device.type}:{tp_rank}"
named_modules = dict(self.named_modules())
for prefix in non_model_layer_prefixes:
mod = named_modules.get(prefix)
if mod:
mod.to(device_str)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
self.pre_load()
skip_prefixes = ["lm_head."] if self.hf_config.tie_word_embeddings else []
# List of unexpected keywords in weight names
unexpected_keywords = [
"vae",
"vision_aligner",
"vision_model",
"final_layer",
"patch_embed",
"timestep_emb",
"time_embed",
"time_embed_2",
"guidance_emb",
"timestep_r_emb",
]
Expand Down Expand Up @@ -892,7 +874,7 @@ def forward_call(

custom_pos_emb = self.get_pos_emb(custom_pos_emb, position_ids)

inputs_embeds = self.model.wte(input_ids)
inputs_embeds = self.model.embed_tokens(input_ids)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid calling embed_tokens on non-first PP ranks

forward_call now unconditionally does self.model.embed_tokens(input_ids), but HunyuanImage3Model.__init__ only creates embed_tokens on the first PP rank (or last when tied embeddings); other pipeline-parallel ranks get PPMissingLayer. With pipeline_parallel_size > 1 and default tie_word_embeddings=False, this change makes non-first ranks invoke a missing layer and fail during inference, whereas the previous self.model.wte path existed on every rank.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid calling embed_tokens on non-first PP ranks

forward_call now unconditionally does self.model.embed_tokens(input_ids), but HunyuanImage3Model.__init__ only creates embed_tokens on the first PP rank (or last when tied embeddings); other pipeline-parallel ranks get PPMissingLayer. With pipeline_parallel_size > 1 and default tie_word_embeddings=False, this change makes non-first ranks invoke a missing layer and fail during inference, whereas the previous self.model.wte path existed on every rank.

Useful? React with 👍 / 👎.

Current model do not support PP, so PP.is_first_rank is always true. No need to check for now.


bsz, seq_len, n_embd = inputs_embeds.shape

Expand Down