Skip to content

Commit dd9e888

Browse files
committed
vllm_0.15.0
1 parent c57dc7f commit dd9e888

File tree

3 files changed

+288
-305
lines changed

3 files changed

+288
-305
lines changed

vllm_omni/diffusion/models/hunyuan/hunyuan_image3_utils.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,285 @@
88
import os
99
import glob
1010
from safetensors.torch import load_file
11+
from transformers import PretrainedConfig
12+
from typing import Optional, Tuple, Any, List, Union, Iterable, cast
13+
import math
14+
import inspect
15+
from torch import nn
16+
17+
def _is_moe(config: PretrainedConfig) -> bool:
18+
num_experts = getattr(config, "num_experts", None)
19+
if isinstance(num_experts, int):
20+
return num_experts > 1
21+
if isinstance(num_experts, list) and num_experts:
22+
# Ensure all elements are integers before calling max.
23+
if all(isinstance(e, int) for e in num_experts):
24+
return max(num_experts) > 1
25+
else:
26+
return False
27+
return False
28+
29+
30+
def _get_cla_factor(config: PretrainedConfig) -> int:
31+
if not getattr(config, "use_cla", False):
32+
return 1
33+
return getattr(config, "cla_share_factor", 1)
34+
35+
36+
def retrieve_timesteps(
37+
scheduler,
38+
num_inference_steps: Optional[int] = None,
39+
device: Optional[Union[str, torch.device]] = None,
40+
timesteps: Optional[List[int]] = None,
41+
sigmas: Optional[List[float]] = None,
42+
**kwargs,
43+
):
44+
"""
45+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
46+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
47+
48+
Args:
49+
scheduler (`SchedulerMixin`):
50+
The scheduler to get timesteps from.
51+
num_inference_steps (`int`):
52+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
53+
must be `None`.
54+
device (`str` or `torch.device`, *optional*):
55+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
56+
timesteps (`List[int]`, *optional*):
57+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
58+
`num_inference_steps` and `sigmas` must be `None`.
59+
sigmas (`List[float]`, *optional*):
60+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
61+
`num_inference_steps` and `timesteps` must be `None`.
62+
63+
Returns:
64+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
65+
second element is the number of inference steps.
66+
"""
67+
if timesteps is not None and sigmas is not None:
68+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
69+
if timesteps is not None:
70+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
71+
if not accepts_timesteps:
72+
raise ValueError(
73+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
74+
f" timestep schedules. Please check whether you are using the correct scheduler."
75+
)
76+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
77+
timesteps = scheduler.timesteps
78+
num_inference_steps = len(timesteps)
79+
elif sigmas is not None:
80+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
81+
if not accept_sigmas:
82+
raise ValueError(
83+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
84+
f" sigmas schedules. Please check whether you are using the correct scheduler."
85+
)
86+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
87+
timesteps = scheduler.timesteps
88+
num_inference_steps = len(timesteps)
89+
else:
90+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
91+
timesteps = scheduler.timesteps
92+
return timesteps, num_inference_steps
93+
94+
def real_batched_index_select(t, dim, idx):
95+
""" index_select for batched index and batched t """
96+
assert t.ndim >= 2 and idx.ndim >= 2, f"{t.ndim=} {idx.ndim=}"
97+
assert len(t) == len(idx), f"{len(t)=} != {len(idx)=}"
98+
return torch.stack([torch.index_select(t[i], dim - 1, idx[i]) for i in range(len(t))])
99+
100+
101+
def conv_nd(dims, *args, **kwargs):
102+
"""
103+
Create a 1D, 2D, or 3D convolution module.
104+
"""
105+
if dims == 1:
106+
return nn.Conv1d(*args, **kwargs)
107+
elif dims == 2:
108+
return nn.Conv2d(*args, **kwargs)
109+
elif dims == 3:
110+
return nn.Conv3d(*args, **kwargs)
111+
raise ValueError(f"unsupported dimensions: {dims}")
112+
11113

114+
def normalization(channels, **kwargs):
115+
"""
116+
Make a standard normalization layer.
117+
118+
:param channels: number of input channels.
119+
:return: a nn.Module for normalization.
120+
"""
121+
return nn.GroupNorm(32, channels, **kwargs)
122+
123+
124+
def linear(*args, **kwargs):
125+
"""
126+
Create a linear module.
127+
"""
128+
return nn.Linear(*args, **kwargs)
129+
130+
131+
def zero_module(module):
132+
"""
133+
Zero out the parameters of a module and return it.
134+
"""
135+
for p in module.parameters():
136+
p.detach().zero_()
137+
return module
138+
139+
140+
def _to_tuple(x, dim=2):
141+
if isinstance(x, int):
142+
return (x,) * dim
143+
elif len(x) == dim:
144+
return x
145+
else:
146+
raise ValueError(f"Expected length {dim} or int, but got {x}")
147+
148+
149+
def get_meshgrid_nd(start, *args, dim=2):
150+
if len(args) == 0:
151+
# start is grid_size
152+
num = _to_tuple(start, dim=dim)
153+
start = (0,) * dim
154+
stop = num
155+
elif len(args) == 1:
156+
# start is start, args[0] is stop, step is 1
157+
start = _to_tuple(start, dim=dim)
158+
stop = _to_tuple(args[0], dim=dim)
159+
num = [stop[i] - start[i] for i in range(dim)]
160+
# assert num are all integers
161+
num_int = [int(x) for x in num]
162+
assert (torch.tensor(num) == torch.tensor(num_int)).all(), f"num should be int, but got {num}"
163+
num = num_int
164+
elif len(args) == 2:
165+
# start is start, args[0] is stop, args[1] is num
166+
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
167+
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
168+
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
169+
else:
170+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
171+
172+
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
173+
axis_grid = []
174+
for i in range(dim):
175+
a, b, n = start[i], stop[i], num[i]
176+
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
177+
axis_grid.append(g)
178+
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [H, W]
179+
grid = torch.stack(grid, dim=0) # [dim, H, W]
180+
181+
return grid
182+
183+
def build_2d_rope(
184+
seq_len: int, n_elem: int, image_infos: Optional[List[Tuple[slice, Tuple[int, int]]]] = None,
185+
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
186+
return_all_pos: bool = False,
187+
):
188+
189+
assert n_elem % 4 == 0, f"n_elem must be divisible by 4, but got {n_elem}."
190+
191+
# theta
192+
if base_rescale_factor != 1.0:
193+
base *= base_rescale_factor ** (n_elem / (n_elem - 2))
194+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
195+
theta = theta.reshape(1, n_elem // 4, 2) # [1, half_d, 2]
196+
197+
# position indices
198+
if image_infos is None:
199+
image_infos = []
200+
201+
image_infos_list = [image_infos]
202+
sample_seq_lens = [seq_len]
203+
204+
# Prepare position indices for each sample
205+
x_sections = []
206+
y_sections = []
207+
for sample_id, sample_image_infos in enumerate(image_infos_list):
208+
last_pos = 0
209+
for sec_slice, (h, w) in sample_image_infos:
210+
L = sec_slice.start # start from 0, so image_slice.start is just L
211+
# previous text
212+
if last_pos < L:
213+
y_sections.append(torch.arange(last_pos, L))
214+
x_sections.append(torch.arange(last_pos, L))
215+
elif h is None:
216+
# Interleave data has overlapped positions for <boi> <size> <ratio> <timestep> <eoi> tokens.
217+
y_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
218+
x_sections.append(torch.arange(sec_slice.start, sec_slice.stop))
219+
continue
220+
else:
221+
# Interleave data has overlapped positions for noised image and the successive clean image,
222+
# leading to last_pos (= last text end L + noise w * h) > L (last text end L).
223+
pass
224+
# current image
225+
beta_y = L + (w * h - h) / 2
226+
beta_x = L + (w * h - w) / 2
227+
grid = get_meshgrid_nd((beta_y, beta_x), (beta_y + h, beta_x + w)) # [2, h, w]
228+
grid = grid.reshape(2, -1) # (y, x)
229+
y_sections.append(grid[0])
230+
x_sections.append(grid[1])
231+
# step
232+
last_pos = L + w * h
233+
# final text
234+
y_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
235+
x_sections.append(torch.arange(last_pos, sample_seq_lens[sample_id]))
236+
237+
x_pos = torch.cat(x_sections).long()
238+
y_pos = torch.cat(y_sections).long()
239+
# If there are overlap positions, we need to remove them.
240+
x_pos = x_pos[:seq_len]
241+
y_pos = y_pos[:seq_len]
242+
all_pos = torch.stack((y_pos, x_pos), dim=1).unsqueeze(1).to(device) # [seq_len, 1, 2]
243+
244+
# calc rope
245+
idx_theta = (all_pos * theta).reshape(all_pos.shape[0], n_elem // 2).repeat(1, 2)
246+
247+
cos = torch.cos(idx_theta)
248+
sin = torch.sin(idx_theta)
249+
250+
if return_all_pos:
251+
return cos, sin, all_pos
252+
253+
return cos, sin
254+
255+
256+
def build_batch_2d_rope(
257+
seq_len: int, n_elem: int, image_infos: Optional[List[List[Tuple[slice, Tuple[int, int]]]]] = None,
258+
device: Optional[torch.device] = None, base: int = 10000, base_rescale_factor: float = 1.0,
259+
return_all_pos: bool = False,
260+
):
261+
cos_list, sin_list, all_pos_list = [], [], []
262+
if image_infos is None:
263+
image_infos = [None]
264+
for i, image_info in enumerate(image_infos):
265+
res = build_2d_rope(
266+
seq_len, n_elem, image_infos=image_info, device=device,
267+
base=base, base_rescale_factor=base_rescale_factor,
268+
return_all_pos=return_all_pos,
269+
)
270+
if isinstance(res, tuple) and len(res) == 3:
271+
cos, sin, all_pos = res
272+
elif isinstance(res, tuple) and len(res) == 2:
273+
cos, sin = res
274+
all_pos = None
275+
else:
276+
raise ValueError(
277+
"build_2d_rope must return a tuple of length 2 or 3 "
278+
f"when return_all_pos={return_all_pos}, got: {type(res)} with length "
279+
f"{len(res) if isinstance(res, tuple) else 'N/A'}"
280+
)
281+
cos_list.append(cos)
282+
sin_list.append(sin)
283+
all_pos_list.append(all_pos)
284+
stacked_cos = torch.stack(cos_list, dim=0)
285+
stacked_sin = torch.stack(sin_list, dim=0)
286+
if return_all_pos:
287+
return stacked_cos, stacked_sin, all_pos_list
288+
289+
return stacked_cos, stacked_sin
12290

13291
def get_full_state_dict(model_path):
14292
files = glob.glob(os.path.join(model_path, "*.safetensors"))

vllm_omni/diffusion/models/hunyuan/hunyuan_image_3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
TimestepEmbedder,
2626
UNetDown,
2727
UNetUp,
28+
CausalMMOutputWithPast,
29+
)
30+
from .hunyuan_image3_utils import (
2831
build_batch_2d_rope,
2932
real_batched_index_select,
30-
CausalMMOutputWithPast,
3133
)
3234
from .autoencoder_kl_3d import AutoencoderKLConv3D
3335
from .siglip2 import Siglip2VisionTransformer, LightProjector

0 commit comments

Comments
 (0)