1717
1818import torch
1919import torch .nn as nn
20+ import torch .nn .functional as F
2021
2122from ...configuration_utils import ConfigMixin , register_to_config
2223from ...models .attention import FeedForward
2324from ...models .attention_processor import (
2425 Attention ,
2526 AttentionProcessor ,
26- CogVideoXAttnProcessor2_0 ,
27+ CogView4AttnProcessor ,
2728)
2829from ...models .modeling_utils import ModelMixin
2930from ...models .normalization import AdaLayerNormContinuous
3233from ..modeling_outputs import Transformer2DModelOutput
3334from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3435
36+
3537logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3638
3739
@@ -60,6 +62,8 @@ def __init__(
6062 super ().__init__ ()
6163
6264 self .norm1 = CogView3PlusAdaLayerNormZeroTextImage (embedding_dim = time_embed_dim , dim = dim )
65+ self .adaln = self .norm1 .linear
66+ self .layernorm = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
6367
6468 self .attn1 = Attention (
6569 query_dim = dim ,
@@ -69,66 +73,109 @@ def __init__(
6973 bias = True ,
7074 qk_norm = "layer_norm" ,
7175 elementwise_affine = False ,
72- eps = 1e-6 ,
73- processor = CogVideoXAttnProcessor2_0 (),
76+ eps = 1e-5 ,
77+ processor = CogView4AttnProcessor (),
7478 )
7579
76- self .norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
77- self .norm2_context = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
78-
7980 self .ff = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
8081
82+ def multi_modulate (self , hidden_states , encoder_hidden_states , factors ) -> torch .Tensor :
83+ n_sample , n_type , h = factors [0 ].shape
84+ shift_factor , scale_factor = factors [0 ].view (- 1 , h ), factors [1 ].view (- 1 , h )
85+
86+ shift_factor_hidden_states , shift_factor_encoder_hidden_states = shift_factor .chunk (2 , dim = 0 )
87+ scale_factor_hidden_states , scale_factor_encoder_hidden_states = scale_factor .chunk (2 , dim = 0 )
88+
89+ hidden_states = torch .addcmul (shift_factor_hidden_states , hidden_states , (1 + scale_factor_hidden_states ))
90+ encoder_hidden_states = torch .addcmul (
91+ shift_factor_encoder_hidden_states , encoder_hidden_states , (1 + scale_factor_encoder_hidden_states )
92+ )
93+
94+ return hidden_states , encoder_hidden_states
95+
96+ def multi_gate (self , hidden_states , encoder_hidden_states , factor ):
97+ batch_size , seq_len , hidden_dim = hidden_states .shape
98+ gate_factor = factor .view (- 1 , hidden_dim )
99+ gate_factor_hidden_states , gate_factor_encoder_hidden_states = gate_factor .chunk (2 , dim = 0 )
100+ hidden_states = gate_factor_hidden_states * hidden_states
101+ encoder_hidden_states = gate_factor_encoder_hidden_states * encoder_hidden_states
102+ return hidden_states , encoder_hidden_states
103+
81104 def forward (
82105 self ,
83106 hidden_states : torch .Tensor ,
84107 encoder_hidden_states : torch .Tensor ,
85- emb : torch .Tensor ,
108+ time_embedding : torch .Tensor = None ,
109+ image_rotary_emb : torch .Tensor = None ,
86110 ** kwargs ,
87111 ) -> torch .Tensor :
88- text_seq_length = encoder_hidden_states .size (1 )
89-
90- # norm & modulate
91- (
92- norm_hidden_states ,
93- gate_msa ,
94- shift_mlp ,
95- scale_mlp ,
96- gate_mlp ,
97- norm_encoder_hidden_states ,
98- c_gate_msa ,
99- c_shift_mlp ,
100- c_scale_mlp ,
101- c_gate_mlp ,
102- ) = self .norm1 (hidden_states , encoder_hidden_states , emb )
103-
104- # attention
105- attn_hidden_states , attn_encoder_hidden_states = self .attn1 (
106- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states , ** kwargs
107- )
108-
109- hidden_states = hidden_states + gate_msa .unsqueeze (1 ) * attn_hidden_states
110- encoder_hidden_states = encoder_hidden_states + c_gate_msa .unsqueeze (1 ) * attn_encoder_hidden_states
112+ batch_size , encoder_hidden_states_len , hidden_dim = encoder_hidden_states .shape
113+ hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
111114
112- # norm & modulate
113- norm_hidden_states = self .norm2 (hidden_states )
114- norm_hidden_states = norm_hidden_states * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
115+ residual = hidden_states
115116
116- norm_encoder_hidden_states = self . norm2_context ( encoder_hidden_states )
117- norm_encoder_hidden_states = norm_encoder_hidden_states * ( 1 + c_scale_mlp [:, None ]) + c_shift_mlp [:, None ]
117+ # time_embedding embedding, [n_sample, h]
118+ assert time_embedding is not None
118119
119- # feed-forward
120- norm_hidden_states = torch .cat ([norm_encoder_hidden_states , norm_hidden_states ], dim = 1 )
121- ff_output = self .ff (norm_hidden_states )
120+ layernorm_factor = (
121+ self .adaln (time_embedding )
122+ .view (
123+ time_embedding .shape [0 ],
124+ 6 ,
125+ 2 ,
126+ hidden_states .shape [- 1 ],
127+ )
128+ .permute (1 , 2 , 0 , 3 )
129+ .contiguous ()
130+ )
122131
123- hidden_states = hidden_states + gate_mlp .unsqueeze (1 ) * ff_output [:, text_seq_length :]
124- encoder_hidden_states = encoder_hidden_states + c_gate_mlp .unsqueeze (1 ) * ff_output [:, :text_seq_length ]
132+ ##############################################################
133+ # Optional Input Layer norm
134+ hidden_states = self .layernorm (hidden_states )
135+ hidden_states , encoder_hidden_states = self .multi_modulate (
136+ hidden_states = hidden_states [:, encoder_hidden_states_len :],
137+ encoder_hidden_states = hidden_states [:, :encoder_hidden_states_len ],
138+ factors = (layernorm_factor [0 ], layernorm_factor [1 ]),
139+ )
140+ hidden_states , encoder_hidden_states = self .attn1 (
141+ hidden_states = hidden_states ,
142+ encoder_hidden_states = encoder_hidden_states ,
143+ image_rotary_emb = image_rotary_emb ,
144+ )
145+ hidden_states , encoder_hidden_states = self .multi_gate (
146+ hidden_states = hidden_states ,
147+ encoder_hidden_states = encoder_hidden_states ,
148+ factor = layernorm_factor [2 ],
149+ )
150+ hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
151+ hidden_states += residual
152+
153+ residual = hidden_states
154+ ##############################################################
155+ hidden_states = self .layernorm (hidden_states )
156+ hidden_states , encoder_hidden_states = self .multi_modulate (
157+ hidden_states = hidden_states [:, encoder_hidden_states_len :],
158+ encoder_hidden_states = hidden_states [:, :encoder_hidden_states_len ],
159+ factors = (layernorm_factor [3 ], layernorm_factor [4 ]),
160+ )
161+ hidden_states = self .ff (hidden_states )
162+ encoder_hidden_states = self .ff (encoder_hidden_states )
163+ hidden_states , encoder_hidden_states = self .multi_gate (
164+ hidden_states = hidden_states ,
165+ encoder_hidden_states = encoder_hidden_states ,
166+ factor = layernorm_factor [5 ],
167+ )
168+ hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
169+ hidden_states += residual
125170
126- if hidden_states .dtype == torch .float16 :
127- hidden_states = hidden_states .clip (- 65504 , 65504 )
128- if encoder_hidden_states .dtype == torch .float16 :
129- encoder_hidden_states = encoder_hidden_states .clip (- 65504 , 65504 )
171+ ##############################################################
172+ hidden_states , encoder_hidden_states = (
173+ hidden_states [:, :encoder_hidden_states_len ],
174+ hidden_states [:, encoder_hidden_states_len :],
175+ )
130176 return hidden_states , encoder_hidden_states
131177
178+
132179class CogView4Transformer2DModel (ModelMixin , ConfigMixin ):
133180 r"""
134181 Args:
@@ -335,7 +382,8 @@ def get_rope_embedding(self, height, width, target_h, target_w, device):
335382 freqs = torch .cat ([freqs , freqs ], dim = - 1 ) # [height, width, dim]
336383 freqs = freqs .reshape (height * width , - 1 )
337384
338- return freqs .cos (), freqs .sin ()
385+ return freqs
386+ # return freqs.cos(), freqs.sin()
339387
340388 def forward (
341389 self ,
@@ -391,28 +439,31 @@ def forward(
391439 image_rotary_emb = self .get_rope_embedding (
392440 patch_height , patch_width , target_h = patch_height , target_w = patch_width , device = hidden_states .device
393441 )
442+ # image_rotary_emb = torch.load("/home/lhy/code/cogview/rotary_pos_emb.pt")
443+ # image_rotary_emb = image_rotary_emb[16:16+4096, 0, 0, :]
394444
445+ ######################
395446 # 2. Conditional embeddings
396- temb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
397- temb_cond , temb_uncond = temb .chunk (2 )
398- hidden_states , prompt_embeds , negative_prompt_embeds = self .patch_embed (
399- hidden_states , prompt_embeds , negative_prompt_embeds
400- )
401- hidden_states_cond , hidden_states_uncond = hidden_states .chunk (2 )
447+ # temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
448+ # temb_cond, temb_uncond = temb.chunk(2)
449+ # hidden_states, prompt_embeds, negative_prompt_embeds = self.patch_embed(
450+ # hidden_states, prompt_embeds, negative_prompt_embeds
451+ # )
452+ # hidden_states_cond, hidden_states_uncond = hidden_states.chunk(2)
453+ # encoder_hidden_states_cond = prompt_embeds
454+ # encoder_hidden_states_uncond = negative_prompt_embeds
455+
456+ prompt_embeds = torch .load ("/home/lhy/code/cogview/cp_condition_0_16.pt" )[None , ::]
457+ negative_prompt_embeds = torch .load ("/home/lhy/code/cogview/cp_condition_16_32.pt" )[None , ::]
458+
459+ hidden_states_cond = torch .load ("/home/lhy/code/cogview/cp_vision_input_0_4096.pt" )[None , ::]
460+ hidden_states_uncond = torch .load ("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt" )[None , ::]
461+
462+ temb_cond = torch .load ("/home/lhy/code/cogview/time_embedding_0_1.pt" )[None , ::]
463+ temb_uncond = torch .load ("/home/lhy/code/cogview/time_embedding_1_2.pt" )[None , ::]
464+
402465 encoder_hidden_states_cond = prompt_embeds
403466 encoder_hidden_states_uncond = negative_prompt_embeds
404-
405- ######################
406- # prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
407- # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
408- # prompt_embeds = torch.load("/home/lhy/code/cogview/cp_condition_0_16.pt")[None, ::]
409- # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/cp_uncondition_16_32.pt")[None, ::]
410- #
411- # hidden_states_cond = torch.load("/home/lhy/code/cogview/cp_vision_input_0_4096.pt")
412- # hidden_states_uncond = torch.load("/home/lhy/code/cogview/cp_vision_input_4096:8192.pt")
413- #
414- # emb_cond = torch.load("/home/lhy/code/cogview/time_embedding_0_1.pt")
415- # emb_uncond = torch.load("/home/lhy/code/cogview/time_embedding_1_2.pt")
416467 ######################
417468
418469 for index_block , block in enumerate (self .transformer_blocks ):
@@ -423,13 +474,13 @@ def forward(
423474 hidden_states_cond , encoder_hidden_states_cond = block (
424475 hidden_states = hidden_states_cond ,
425476 encoder_hidden_states = encoder_hidden_states_cond ,
426- emb = temb_cond ,
477+ time_embedding = temb_cond ,
427478 image_rotary_emb = image_rotary_emb ,
428479 )
429480 hidden_states_uncond , encoder_hidden_states_uncond = block (
430481 hidden_states = hidden_states_uncond ,
431482 encoder_hidden_states = encoder_hidden_states_uncond ,
432- emb = temb_uncond ,
483+ time_embedding = temb_uncond ,
433484 image_rotary_emb = image_rotary_emb ,
434485 )
435486
0 commit comments