Skip to content

Commit 0d7c4c0

Browse files
committed
tmp commit, should be squash to previous one
Signed-off-by: wujiaping <1608928702@qq.com>
1 parent 632b83e commit 0d7c4c0

File tree

4 files changed

+1092
-938
lines changed

4 files changed

+1092
-938
lines changed

vllm_omni/diffusion/models/hunyuan_image_3/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""hunyuan Image 3 diffusion model components."""
44

5-
from vllm_omni.diffusion.models.hunyuan_image_3.pipeline_hunyuan_image_3 import (
6-
HunyuanImage3Pipeline,
7-
)
85
from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_image_3_transformer import (
96
HunyuanImage3Model,
107
HunyuanImage3Text2ImagePipeline,
118
)
9+
from vllm_omni.diffusion.models.hunyuan_image_3.pipeline_hunyuan_image_3 import (
10+
HunyuanImage3Pipeline,
11+
)
1212

1313
__all__ = [
1414
"HunyuanImage3Pipeline",

vllm_omni/diffusion/models/hunyuan_image_3/autoencoder.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,21 @@
1111
# limitations under the License.
1212
# ==============================================================================
1313

14-
from dataclasses import dataclass
15-
from typing import Tuple, Optional
1614
import math
17-
import random
15+
from dataclasses import dataclass
16+
from typing import Optional, Tuple # noqa: UP035
17+
1818
import numpy as np
19-
from einops import rearrange
2019
import torch
21-
from torch import Tensor, nn
2220
import torch.nn.functional as F
23-
2421
from diffusers.configuration_utils import ConfigMixin, register_to_config
2522
from diffusers.models.modeling_outputs import AutoencoderKLOutput
2623
from diffusers.models.modeling_utils import ModelMixin
27-
from diffusers.utils.torch_utils import randn_tensor
2824
from diffusers.utils import BaseOutput
25+
from diffusers.utils.torch_utils import randn_tensor
26+
from einops import rearrange
27+
from torch import Tensor, nn
28+
2929

3030
class DiagonalGaussianDistribution(object):
3131
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
@@ -57,6 +57,7 @@ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTens
5757
x = self.mean + self.std * sample
5858
return x
5959

60+
6061
@dataclass
6162
class DecoderOutput(BaseOutput):
6263
sample: torch.FloatTensor
@@ -71,6 +72,7 @@ def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
7172
def create_custom_forward(module):
7273
def custom_forward(*inputs):
7374
return module(*inputs)
75+
7476
return custom_forward
7577

7678
if use_checkpointing:
@@ -81,7 +83,7 @@ def custom_forward(*inputs):
8183

8284
class Conv3d(nn.Conv3d):
8385
"""
84-
Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5.
86+
Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5.
8587
Only symmetric padding is supported.
8688
"""
8789

@@ -102,9 +104,9 @@ def forward(self, input):
102104
value=0,
103105
)
104106
if i > 0:
105-
padded_chunk[:, :, :self.padding[0]] = chunks[i - 1][:, :, -self.padding[0]:]
107+
padded_chunk[:, :, : self.padding[0]] = chunks[i - 1][:, :, -self.padding[0] :]
106108
if i < len(chunks) - 1:
107-
padded_chunk[:, :, -self.padding[0]:] = chunks[i + 1][:, :, :self.padding[0]]
109+
padded_chunk[:, :, -self.padding[0] :] = chunks[i + 1][:, :, : self.padding[0]]
108110
else:
109111
padded_chunk = chunks[i]
110112
padded_chunks.append(padded_chunk)
@@ -120,7 +122,8 @@ def forward(self, input):
120122

121123

122124
class AttnBlock(nn.Module):
123-
""" Attention with torch sdpa implementation. """
125+
"""Attention with torch sdpa implementation."""
126+
124127
def __init__(self, in_channels: int):
125128
super().__init__()
126129
self.in_channels = in_channels
@@ -178,6 +181,7 @@ def forward(self, x):
178181
x = self.nin_shortcut(x)
179182
return x + h
180183

184+
181185
class DownsampleDCAE(nn.Module):
182186
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
183187
super().__init__()
@@ -198,6 +202,7 @@ def forward(self, x: Tensor):
198202
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
199203
return h + shortcut
200204

205+
201206
class UpsampleDCAE(nn.Module):
202207
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
203208
super().__init__()
@@ -215,10 +220,12 @@ def forward(self, x: Tensor):
215220
shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
216221
return h + shortcut
217222

223+
218224
class Encoder(nn.Module):
219225
"""
220226
The encoder network of AutoencoderKLConv3D.
221227
"""
228+
222229
def __init__(
223230
self,
224231
in_channels: int,
@@ -251,8 +258,9 @@ def __init__(
251258
down.block = block
252259

253260
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
254-
add_temporal_downsample = (add_spatial_downsample and
255-
bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal)))
261+
add_temporal_downsample = add_spatial_downsample and bool(
262+
i_level >= np.log2(ffactor_spatial // ffactor_temporal)
263+
)
256264
if add_spatial_downsample or add_temporal_downsample:
257265
assert i_level < len(block_out_channels) - 1
258266
block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
@@ -280,7 +288,8 @@ def forward(self, x: Tensor) -> Tensor:
280288
for i_level in range(len(self.block_out_channels)):
281289
for i_block in range(self.num_res_blocks):
282290
h = forward_with_checkpointing(
283-
self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
291+
self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing
292+
)
284293
if hasattr(self.down[i_level], "downsample"):
285294
h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing)
286295

@@ -298,10 +307,12 @@ def forward(self, x: Tensor) -> Tensor:
298307
h += shortcut
299308
return h
300309

310+
301311
class Decoder(nn.Module):
302312
"""
303313
The decoder network of AutoencoderKLConv3D.
304314
"""
315+
305316
def __init__(
306317
self,
307318
z_channels: int,
@@ -380,10 +391,12 @@ def forward(self, z: Tensor) -> Tensor:
380391
h = self.conv_out(h)
381392
return h
382393

394+
383395
class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
384396
"""
385397
Autoencoder model with KL-regularized latent space based on 3D convolutions.
386398
"""
399+
387400
_supports_gradient_checkpointing = True
388401

389402
@register_to_config
@@ -402,8 +415,8 @@ def __init__(
402415
shift_factor: Optional[float] = None,
403416
downsample_match_channel: bool = True,
404417
upsample_match_channel: bool = True,
405-
only_encoder: bool = False, # only build encoder for saving memory
406-
only_decoder: bool = False, # only build decoder for saving memory
418+
only_encoder: bool = False, # only build encoder for saving memory
419+
only_decoder: bool = False, # only build decoder for saving memory
407420
):
408421
super().__init__()
409422
self.ffactor_spatial = ffactor_spatial
@@ -449,27 +462,29 @@ def __init__(
449462

450463
# use torch.compile for faster encode speed
451464
self.use_compile = False
452-
465+
453466
def _set_gradient_checkpointing(self, module, value=False):
454467
if isinstance(module, (Encoder, Decoder)):
455468
module.gradient_checkpointing = value
456-
469+
457470
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
458471
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
459472
for x in range(blend_extent):
460-
b[:, :, :, :, x] = \
461-
a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
473+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
474+
x / blend_extent
475+
)
462476
return b
463477

464478
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
465479
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
466480
for y in range(blend_extent):
467-
b[:, :, :, y, :] = \
468-
a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
481+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
482+
y / blend_extent
483+
)
469484
return b
470485

471486
def spatial_tiled_decode(self, z: torch.Tensor):
472-
""" spatial tailing for frames """
487+
"""spatial tailing for frames"""
473488
B, C, T, H, W = z.shape
474489
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
475490
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) # 256 * 0.25 = 64
@@ -479,7 +494,7 @@ def spatial_tiled_decode(self, z: torch.Tensor):
479494
for i in range(0, H, overlap_size):
480495
row = []
481496
for j in range(0, W, overlap_size):
482-
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
497+
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
483498
decoded = self.decoder(tile)
484499
row.append(decoded)
485500
rows.append(row)
@@ -498,7 +513,7 @@ def spatial_tiled_decode(self, z: torch.Tensor):
498513
return dec
499514

500515
def temporal_tiled_decode(self, z: torch.Tensor):
501-
""" temporal tailing for frames """
516+
"""temporal tailing for frames"""
502517
B, C, T, H, W = z.shape
503518
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
504519
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) # 64 * 0.25 = 16
@@ -507,9 +522,10 @@ def temporal_tiled_decode(self, z: torch.Tensor):
507522

508523
row = []
509524
for i in range(0, T, overlap_size):
510-
tile = z[:, :, i: i + self.tile_latent_min_tsize, :, :]
525+
tile = z[:, :, i : i + self.tile_latent_min_tsize, :, :]
511526
if self.use_spatial_tiling and (
512-
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
527+
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
528+
):
513529
decoded = self.spatial_tiled_decode(tile)
514530
else:
515531
decoded = self.decoder(tile)
@@ -522,23 +538,27 @@ def temporal_tiled_decode(self, z: torch.Tensor):
522538
result_row.append(tile[:, :, :t_limit, :, :])
523539
dec = torch.cat(result_row, dim=-3)
524540
return dec
525-
541+
526542
def encode(self, x: Tensor, return_dict: bool = True):
527543
"""
528544
Encodes the input by passing through the encoder network.
529545
Support slicing and tiling for memory efficiency.
530546
"""
547+
531548
def _encode(x):
532549
if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
533550
return self.temporal_tiled_encode(x)
534551
if self.use_spatial_tiling and (
535-
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
552+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
553+
):
536554
return self.spatial_tiled_encode(x)
537555

538556
if self.use_compile:
557+
539558
@torch.compile
540559
def encoder(x):
541560
return self.encoder(x)
561+
542562
return encoder(x)
543563
return self.encoder(x)
544564

@@ -567,17 +587,19 @@ def encoder(x):
567587
return (posterior,)
568588

569589
return AutoencoderKLOutput(latent_dist=posterior)
570-
590+
571591
def decode(self, z: Tensor, return_dict: bool = True, generator=None):
572592
"""
573593
Decodes the input by passing through the decoder network.
574594
Support slicing and tiling for memory efficiency.
575595
"""
596+
576597
def _decode(z):
577598
if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
578599
return self.temporal_tiled_decode(z)
579600
if self.use_spatial_tiling and (
580-
z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
601+
z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size
602+
):
581603
return self.spatial_tiled_decode(z)
582604
return self.decoder(z)
583605

0 commit comments

Comments
 (0)