Skip to content

Commit 0ab7260

Browse files
committed
[WIP][cogview4]: implement initial CogView4 pipeline
Implement the basic CogView4 pipeline structure with the following changes: - Add CogView4 pipeline implementation - Implement DDIM scheduler for CogView4 - Add CogView3Plus transformer architecture - Update embedding models Current limitations: - CFG implementation uses padding for sequence length alignment - Need to verify transformer inference alignment with Megatron TODO: - Consider separate forward passes for condition/uncondition instead of padding approach
1 parent e6b8907 commit 0ab7260

File tree

4 files changed

+233
-187
lines changed

4 files changed

+233
-187
lines changed

src/diffusers/models/embeddings.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
812812

813813
return (hidden_states + pos_embed).to(hidden_states.dtype)
814814

815+
815816
class CogView4PatchEmbed(nn.Module):
816817
def __init__(
817818
self,
@@ -832,55 +833,35 @@ def __init__(
832833

833834
# Linear projection for text embeddings
834835
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
835-
#TODO:这里需要改成RotaryEmbed
836-
pos_embed = get_2d_sincos_pos_embed(
837-
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
838-
)
839-
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
840-
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
841836

842-
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
837+
def forward(
838+
self, hidden_states: torch.Tensor, prompt_embeds: torch.Tensor, negative_prompt_embeds: torch.Tensor | None
839+
) -> torch.Tensor:
843840
batch_size, channel, height, width = hidden_states.shape
844841

845842
if height % self.patch_size != 0 or width % self.patch_size != 0:
846843
raise ValueError("Height and width must be divisible by patch size")
847844

848-
height = height // self.patch_size
849-
width = width // self.patch_size
850-
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
851-
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
852-
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
845+
patch_height = height // self.patch_size
846+
patch_width = width // self.patch_size
853847

854-
# Project the patches
855-
hidden_states = self.proj(hidden_states)
856-
prompt_encoder_hidden_states = []
857-
negative_prompt_encoder_hidden_states = []
858-
859-
for i in range(0, batch_size, 2):
860-
prompt_embeds = encoder_hidden_states[i, :, :] # [seq_len, hidden_size]
861-
negative_embeds = encoder_hidden_states[i + 1, :, :] # [seq_len, hidden_size]
862-
mask = negative_embeds.abs().sum(dim=-1) > 0
863-
seq_len_neg = mask.sum().item() # 非零部分的数量
864-
negative_embeds_valid = negative_embeds[:seq_len_neg, :] # [seq_len_neg, hidden_size]
865-
prompt_encoder_hidden_states.append(prompt_embeds)
866-
negative_prompt_encoder_hidden_states.append(negative_embeds_valid)
867-
prompt_encoder_hidden_states = torch.stack(prompt_encoder_hidden_states, dim=0)
868-
negative_prompt_encoder_hidden_states = torch.stack(negative_prompt_encoder_hidden_states, dim=0)
869-
prompt_text_length = prompt_encoder_hidden_states.shape[1]
870-
negative_prompt_text_length = negative_prompt_encoder_hidden_states.shape[1]
871-
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
872-
prompt_text_pos_embed = torch.zeros(
873-
(prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
848+
# b, c, h, w -> b, c, patch_height, patch_size, patch_width, patch_size
849+
# -> b, patch_height, patch_width, c, patch_size, patch_size
850+
# -> b, patch_height * patch_width, c * patch_size * patch_size
851+
hidden_states = (
852+
hidden_states.reshape(batch_size, channel, patch_height, self.patch_size, patch_width, self.patch_size)
853+
.permute(0, 2, 4, 1, 3, 5)
854+
.reshape(batch_size, patch_height * patch_width, channel * self.patch_size * self.patch_size)
874855
)
875-
negative_prompt_text_pos_embed = torch.zeros(
876-
(negative_prompt_text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
877-
)
878-
prompt_pos_embed = torch.cat([prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
879-
negative_prompt_pos_embed = torch.cat([negative_prompt_text_pos_embed, image_pos_embed], dim=0)[None, ...]
880-
# TODO: 拼接哼一个完整的 pos_embed 以及拼接 Rope Embed
881-
pos_embed = torch.cat([prompt_pos_embed, negative_prompt_pos_embed], dim=0)
882-
hidden_states = hidden_states + pos_embed.to(hidden_states.dtype)
883-
return hidden_states
856+
857+
# project
858+
hidden_states = self.proj(hidden_states) # embed_dim: 64 -> 4096
859+
prompt_embeds = self.text_proj(prompt_embeds) # embed_dim: 4096 -> 4096
860+
if negative_prompt_embeds is not None:
861+
negative_prompt_embeds = self.text_proj(negative_prompt_embeds) # embed_dim: 4096 -> 4096
862+
863+
return hidden_states, prompt_embeds, negative_prompt_embeds
864+
884865

885866
def get_3d_rotary_pos_embed(
886867
embed_dim,

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def forward(
8484
hidden_states: torch.Tensor,
8585
encoder_hidden_states: torch.Tensor,
8686
emb: torch.Tensor,
87+
**kwargs,
8788
) -> torch.Tensor:
8889
text_seq_length = encoder_hidden_states.size(1)
8990

@@ -103,7 +104,7 @@ def forward(
103104

104105
# attention
105106
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
106-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
107+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **kwargs
107108
)
108109

109110
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
@@ -191,14 +192,15 @@ def __init__(
191192
# Each of these are sincos embeddings of shape 2 * condition_dim
192193
self.pooled_projection_dim = 3 * 2 * condition_dim
193194

194-
# self.patch_embed = CogView3PlusPatchEmbed(
195-
# in_channels=in_channels,
196-
# hidden_size=self.inner_dim,
197-
# patch_size=patch_size,
198-
# text_hidden_size=text_embed_dim,
199-
# pos_embed_max_size=pos_embed_max_size,
200-
# )
201-
# TODO: 兼容性适配
195+
self.max_h = 256
196+
self.max_w = 256
197+
self.rope = self.prepare_rope(
198+
embed_dim=self.config.attention_head_dim,
199+
max_h=self.max_h,
200+
max_w=self.max_w,
201+
rotary_base=10000
202+
)
203+
202204
self.patch_embed = CogView4PatchEmbed(
203205
in_channels=in_channels,
204206
hidden_size=self.inner_dim,
@@ -300,10 +302,55 @@ def _set_gradient_checkpointing(self, module, value=False):
300302
if hasattr(module, "gradient_checkpointing"):
301303
module.gradient_checkpointing = value
302304

305+
@staticmethod
306+
def prepare_rope(embed_dim, max_h, max_w, rotary_base):
307+
dim_h = embed_dim // 2
308+
dim_w = embed_dim // 2
309+
h_inv_freq = 1.0 / (
310+
rotary_base ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
311+
)
312+
w_inv_freq = 1.0 / (
313+
rotary_base ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
314+
)
315+
h_seq = torch.arange(max_h, dtype=h_inv_freq.dtype)
316+
w_seq = torch.arange(max_w, dtype=w_inv_freq.dtype)
317+
freqs_h = torch.outer(h_seq, h_inv_freq)
318+
freqs_w = torch.outer(w_seq, w_inv_freq)
319+
return (freqs_h, freqs_w)
320+
321+
def get_rope_embedding(self, height, width, target_h, target_w, device):
322+
# Get pre-computed frequencies
323+
freqs_h, freqs_w = self.rope
324+
325+
h_idx = torch.arange(height)
326+
w_idx = torch.arange(width)
327+
inner_h_idx = (h_idx * self.max_h) // target_h
328+
inner_w_idx = (w_idx * self.max_w) // target_w
329+
330+
freqs_h = freqs_h[inner_h_idx].to(device)
331+
freqs_w = freqs_w[inner_w_idx].to(device)
332+
333+
# Create position matrices for height and width
334+
# [height, 1, dim//4] and [1, width, dim//4]
335+
freqs_h = freqs_h.unsqueeze(1)
336+
freqs_w = freqs_w.unsqueeze(0)
337+
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
338+
freqs_h = freqs_h.expand(height, width, -1)
339+
freqs_w = freqs_w.expand(height, width, -1)
340+
341+
# Concatenate along last dimension to get [height, width, dim//2]
342+
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
343+
344+
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
345+
freqs = freqs.reshape(height*width, -1)
346+
347+
return freqs.cos(), freqs.sin()
348+
303349
def forward(
304350
self,
305351
hidden_states: torch.Tensor,
306-
encoder_hidden_states: torch.Tensor,
352+
prompt_embeds: torch.Tensor,
353+
negative_prompt_embeds: torch.Tensor | None,
307354
timestep: torch.LongTensor,
308355
original_size: torch.Tensor,
309356
target_size: torch.Tensor,
@@ -338,16 +385,27 @@ def forward(
338385
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
339386
The denoised latents using provided inputs as conditioning.
340387
"""
341-
height, width = hidden_states.shape[-2:]
342-
text_seq_length = encoder_hidden_states.shape[1]
388+
batch_size, channel, height, width = hidden_states.shape
389+
patch_height, patch_width = height // self.config.patch_size, width // self.config.patch_size
390+
do_cfg = negative_prompt_embeds is not None
343391

344-
hidden_states = self.patch_embed(
345-
hidden_states, encoder_hidden_states
346-
) # takes care of adding positional embeddings too.
347-
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
392+
if do_cfg:
393+
assert batch_size == prompt_embeds.shape[0] + negative_prompt_embeds.shape[0], "batch size mismatch in CFG mode"
394+
else:
395+
assert batch_size == prompt_embeds.shape[0], "batch size mismatch in non-CFG mode"
396+
397+
hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
398+
hidden_states, prompt_embeds, negative_prompt_embeds
399+
)
348400

349-
encoder_hidden_states = hidden_states[:, :text_seq_length]
350-
hidden_states = hidden_states[:, text_seq_length:]
401+
encoder_hidden_states = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
402+
403+
# prepare image_rotary__emb
404+
image_rotary_emb = self.get_rope_embedding(
405+
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
406+
)
407+
408+
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
351409

352410
for index_block, block in enumerate(self.transformer_blocks):
353411
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -363,17 +421,19 @@ def custom_forward(*inputs):
363421
create_custom_forward(block),
364422
hidden_states,
365423
encoder_hidden_states,
366-
emb,
424+
emb=emb,
425+
image_rotary_emb=image_rotary_emb,
367426
**ckpt_kwargs,
368427
)
369428
else:
370429
hidden_states, encoder_hidden_states = block(
371430
hidden_states=hidden_states,
372431
encoder_hidden_states=encoder_hidden_states,
373432
emb=emb,
433+
image_rotary_emb=image_rotary_emb,
374434
)
375435

376-
hidden_states = self.norm_out(hidden_states, emb)
436+
hidden_states = self.norm_out(hidden_states, emb) # 结果对应于megatron里的final_layer_input
377437
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
378438

379439
# unpatchify

0 commit comments

Comments
 (0)