Skip to content

Commit f1ccdd2

Browse files
Update transformer_cogview4.py
1 parent ebbaa5b commit f1ccdd2

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ def forward(
176176
return hidden_states, encoder_hidden_states
177177

178178

179+
def swap_scale_shift(weight, dim):
180+
shift, scale = weight.chunk(2, dim=0)
181+
new_weight = torch.cat([scale, shift], dim=0)
182+
return new_weight
183+
184+
179185
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
180186
r"""
181187
Args:
@@ -276,7 +282,10 @@ def __init__(
276282
elementwise_affine=False,
277283
)
278284
self.adaln_final = self.norm_out.linear
279-
######################################
285+
# with torch.no_grad():
286+
# w = self.norm_out.linear.weight.data.clone()
287+
# w_swapped = swap_scale_shift(w, dim=0)
288+
# self.adaln_final.weight.data.copy_(w_swapped)
280289

281290
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
282291

@@ -445,6 +454,7 @@ def forward(
445454
image_rotary_emb = self.get_rope_embedding(
446455
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
447456
)
457+
## TODO: @Oleehy Remove it after debugging
448458
# image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
449459
# image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
450460

@@ -457,6 +467,7 @@ def forward(
457467
)
458468
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
459469

470+
# Todo: @Oleehy Remove it after debugging
460471
# prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
461472
# negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
462473
#
@@ -488,27 +499,15 @@ def forward(
488499
image_rotary_emb=image_rotary_emb,
489500
)
490501

491-
#################################################
492-
# hidden_states_cond, encoder_hidden_states_cond = (
493-
# self.norm_out(hidden_states_cond, temb_cond),
494-
# self.norm_out(encoder_hidden_states_cond, temb_cond),
495-
# )
496-
# hidden_states_uncond, encoder_hidden_states_uncond = (
497-
# self.norm_out(hidden_states_uncond, temb_uncond),
498-
# self.norm_out(encoder_hidden_states_uncond, temb_uncond),
499-
# )
500-
501-
hidden_states_cond = self.layernorm(hidden_states_cond)
502-
hidden_states_uncond = self.layernorm(hidden_states_uncond)
503-
encoder_hidden_states_cond = self.layernorm(encoder_hidden_states_cond)
504-
encoder_hidden_states_uncond = self.layernorm(encoder_hidden_states_uncond)
505-
506-
shift_cond, scale_cond = self.adaln_final(temb_cond).chunk(2, dim=-1)
507-
shift_uncond, scale_uncond = self.adaln_final(temb_uncond).chunk(2, dim=-1)
508-
509-
hidden_states_cond = hidden_states_cond * (1 + scale_cond) + shift_cond
510-
hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond) + shift_uncond
511-
#################################################
502+
# Todo: @Oleehy Check if this is the right implementation
503+
hidden_states_cond, encoder_hidden_states_cond = (
504+
self.norm_out(hidden_states_cond, temb_cond),
505+
self.norm_out(encoder_hidden_states_cond, temb_cond),
506+
)
507+
hidden_states_uncond, encoder_hidden_states_uncond = (
508+
self.norm_out(hidden_states_uncond, temb_uncond),
509+
self.norm_out(encoder_hidden_states_uncond, temb_uncond),
510+
)
512511

513512
hidden_states_cond = self.proj_out(hidden_states_cond)
514513
hidden_states_uncond = self.proj_out(hidden_states_uncond)

0 commit comments

Comments
 (0)