Skip to content

Commit f35850c

Browse files
committed
address reviews
1 parent 0c1ebc3 commit f35850c

File tree

5 files changed

+110
-150
lines changed

5 files changed

+110
-150
lines changed

docs/source/en/api/pipelines/cogview3.md

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
-->
1515

16-
# CogVideoX
16+
# CogView3Plus
1717

1818
[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.
1919

@@ -29,45 +29,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
2929

3030
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
3131

32-
## Inference
33-
34-
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
35-
36-
First, load the pipeline:
37-
38-
```python
39-
import torch
40-
from diffusers import CogView3PlusPipeline
41-
from diffusers.utils import export_to_video,load_image
42-
43-
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b").to("cuda") # or "THUDM/CogVideoX-2b"
44-
```
45-
46-
Then change the memory layout of the `transformer` and `vae` components to `torch.channels_last`:
47-
48-
```python
49-
pipe.transformer.to(memory_format=torch.channels_last)
50-
pipe.vae.to(memory_format=torch.channels_last)
51-
```
52-
53-
Compile the components and run inference:
54-
55-
```python
56-
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
57-
pipe.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
58-
59-
# CogVideoX works well with long and well-described prompts
60-
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
61-
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
62-
```
63-
64-
The [benchmark](TODO) results on an 80GB A100 machine are:
65-
66-
```
67-
Without torch.compile(): Average inference time: TODO seconds.
68-
With torch.compile(): Average inference time: TODO seconds.
69-
```
70-
7132
## CogView3PlusPipeline
7233

7334
[[autodoc]] CogView3PlusPipeline

src/diffusers/models/embeddings.py

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,60 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
442442
return embeds
443443

444444

445+
class CogView3PlusPatchEmbed(nn.Module):
446+
def __init__(
447+
self,
448+
in_channels: int = 16,
449+
hidden_size: int = 2560,
450+
patch_size: int = 2,
451+
text_hidden_size: int = 4096,
452+
pos_embed_max_size: int = 128,
453+
):
454+
super().__init__()
455+
self.in_channels = in_channels
456+
self.hidden_size = hidden_size
457+
self.patch_size = patch_size
458+
self.text_hidden_size = text_hidden_size
459+
self.pos_embed_max_size = pos_embed_max_size
460+
# Linear projection for image patches
461+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
462+
463+
# Linear projection for text embeddings
464+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
465+
466+
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
467+
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
468+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
469+
470+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
471+
batch_size, channel, height, width = hidden_states.shape
472+
473+
if height % self.patch_size != 0 or width % self.patch_size != 0:
474+
raise ValueError("Height and width must be divisible by patch size")
475+
476+
height = height // self.patch_size
477+
width = width // self.patch_size
478+
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
479+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
480+
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
481+
482+
# Project the patches
483+
hidden_states = self.proj(hidden_states)
484+
encoder_hidden_states = self.text_proj(encoder_hidden_states)
485+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
486+
487+
# Calculate text_length
488+
text_length = encoder_hidden_states.shape[1]
489+
490+
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
491+
text_pos_embed = torch.zeros(
492+
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
493+
)
494+
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
495+
496+
return (hidden_states + pos_embed).to(hidden_states.dtype)
497+
498+
445499
def get_3d_rotary_pos_embed(
446500
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
447501
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@@ -714,58 +768,6 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
714768
return freqs_cos, freqs_sin
715769

716770

717-
class CogView3PlusPatchEmbed(nn.Module):
718-
def __init__(
719-
self,
720-
in_channels: int = 16,
721-
hidden_size: int = 2560,
722-
patch_size: int = 2,
723-
text_hidden_size: int = 4096,
724-
pos_embed_max_size: int = 128,
725-
):
726-
super().__init__()
727-
self.in_channels = in_channels
728-
self.hidden_size = hidden_size
729-
self.patch_size = patch_size
730-
self.text_hidden_size = text_hidden_size
731-
self.pos_embed_max_size = pos_embed_max_size
732-
# Linear projection for image patches
733-
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
734-
735-
# Linear projection for text embeddings
736-
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
737-
738-
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
739-
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
740-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
741-
742-
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None) -> torch.Tensor:
743-
batch_size, channel, height, width = hidden_states.shape
744-
if height % self.patch_size != 0 or width % self.patch_size != 0:
745-
raise ValueError("Height and width must be divisible by patch size")
746-
height = height // self.patch_size
747-
width = width // self.patch_size
748-
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
749-
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
750-
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
751-
752-
# Project the patches
753-
hidden_states = self.proj(hidden_states)
754-
encoder_hidden_states = self.text_proj(encoder_hidden_states)
755-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
756-
757-
# Calculate text_length
758-
text_length = encoder_hidden_states.shape[1]
759-
760-
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
761-
text_pos_embed = torch.zeros(
762-
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
763-
)
764-
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
765-
766-
return (hidden_states + pos_embed).to(hidden_states.dtype)
767-
768-
769771
class TimestepEmbedding(nn.Module):
770772
def __init__(
771773
self,
@@ -1090,11 +1092,11 @@ def forward(self, timestep, class_labels, hidden_dtype=None):
10901092

10911093

10921094
class CombinedTimestepTextProjEmbeddings(nn.Module):
1093-
def __init__(self, embedding_dim, pooled_projection_dim, timesteps_dim=256):
1095+
def __init__(self, embedding_dim, pooled_projection_dim):
10941096
super().__init__()
10951097

1096-
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
1097-
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
1098+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
1099+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
10981100
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
10991101

11001102
def forward(self, timestep, pooled_projection):
@@ -1132,7 +1134,7 @@ def forward(self, timestep, guidance, pooled_projection):
11321134
return conditioning
11331135

11341136

1135-
class CogView3CombinedTimestepConditionEmbeddings(nn.Module):
1137+
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
11361138
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
11371139
super().__init__()
11381140

@@ -1154,9 +1156,11 @@ def forward(
11541156
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
11551157
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
11561158
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
1159+
1160+
# (B, 3 * condition_dim)
11571161
condition_proj = torch.cat(
11581162
[original_size_proj, crop_coords_proj, target_size_proj], dim=1
1159-
) # (B, 3 * condition_dim)
1163+
)
11601164

11611165
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
11621166
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ...models.modeling_utils import ModelMixin
3030
from ...models.normalization import AdaLayerNormContinuous
3131
from ...utils import is_torch_version, logging
32-
from ..embeddings import CogView3CombinedTimestepConditionEmbeddings, CogView3PlusPatchEmbed
32+
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
3333
from ..modeling_outputs import Transformer2DModelOutput
3434
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3535

@@ -133,12 +133,27 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
133133
The number of channels in each head.
134134
num_attention_heads (`int`, defaults to `64`):
135135
The number of heads to use for multi-head attention.
136-
out_channels (`int`, *optional*, defaults to `16`):
136+
out_channels (`int`, defaults to `16`):
137137
The number of channels in the output.
138138
text_embed_dim (`int`, defaults to `4096`):
139139
Input dimension of text embeddings from the text encoder.
140140
time_embed_dim (`int`, defaults to `512`):
141141
Output dimension of timestep embeddings.
142+
condition_dim (`int`, defaults to `256`):
143+
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, crop_coords).
144+
pooled_projection_dim (`int`, defaults to `1536`):
145+
The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions are
146+
used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 * condition_dim`,
147+
we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep embeddings will be projected
148+
to this dimension as well.
149+
TODO(yiyi): Do we need this parameter based on the above explanation?
150+
pos_embed_max_size (`int`, defaults to `128`):
151+
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added to input
152+
patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 means that the maximum
153+
supported height and width for image generation is `128 * vae_scale_factor * patch_size => 128 * 8 * 2 => 2048`.
154+
sample_size (`int`, defaults to `128`):
155+
The base resolution of input latents. If height/width is not provided during generation, this value is used to determine
156+
the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
142157
"""
143158

144159
_supports_gradient_checkpointing = True
@@ -163,15 +178,15 @@ def __init__(
163178
self.out_channels = out_channels
164179
self.inner_dim = num_attention_heads * attention_head_dim
165180

166-
self.pos_embed = CogView3PlusPatchEmbed(
181+
self.patch_embed = CogView3PlusPatchEmbed(
167182
in_channels=in_channels,
168183
hidden_size=self.inner_dim,
169184
patch_size=patch_size,
170185
text_hidden_size=text_embed_dim,
171186
pos_embed_max_size=pos_embed_max_size,
172187
)
173188

174-
self.time_condition_embed = CogView3CombinedTimestepConditionEmbeddings(
189+
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
175190
embedding_dim=time_embed_dim,
176191
condition_dim=condition_dim,
177192
pooled_projection_dim=pooled_projection_dim,
@@ -318,20 +333,31 @@ def forward(
318333
The [`CogView3PlusTransformer2DModel`] forward method.
319334
320335
Args:
321-
hidden_states (`torch.Tensor`): Input `hidden_states`.
322-
timestep (`torch.LongTensor`): Indicates denoising step.
323-
y (`torch.LongTensor`, *optional*): 标签输入,用于获取标签嵌入。
324-
block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors for residuals.
325-
joint_attention_kwargs (`dict`, *optional*): Additional kwargs for the attention processor.
326-
return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `Transformer2DModelOutput`.
336+
hidden_states (`torch.Tensor`):
337+
Input `hidden_states` of shape `(batch size, channel, height, width)`.
338+
encoder_hidden_states (`torch.Tensor`):
339+
Conditional embeddings (embeddings computed from the input conditions such as prompts)
340+
of shape `(batch_size, sequence_len, text_embed_dim)`
341+
timestep (`torch.LongTensor`):
342+
Used to indicate denoising step.
343+
original_size (`torch.Tensor`):
344+
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
345+
target_size (`torch.Tensor`):
346+
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
347+
crop_coords (`torch.Tensor`):
348+
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
349+
return_dict (`bool`, *optional*, defaults to `True`):
350+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
351+
tuple.
327352
328353
Returns:
329-
Output tensor or `Transformer2DModelOutput`.
354+
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
355+
The denoised latents using provided inputs as conditioning.
330356
"""
331357
height, width = hidden_states.shape[-2:]
332358
text_seq_length = encoder_hidden_states.shape[1]
333359

334-
hidden_states = self.pos_embed(
360+
hidden_states = self.patch_embed(
335361
hidden_states, encoder_hidden_states
336362
) # takes care of adding positional embeddings too.
337363
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
"CogVideoXVideoToVideoPipeline",
147147
]
148148
_import_structure["cogview3"] = [
149-
"CogView3PlusPipeline",
149+
"CogView3PlusPipeline"
150150
]
151151
_import_structure["controlnet"].extend(
152152
[

0 commit comments

Comments
 (0)