@@ -170,8 +170,8 @@ def forward(
170170
171171 ##############################################################
172172 hidden_states , encoder_hidden_states = (
173- hidden_states [:, :encoder_hidden_states_len ],
174173 hidden_states [:, encoder_hidden_states_len :],
174+ hidden_states [:, :encoder_hidden_states_len ],
175175 )
176176 return hidden_states , encoder_hidden_states
177177
@@ -240,6 +240,8 @@ def __init__(
240240 embed_dim = self .config .attention_head_dim , max_h = self .max_h , max_w = self .max_w , rotary_base = 10000
241241 )
242242
243+ self .layernorm = nn .LayerNorm (self .inner_dim , elementwise_affine = False , eps = 1e-5 )
244+
243245 self .patch_embed = CogView4PatchEmbed (
244246 in_channels = in_channels ,
245247 hidden_size = self .inner_dim ,
@@ -267,11 +269,15 @@ def __init__(
267269 ]
268270 )
269271
272+ ######################################
270273 self .norm_out = AdaLayerNormContinuous (
271274 embedding_dim = self .inner_dim ,
272275 conditioning_embedding_dim = time_embed_dim ,
273276 elementwise_affine = False ,
274277 )
278+ self .adaln_final = self .norm_out .linear
279+ ######################################
280+
275281 self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
276282
277283 self .gradient_checkpointing = False
@@ -484,14 +490,28 @@ def forward(
484490 image_rotary_emb = image_rotary_emb ,
485491 )
486492
487- hidden_states_cond , encoder_hidden_states_cond = (
488- self .norm_out (hidden_states_cond , temb_cond ),
489- self .norm_out (encoder_hidden_states_cond , temb_cond ),
490- )
491- hidden_states_uncond , encoder_hidden_states_uncond = (
492- self .norm_out (hidden_states_uncond , temb_uncond ),
493- self .norm_out (encoder_hidden_states_uncond , temb_uncond ),
494- )
493+ #################################################
494+ # hidden_states_cond, encoder_hidden_states_cond = (
495+ # self.norm_out(hidden_states_cond, temb_cond),
496+ # self.norm_out(encoder_hidden_states_cond, temb_cond),
497+ # )
498+ # hidden_states_uncond, encoder_hidden_states_uncond = (
499+ # self.norm_out(hidden_states_uncond, temb_uncond),
500+ # self.norm_out(encoder_hidden_states_uncond, temb_uncond),
501+ # )
502+
503+ hidden_states_cond = self .layernorm (hidden_states_cond )
504+ hidden_states_uncond = self .layernorm (hidden_states_uncond )
505+ encoder_hidden_states_cond = self .layernorm (encoder_hidden_states_cond )
506+ encoder_hidden_states_uncond = self .layernorm (encoder_hidden_states_uncond )
507+
508+ shift_cond , scale_cond = self .adaln_final (temb_cond ).chunk (2 , dim = - 1 )
509+ shift_uncond , scale_uncond = self .adaln_final (temb_uncond ).chunk (2 , dim = - 1 )
510+
511+ hidden_states_cond = hidden_states_cond * (1 + scale_cond ) + shift_cond
512+ hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond ) + shift_uncond
513+ #################################################
514+
495515 hidden_states_cond = self .proj_out (hidden_states_cond )
496516 hidden_states_uncond = self .proj_out (hidden_states_uncond )
497517
0 commit comments