Skip to content

Commit dd6568b

Browse files
use with cogview4 transformers forward twice of u and uc
1 parent bf7f322 commit dd6568b

File tree

10 files changed

+600
-179
lines changed

10 files changed

+600
-179
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@
276276
title: ConsisIDTransformer3DModel
277277
- local: api/models/cogview3plus_transformer2d
278278
title: CogView3PlusTransformer2DModel
279+
- local: api/models/cogview4_transformer2d
280+
title: CogView4Transformer2DModel
279281
- local: api/models/dit_transformer2d
280282
title: DiTTransformer2DModel
281283
- local: api/models/flux_transformer
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# CogView4Transformer2DModel
13+
14+
A Diffusion Transformer model for 2D data from [CogView4]()
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import CogView3PlusTransformer2DModel
20+
21+
transformer = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
22+
```
23+
24+
## CogView4Transformer2DModel
25+
26+
[[autodoc]] CogView4Transformer2DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

scripts/convert_cogview4_to_diffusers.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from accelerate import init_empty_weights
3232
from transformers import PreTrainedTokenizerFast, GlmForCausalLM
3333

34-
from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView3PlusTransformer2DModel
34+
from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
3535
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
3636
from diffusers.utils.import_utils import is_accelerate_available
3737

@@ -168,7 +168,7 @@ def main(args):
168168
converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
169169
args.transformer_checkpoint_path
170170
)
171-
transformer = CogView3PlusTransformer2DModel(
171+
transformer = CogView4Transformer2DModel(
172172
patch_size=2,
173173
in_channels=16,
174174
num_layers=28,
@@ -209,23 +209,27 @@ def main(args):
209209
if dtype is not None:
210210
vae = vae.to(dtype=dtype)
211211

212-
# text_encoder_id = "THUDM/glm-4-9b-hf"
213-
# tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
214-
# text_encoder = GlmForCausalLM.from_pretrained(
215-
# text_encoder_id,
216-
# cache_dir=args.text_encoder_cache_dir,
217-
# torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
218-
# )
219-
from transformers import AutoTokenizer,AutoModel
220-
text_encoder_id = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/glm-4-9b"
221-
tokenizer = AutoTokenizer.from_pretrained(text_encoder_id,trust_remote_code=True)
222-
text_encoder = AutoModel.from_pretrained(
212+
text_encoder_id = "/share/home/zyx/Models/glm-4-9b-hf"
213+
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
214+
text_encoder = GlmForCausalLM.from_pretrained(
223215
text_encoder_id,
224216
cache_dir=args.text_encoder_cache_dir,
225217
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
226-
trust_remote_code = True
227218
)
228-
# Apparently, the conversion does not work anymore without this :shrug:
219+
220+
# TODO: This is for Older GLM-4 as https://huggingface.co/THUDM/glm-4-9b, will use https://huggingface.co/THUDM/glm-4-9b-hf for new transformers version format.
221+
# TODO: Remove it later
222+
223+
# from transformers import AutoTokenizer,AutoModel
224+
# text_encoder_id = "/share/home/zyx/Models/Megatron-VLM/examples/dit/ckpts/glm-4-9b"
225+
# tokenizer = AutoTokenizer.from_pretrained(text_encoder_id,trust_remote_code=True)
226+
# text_encoder = AutoModel.from_pretrained(
227+
# text_encoder_id,
228+
# cache_dir=args.text_encoder_cache_dir,
229+
# torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
230+
# trust_remote_code = True
231+
# )
232+
229233
for param in text_encoder.parameters():
230234
param.data = param.data.contiguous()
231235

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
"AutoencoderTiny",
9393
"CogVideoXTransformer3DModel",
9494
"CogView3PlusTransformer2DModel",
95+
"CogView4Transformer2DModel",
9596
"ConsisIDTransformer3DModel",
9697
"ConsistencyDecoderVAE",
9798
"ControlNetModel",
@@ -606,6 +607,7 @@
606607
AutoencoderTiny,
607608
CogVideoXTransformer3DModel,
608609
CogView3PlusTransformer2DModel,
610+
CogView4Transformer2DModel,
609611
ConsisIDTransformer3DModel,
610612
ConsistencyDecoderVAE,
611613
ControlNetModel,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
6969
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
7070
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
71+
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
7172
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
7273
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
7374
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -130,6 +131,7 @@
130131
AuraFlowTransformer2DModel,
131132
CogVideoXTransformer3DModel,
132133
CogView3PlusTransformer2DModel,
134+
CogView4Transformer2DModel,
133135
ConsisIDTransformer3DModel,
134136
DiTTransformer2DModel,
135137
DualTransformer2DModel,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .transformer_2d import Transformer2DModel
1919
from .transformer_allegro import AllegroTransformer3DModel
2020
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
21+
from .transformer_cogview4 import CogView4Transformer2DModel
2122
from .transformer_flux import FluxTransformer2DModel
2223
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
2324
from .transformer_ltx import LTXVideoTransformer3DModel

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 42 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...models.modeling_utils import ModelMixin
2929
from ...models.normalization import AdaLayerNormContinuous
3030
from ...utils import is_torch_version, logging
31-
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed, CogView4PatchEmbed
31+
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
3232
from ..modeling_outputs import Transformer2DModelOutput
3333
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3434

@@ -84,7 +84,6 @@ def forward(
8484
hidden_states: torch.Tensor,
8585
encoder_hidden_states: torch.Tensor,
8686
emb: torch.Tensor,
87-
**kwargs,
8887
) -> torch.Tensor:
8988
text_seq_length = encoder_hidden_states.size(1)
9089

@@ -104,7 +103,7 @@ def forward(
104103

105104
# attention
106105
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
107-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **kwargs
106+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
108107
)
109108

110109
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
@@ -167,7 +166,8 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
167166
"""
168167

169168
_supports_gradient_checkpointing = True
170-
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed", "CogView4PlusPatchEmbed"]
169+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
170+
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
171171

172172
@register_to_config
173173
def __init__(
@@ -192,16 +192,7 @@ def __init__(
192192
# Each of these are sincos embeddings of shape 2 * condition_dim
193193
self.pooled_projection_dim = 3 * 2 * condition_dim
194194

195-
self.max_h = 256
196-
self.max_w = 256
197-
self.rope = self.prepare_rope(
198-
embed_dim=self.config.attention_head_dim,
199-
max_h=self.max_h,
200-
max_w=self.max_w,
201-
rotary_base=10000
202-
)
203-
204-
self.patch_embed = CogView4PatchEmbed(
195+
self.patch_embed = CogView3PlusPatchEmbed(
205196
in_channels=in_channels,
206197
hidden_size=self.inner_dim,
207198
patch_size=patch_size,
@@ -232,8 +223,7 @@ def __init__(
232223
embedding_dim=self.inner_dim,
233224
conditioning_embedding_dim=time_embed_dim,
234225
elementwise_affine=False,
235-
# eps=1e-6,
236-
eps=1e-5,
226+
eps=1e-6,
237227
)
238228
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
239229

@@ -303,55 +293,10 @@ def _set_gradient_checkpointing(self, module, value=False):
303293
if hasattr(module, "gradient_checkpointing"):
304294
module.gradient_checkpointing = value
305295

306-
@staticmethod
307-
def prepare_rope(embed_dim, max_h, max_w, rotary_base):
308-
dim_h = embed_dim // 2
309-
dim_w = embed_dim // 2
310-
h_inv_freq = 1.0 / (
311-
rotary_base ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
312-
)
313-
w_inv_freq = 1.0 / (
314-
rotary_base ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
315-
)
316-
h_seq = torch.arange(max_h, dtype=h_inv_freq.dtype)
317-
w_seq = torch.arange(max_w, dtype=w_inv_freq.dtype)
318-
freqs_h = torch.outer(h_seq, h_inv_freq)
319-
freqs_w = torch.outer(w_seq, w_inv_freq)
320-
return (freqs_h, freqs_w)
321-
322-
def get_rope_embedding(self, height, width, target_h, target_w, device):
323-
# Get pre-computed frequencies
324-
freqs_h, freqs_w = self.rope
325-
326-
h_idx = torch.arange(height)
327-
w_idx = torch.arange(width)
328-
inner_h_idx = (h_idx * self.max_h) // target_h
329-
inner_w_idx = (w_idx * self.max_w) // target_w
330-
331-
freqs_h = freqs_h[inner_h_idx].to(device)
332-
freqs_w = freqs_w[inner_w_idx].to(device)
333-
334-
# Create position matrices for height and width
335-
# [height, 1, dim//4] and [1, width, dim//4]
336-
freqs_h = freqs_h.unsqueeze(1)
337-
freqs_w = freqs_w.unsqueeze(0)
338-
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
339-
freqs_h = freqs_h.expand(height, width, -1)
340-
freqs_w = freqs_w.expand(height, width, -1)
341-
342-
# Concatenate along last dimension to get [height, width, dim//2]
343-
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
344-
345-
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
346-
freqs = freqs.reshape(height*width, -1)
347-
348-
return freqs.cos(), freqs.sin()
349-
350296
def forward(
351297
self,
352298
hidden_states: torch.Tensor,
353-
prompt_embeds: torch.Tensor,
354-
negative_prompt_embeds: torch.Tensor | None,
299+
encoder_hidden_states: torch.Tensor,
355300
timestep: torch.LongTensor,
356301
original_size: torch.Tensor,
357302
target_size: torch.Tensor,
@@ -386,103 +331,58 @@ def forward(
386331
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
387332
The denoised latents using provided inputs as conditioning.
388333
"""
389-
batch_size, channel, height, width = hidden_states.shape
390-
patch_height, patch_width = height // self.config.patch_size, width // self.config.patch_size
391-
do_cfg = negative_prompt_embeds is not None
392-
393-
if do_cfg:
394-
assert batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0], "batch size mismatch in CFG mode"
395-
else:
396-
assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
334+
height, width = hidden_states.shape[-2:]
335+
text_seq_length = encoder_hidden_states.shape[1]
397336

398-
hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
399-
hidden_states, prompt_embeds, negative_prompt_embeds
400-
)
337+
hidden_states = self.patch_embed(
338+
hidden_states, encoder_hidden_states
339+
) # takes care of adding positional embeddings too.
401340
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
402341

403-
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
404-
emb_cond, emb_uncond = emb.chunk(2)
405-
406-
# prepare image_rotary__emb
407-
image_rotary_emb = self.get_rope_embedding(
408-
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
409-
)
410-
411-
######################
412-
# prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
413-
# negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
414-
prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
415-
negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
342+
encoder_hidden_states = hidden_states[:, :text_seq_length]
343+
hidden_states = hidden_states[:, text_seq_length:]
416344

417-
hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
418-
hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
345+
for index_block, block in enumerate(self.transformer_blocks):
346+
if torch.is_grad_enabled() and self.gradient_checkpointing:
419347

420-
emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
421-
emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
422-
######################
348+
def create_custom_forward(module):
349+
def custom_forward(*inputs):
350+
return module(*inputs)
423351

424-
encoder_hidden_states_cond = prompt_embeds
425-
encoder_hidden_states_uncond = negative_prompt_embeds
352+
return custom_forward
426353

427-
for index_block, block in enumerate(self.transformer_blocks):
428-
if torch.is_grad_enabled() and self.gradient_checkpointing:
429-
...
430-
else:
431-
hidden_states_cond, encoder_hidden_states_cond = block(
432-
hidden_states=hidden_states_cond,
433-
encoder_hidden_states=encoder_hidden_states_cond,
434-
emb=emb_cond, # refactor later
435-
image_rotary_emb=image_rotary_emb,
436-
# image_rotary_emb=None,
354+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
355+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
356+
create_custom_forward(block),
357+
hidden_states,
358+
encoder_hidden_states,
359+
emb,
360+
**ckpt_kwargs,
437361
)
438-
###########################
439-
# hidden_states_cond, encoder_hidden_states_cond = (
440-
# self.norm_out.norm(hidden_states_cond),
441-
# self.norm_out.norm(encoder_hidden_states_cond),
442-
# )
443-
###########################
444-
445-
hidden_states_uncond, encoder_hidden_states_uncond = block(
446-
hidden_states=hidden_states_uncond,
447-
encoder_hidden_states=encoder_hidden_states_uncond,
448-
emb=emb_uncond, # refactor later
449-
image_rotary_emb=image_rotary_emb,
450-
# image_rotary_emb=None,
362+
else:
363+
hidden_states, encoder_hidden_states = block(
364+
hidden_states=hidden_states,
365+
encoder_hidden_states=encoder_hidden_states,
366+
emb=emb,
451367
)
452-
###########################
453-
# hidden_states_uncond, encoder_hidden_states_uncond = (
454-
# self.norm_out.norm(hidden_states_uncond),
455-
# self.norm_out.norm(encoder_hidden_states_uncond),
456-
# )
457-
###########################
458-
459-
hidden_states_cond = self.norm_out(hidden_states_cond, emb_cond) # 结果对应于megatron里的final_layer_input
460-
hidden_states_uncond = self.norm_out(hidden_states_uncond, emb_uncond) # 结果对应于megatron里的final_layer_input
461-
hidden_states_cond = self.proj_out(hidden_states_cond) # (batch_size, height*width, patch_size*patch_size*out_channels)
462-
hidden_states_uncond = self.proj_out(hidden_states_uncond) # (batch_size, height*width, patch_size*patch_size*out_channels)
368+
369+
hidden_states = self.norm_out(hidden_states, emb)
370+
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
463371

464372
# unpatchify
465373
patch_size = self.config.patch_size
466374
height = height // patch_size
467375
width = width // patch_size
468376

469-
hidden_states_cond = hidden_states_cond.reshape(
470-
shape=(hidden_states_cond.shape[0], height, width, self.out_channels, patch_size, patch_size)
471-
)
472-
hidden_states_cond = torch.einsum("nhwcpq->nchpwq", hidden_states_cond)
473-
output_cond = hidden_states_cond.reshape(
474-
shape=(hidden_states_cond.shape[0], self.out_channels, height * patch_size, width * patch_size)
475-
)
476-
477-
hidden_states_uncond = hidden_states_uncond.reshape(
478-
shape=(hidden_states_uncond.shape[0], height, width, self.out_channels, patch_size, patch_size)
377+
hidden_states = hidden_states.reshape(
378+
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
479379
)
480-
hidden_states_uncond = torch.einsum("nhwcpq->nchpwq", hidden_states_uncond)
481-
output_uncond = hidden_states_uncond.reshape(
482-
shape=(hidden_states_uncond.shape[0], self.out_channels, height * patch_size, width * patch_size)
380+
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
381+
output = hidden_states.reshape(
382+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
483383
)
484384

485385
if not return_dict:
486-
return (output_cond, output_uncond)
386+
return (output,)
487387

488-
return Transformer2DModelOutput(sample=output_cond), Transformer2DModelOutput(sample=output_uncond)
388+
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)