Skip to content

Commit f826aec

Browse files
committed
[cogview4] implement CogView4 transformer block
Implement CogView4 transformer block following the Megatron architecture: - Add multi-modulate and multi-gate mechanisms for adaptive layer normalization - Implement dual-stream attention with encoder-decoder structure - Add feed-forward network with GELU activation - Support rotary position embeddings for image tokens The implementation follows the original CogView4 architecture while adapting it to work within the diffusers framework.
1 parent 3bd6d30 commit f826aec

File tree

1 file changed

+116
-65
lines changed

1 file changed

+116
-65
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 116 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717

1818
import torch
1919
import torch.nn as nn
20+
import torch.nn.functional as F
2021

2122
from ...configuration_utils import ConfigMixin, register_to_config
2223
from ...models.attention import FeedForward
2324
from ...models.attention_processor import (
2425
Attention,
2526
AttentionProcessor,
26-
CogVideoXAttnProcessor2_0,
27+
CogView4AttnProcessor,
2728
)
2829
from ...models.modeling_utils import ModelMixin
2930
from ...models.normalization import AdaLayerNormContinuous
@@ -32,6 +33,7 @@
3233
from ..modeling_outputs import Transformer2DModelOutput
3334
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3435

36+
3537
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3638

3739

@@ -60,6 +62,8 @@ def __init__(
6062
super().__init__()
6163

6264
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
65+
self.adaln = self.norm1.linear
66+
self.layernorm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
6367

6468
self.attn1 = Attention(
6569
query_dim=dim,
@@ -69,66 +73,109 @@ def __init__(
6973
bias=True,
7074
qk_norm="layer_norm",
7175
elementwise_affine=False,
72-
eps=1e-6,
73-
processor=CogVideoXAttnProcessor2_0(),
76+
eps=1e-5,
77+
processor=CogView4AttnProcessor(),
7478
)
7579

76-
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
77-
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
78-
7980
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
8081

82+
def multi_modulate(self, hidden_states, encoder_hidden_states, factors) -> torch.Tensor:
83+
n_sample, n_type, h = factors[0].shape
84+
shift_factor, scale_factor = factors[0].view(-1, h), factors[1].view(-1, h)
85+
86+
shift_factor_hidden_states, shift_factor_encoder_hidden_states = shift_factor.chunk(2, dim=0)
87+
scale_factor_hidden_states, scale_factor_encoder_hidden_states = scale_factor.chunk(2, dim=0)
88+
89+
hidden_states = torch.addcmul(shift_factor_hidden_states, hidden_states, (1 + scale_factor_hidden_states))
90+
encoder_hidden_states = torch.addcmul(
91+
shift_factor_encoder_hidden_states, encoder_hidden_states, (1 + scale_factor_encoder_hidden_states)
92+
)
93+
94+
return hidden_states, encoder_hidden_states
95+
96+
def multi_gate(self, hidden_states, encoder_hidden_states, factor):
97+
batch_size, seq_len, hidden_dim = hidden_states.shape
98+
gate_factor = factor.view(-1, hidden_dim)
99+
gate_factor_hidden_states, gate_factor_encoder_hidden_states = gate_factor.chunk(2, dim=0)
100+
hidden_states = gate_factor_hidden_states * hidden_states
101+
encoder_hidden_states = gate_factor_encoder_hidden_states * encoder_hidden_states
102+
return hidden_states, encoder_hidden_states
103+
81104
def forward(
82105
self,
83106
hidden_states: torch.Tensor,
84107
encoder_hidden_states: torch.Tensor,
85-
emb: torch.Tensor,
108+
time_embedding: torch.Tensor = None,
109+
image_rotary_emb: torch.Tensor = None,
86110
**kwargs,
87111
) -> torch.Tensor:
88-
text_seq_length = encoder_hidden_states.size(1)
89-
90-
# norm & modulate
91-
(
92-
norm_hidden_states,
93-
gate_msa,
94-
shift_mlp,
95-
scale_mlp,
96-
gate_mlp,
97-
norm_encoder_hidden_states,
98-
c_gate_msa,
99-
c_shift_mlp,
100-
c_scale_mlp,
101-
c_gate_mlp,
102-
) = self.norm1(hidden_states, encoder_hidden_states, emb)
103-
104-
# attention
105-
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
106-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **kwargs
107-
)
108-
109-
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
110-
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
112+
batch_size, encoder_hidden_states_len, hidden_dim = encoder_hidden_states.shape
113+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
111114

112-
# norm & modulate
113-
norm_hidden_states = self.norm2(hidden_states)
114-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
115+
residual = hidden_states
115116

116-
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
117-
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
117+
# time_embedding embedding, [n_sample, h]
118+
assert time_embedding is not None
118119

119-
# feed-forward
120-
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
121-
ff_output = self.ff(norm_hidden_states)
120+
layernorm_factor = (
121+
self.adaln(time_embedding)
122+
.view(
123+
time_embedding.shape[0],
124+
6,
125+
2,
126+
hidden_states.shape[-1],
127+
)
128+
.permute(1, 2, 0, 3)
129+
.contiguous()
130+
)
122131

123-
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
124-
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
132+
##############################################################
133+
# Optional Input Layer norm
134+
hidden_states = self.layernorm(hidden_states)
135+
hidden_states, encoder_hidden_states = self.multi_modulate(
136+
hidden_states=hidden_states[:, encoder_hidden_states_len:],
137+
encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
138+
factors=(layernorm_factor[0], layernorm_factor[1]),
139+
)
140+
hidden_states, encoder_hidden_states = self.attn1(
141+
hidden_states=hidden_states,
142+
encoder_hidden_states=encoder_hidden_states,
143+
image_rotary_emb=image_rotary_emb,
144+
)
145+
hidden_states, encoder_hidden_states = self.multi_gate(
146+
hidden_states=hidden_states,
147+
encoder_hidden_states=encoder_hidden_states,
148+
factor=layernorm_factor[2],
149+
)
150+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
151+
hidden_states += residual
152+
153+
residual = hidden_states
154+
##############################################################
155+
hidden_states = self.layernorm(hidden_states)
156+
hidden_states, encoder_hidden_states = self.multi_modulate(
157+
hidden_states=hidden_states[:, encoder_hidden_states_len:],
158+
encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
159+
factors=(layernorm_factor[3], layernorm_factor[4]),
160+
)
161+
hidden_states = self.ff(hidden_states)
162+
encoder_hidden_states = self.ff(encoder_hidden_states)
163+
hidden_states, encoder_hidden_states = self.multi_gate(
164+
hidden_states=hidden_states,
165+
encoder_hidden_states=encoder_hidden_states,
166+
factor=layernorm_factor[5],
167+
)
168+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
169+
hidden_states += residual
125170

126-
if hidden_states.dtype == torch.float16:
127-
hidden_states = hidden_states.clip(-65504, 65504)
128-
if encoder_hidden_states.dtype == torch.float16:
129-
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
171+
##############################################################
172+
hidden_states, encoder_hidden_states = (
173+
hidden_states[:, :encoder_hidden_states_len],
174+
hidden_states[:, encoder_hidden_states_len:],
175+
)
130176
return hidden_states, encoder_hidden_states
131177

178+
132179
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
133180
r"""
134181
Args:
@@ -335,7 +382,8 @@ def get_rope_embedding(self, height, width, target_h, target_w, device):
335382
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
336383
freqs = freqs.reshape(height * width, -1)
337384

338-
return freqs.cos(), freqs.sin()
385+
return freqs
386+
# return freqs.cos(), freqs.sin()
339387

340388
def forward(
341389
self,
@@ -391,28 +439,31 @@ def forward(
391439
image_rotary_emb = self.get_rope_embedding(
392440
patch_height, patch_width, target_h=patch_height, target_w=patch_width, device=hidden_states.device
393441
)
442+
# image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
443+
# image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
394444

445+
######################
395446
# 2. Conditional embeddings
396-
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
397-
temb_cond, temb_uncond = temb.chunk(2)
398-
hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
399-
hidden_states, prompt_embeds, negative_prompt_embeds
400-
)
401-
hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
447+
# temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
448+
# temb_cond, temb_uncond = temb.chunk(2)
449+
# hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
450+
# hidden_states, prompt_embeds, negative_prompt_embeds
451+
# )
452+
# hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
453+
# encoder_hidden_states_cond = prompt_embeds
454+
# encoder_hidden_states_uncond = negative_prompt_embeds
455+
456+
prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
457+
negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_16_32.pt")[None, ::]
458+
459+
hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")[None, ::]
460+
hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")[None, ::]
461+
462+
temb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")[None, ::]
463+
temb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")[None, ::]
464+
402465
encoder_hidden_states_cond = prompt_embeds
403466
encoder_hidden_states_uncond = negative_prompt_embeds
404-
405-
######################
406-
# prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
407-
# negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
408-
# prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
409-
# negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
410-
#
411-
# hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
412-
# hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
413-
#
414-
# emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
415-
# emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
416467
######################
417468

418469
for index_block, block in enumerate(self.transformer_blocks):
@@ -423,13 +474,13 @@ def forward(
423474
hidden_states_cond, encoder_hidden_states_cond = block(
424475
hidden_states=hidden_states_cond,
425476
encoder_hidden_states=encoder_hidden_states_cond,
426-
emb=temb_cond,
477+
time_embedding=temb_cond,
427478
image_rotary_emb=image_rotary_emb,
428479
)
429480
hidden_states_uncond, encoder_hidden_states_uncond = block(
430481
hidden_states=hidden_states_uncond,
431482
encoder_hidden_states=encoder_hidden_states_uncond,
432-
emb=temb_uncond,
483+
time_embedding=temb_uncond,
433484
image_rotary_emb=image_rotary_emb,
434485
)
435486

0 commit comments

Comments
 (0)