@@ -176,6 +176,12 @@ def forward(
176176 return hidden_states , encoder_hidden_states
177177
178178
179+ def swap_scale_shift (weight , dim ):
180+ shift , scale = weight .chunk (2 , dim = 0 )
181+ new_weight = torch .cat ([scale , shift ], dim = 0 )
182+ return new_weight
183+
184+
179185class CogView4Transformer2DModel (ModelMixin , ConfigMixin ):
180186 r"""
181187 Args:
@@ -276,7 +282,10 @@ def __init__(
276282 elementwise_affine = False ,
277283 )
278284 self .adaln_final = self .norm_out .linear
279- ######################################
285+ # with torch.no_grad():
286+ # w = self.norm_out.linear.weight.data.clone()
287+ # w_swapped = swap_scale_shift(w, dim=0)
288+ # self.adaln_final.weight.data.copy_(w_swapped)
280289
281290 self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
282291
@@ -445,6 +454,7 @@ def forward(
445454 image_rotary_emb = self .get_rope_embedding (
446455 patch_height , patch_width , target_h = patch_height , target_w = patch_width , device = hidden_states .device
447456 )
457+ ## TODO: @Oleehy Remove it after debugging
448458 # image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
449459 # image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
450460
@@ -457,6 +467,7 @@ def forward(
457467 )
458468 hidden_states_cond , hidden_states_uncond = hidden_states .chunk (2 )
459469
470+ # Todo: @Oleehy Remove it after debugging
460471 # prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
461472 # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
462473 #
@@ -488,27 +499,15 @@ def forward(
488499 image_rotary_emb = image_rotary_emb ,
489500 )
490501
491- #################################################
492- # hidden_states_cond, encoder_hidden_states_cond = (
493- # self.norm_out(hidden_states_cond, temb_cond),
494- # self.norm_out(encoder_hidden_states_cond, temb_cond),
495- # )
496- # hidden_states_uncond, encoder_hidden_states_uncond = (
497- # self.norm_out(hidden_states_uncond, temb_uncond),
498- # self.norm_out(encoder_hidden_states_uncond, temb_uncond),
499- # )
500-
501- hidden_states_cond = self .layernorm (hidden_states_cond )
502- hidden_states_uncond = self .layernorm (hidden_states_uncond )
503- encoder_hidden_states_cond = self .layernorm (encoder_hidden_states_cond )
504- encoder_hidden_states_uncond = self .layernorm (encoder_hidden_states_uncond )
505-
506- shift_cond , scale_cond = self .adaln_final (temb_cond ).chunk (2 , dim = - 1 )
507- shift_uncond , scale_uncond = self .adaln_final (temb_uncond ).chunk (2 , dim = - 1 )
508-
509- hidden_states_cond = hidden_states_cond * (1 + scale_cond ) + shift_cond
510- hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond ) + shift_uncond
511- #################################################
502+ # Todo: @Oleehy Check if this is the right implementation
503+ hidden_states_cond , encoder_hidden_states_cond = (
504+ self .norm_out (hidden_states_cond , temb_cond ),
505+ self .norm_out (encoder_hidden_states_cond , temb_cond ),
506+ )
507+ hidden_states_uncond , encoder_hidden_states_uncond = (
508+ self .norm_out (hidden_states_uncond , temb_uncond ),
509+ self .norm_out (encoder_hidden_states_uncond , temb_uncond ),
510+ )
512511
513512 hidden_states_cond = self .proj_out (hidden_states_cond )
514513 hidden_states_uncond = self .proj_out (hidden_states_uncond )
0 commit comments