Skip to content

Commit de274f3

Browse files
committed
[cogview4][WIP]: update final normalization in CogView4 transformer
Refactored the final normalization layer in CogView4 transformer to use separate layernorm and AdaLN operations instead of combined AdaLayerNormContinuous. This matches the original implementation but needs validation. Needs verification against reference implementation.
1 parent 6a3a07f commit de274f3

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def forward(
170170

171171
##############################################################
172172
hidden_states, encoder_hidden_states = (
173-
hidden_states[:, :encoder_hidden_states_len],
174173
hidden_states[:, encoder_hidden_states_len:],
174+
hidden_states[:, :encoder_hidden_states_len],
175175
)
176176
return hidden_states, encoder_hidden_states
177177

@@ -240,6 +240,8 @@ def __init__(
240240
embed_dim=self.config.attention_head_dim, max_h=self.max_h, max_w=self.max_w, rotary_base=10000
241241
)
242242

243+
self.layernorm = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-5)
244+
243245
self.patch_embed = CogView4PatchEmbed(
244246
in_channels=in_channels,
245247
hidden_size=self.inner_dim,
@@ -267,11 +269,15 @@ def __init__(
267269
]
268270
)
269271

272+
######################################
270273
self.norm_out = AdaLayerNormContinuous(
271274
embedding_dim=self.inner_dim,
272275
conditioning_embedding_dim=time_embed_dim,
273276
elementwise_affine=False,
274277
)
278+
self.adaln_final = self.norm_out.linear
279+
######################################
280+
275281
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
276282

277283
self.gradient_checkpointing = False
@@ -484,14 +490,28 @@ def forward(
484490
image_rotary_emb=image_rotary_emb,
485491
)
486492

487-
hidden_states_cond, encoder_hidden_states_cond = (
488-
self.norm_out(hidden_states_cond, temb_cond),
489-
self.norm_out(encoder_hidden_states_cond, temb_cond),
490-
)
491-
hidden_states_uncond, encoder_hidden_states_uncond = (
492-
self.norm_out(hidden_states_uncond, temb_uncond),
493-
self.norm_out(encoder_hidden_states_uncond, temb_uncond),
494-
)
493+
#################################################
494+
# hidden_states_cond, encoder_hidden_states_cond = (
495+
# self.norm_out(hidden_states_cond, temb_cond),
496+
# self.norm_out(encoder_hidden_states_cond, temb_cond),
497+
# )
498+
# hidden_states_uncond, encoder_hidden_states_uncond = (
499+
# self.norm_out(hidden_states_uncond, temb_uncond),
500+
# self.norm_out(encoder_hidden_states_uncond, temb_uncond),
501+
# )
502+
503+
hidden_states_cond = self.layernorm(hidden_states_cond)
504+
hidden_states_uncond = self.layernorm(hidden_states_uncond)
505+
encoder_hidden_states_cond = self.layernorm(encoder_hidden_states_cond)
506+
encoder_hidden_states_uncond = self.layernorm(encoder_hidden_states_uncond)
507+
508+
shift_cond, scale_cond = self.adaln_final(temb_cond).chunk(2, dim=-1)
509+
shift_uncond, scale_uncond = self.adaln_final(temb_uncond).chunk(2, dim=-1)
510+
511+
hidden_states_cond = hidden_states_cond * (1 + scale_cond) + shift_cond
512+
hidden_states_uncond = hidden_states_uncond * (1 + scale_uncond) + shift_uncond
513+
#################################################
514+
495515
hidden_states_cond = self.proj_out(hidden_states_cond)
496516
hidden_states_uncond = self.proj_out(hidden_states_uncond)
497517

0 commit comments

Comments
 (0)