2626from ...utils import logging
2727from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2828from ..modeling_outputs import Transformer2DModelOutput
29- from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3029
3130
3231logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -81,6 +80,53 @@ def forward(
8180 return hidden_states , prompt_embeds , negative_prompt_embeds
8281
8382
83+ class CogView4AdaLayerNormZero (nn .Module ):
84+ def __init__ (self , embedding_dim : int , dim : int ) -> None :
85+ super ().__init__ ()
86+
87+ self .norm = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
88+ self .norm_context = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
89+ self .linear = nn .Linear (embedding_dim , 12 * dim , bias = True )
90+
91+ def forward (
92+ self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor , temb : torch .Tensor
93+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
94+ norm_hidden_states = self .norm (hidden_states )
95+ norm_encoder_hidden_states = self .norm_context (encoder_hidden_states )
96+
97+ emb = self .linear (temb )
98+ (
99+ shift_msa ,
100+ c_shift_msa ,
101+ scale_msa ,
102+ c_scale_msa ,
103+ gate_msa ,
104+ c_gate_msa ,
105+ shift_mlp ,
106+ c_shift_mlp ,
107+ scale_mlp ,
108+ c_scale_mlp ,
109+ gate_mlp ,
110+ c_gate_mlp ,
111+ ) = emb .chunk (12 , dim = 1 )
112+
113+ hidden_states = norm_hidden_states * (1 + scale_msa .unsqueeze (1 )) + shift_msa .unsqueeze (1 )
114+ encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa .unsqueeze (1 )) + c_shift_msa .unsqueeze (1 )
115+
116+ return (
117+ hidden_states ,
118+ gate_msa ,
119+ shift_mlp ,
120+ scale_mlp ,
121+ gate_mlp ,
122+ encoder_hidden_states ,
123+ c_gate_msa ,
124+ c_shift_mlp ,
125+ c_scale_mlp ,
126+ c_gate_mlp ,
127+ )
128+
129+
84130class CogView4AttnProcessor :
85131 """
86132 Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -89,7 +135,7 @@ class CogView4AttnProcessor:
89135
90136 def __init__ (self ):
91137 if not hasattr (F , "scaled_dot_product_attention" ):
92- raise ImportError ("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
138+ raise ImportError ("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
93139
94140 def __call__ (
95141 self ,
@@ -153,10 +199,8 @@ def __init__(
153199 ) -> None :
154200 super ().__init__ ()
155201
156- self .norm1 = CogView3PlusAdaLayerNormZeroTextImage (embedding_dim = time_embed_dim , dim = dim )
157- self .adaln = self .norm1 .linear
158- self .layernorm = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
159-
202+ # 1. Attention
203+ self .norm1 = CogView4AdaLayerNormZero (time_embed_dim , dim )
160204 self .attn1 = Attention (
161205 query_dim = dim ,
162206 heads = num_attention_heads ,
@@ -169,97 +213,52 @@ def __init__(
169213 processor = CogView4AttnProcessor (),
170214 )
171215
216+ # 2. Feedforward
217+ self .norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
218+ self .norm2_context = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-5 )
172219 self .ff = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
173220
174- def multi_modulate (self , hidden_states , encoder_hidden_states , factors ) -> torch .Tensor :
175- _ , _ , h = factors [0 ].shape
176- shift_factor , scale_factor = factors [0 ].view (- 1 , h ), factors [1 ].view (- 1 , h )
177-
178- shift_factor_hidden_states , shift_factor_encoder_hidden_states = shift_factor .chunk (2 , dim = 0 )
179- scale_factor_hidden_states , scale_factor_encoder_hidden_states = scale_factor .chunk (2 , dim = 0 )
180- shift_factor_hidden_states = shift_factor_hidden_states .unsqueeze (1 )
181- scale_factor_hidden_states = scale_factor_hidden_states .unsqueeze (1 )
182- hidden_states = torch .addcmul (shift_factor_hidden_states , hidden_states , (1 + scale_factor_hidden_states ))
183-
184- shift_factor_encoder_hidden_states = shift_factor_encoder_hidden_states .unsqueeze (1 )
185- scale_factor_encoder_hidden_states = scale_factor_encoder_hidden_states .unsqueeze (1 )
186- encoder_hidden_states = torch .addcmul (
187- shift_factor_encoder_hidden_states , encoder_hidden_states , (1 + scale_factor_encoder_hidden_states )
188- )
189-
190- return hidden_states , encoder_hidden_states
191-
192- def multi_gate (self , hidden_states , encoder_hidden_states , factor ):
193- _ , _ , hidden_dim = hidden_states .shape
194- gate_factor = factor .view (- 1 , hidden_dim )
195- gate_factor_hidden_states , gate_factor_encoder_hidden_states = gate_factor .chunk (2 , dim = 0 )
196- gate_factor_hidden_states = gate_factor_hidden_states .unsqueeze (1 )
197- gate_factor_encoder_hidden_states = gate_factor_encoder_hidden_states .unsqueeze (1 )
198- hidden_states = gate_factor_hidden_states * hidden_states
199- encoder_hidden_states = gate_factor_encoder_hidden_states * encoder_hidden_states
200-
201- return hidden_states , encoder_hidden_states
202-
203221 def forward (
204222 self ,
205223 hidden_states : torch .Tensor ,
206224 encoder_hidden_states : torch .Tensor ,
207225 temb : Optional [torch .Tensor ] = None ,
208226 image_rotary_emb : Optional [torch .Tensor ] = None ,
209227 ) -> torch .Tensor :
210- batch_size , encoder_hidden_states_len , hidden_dim = encoder_hidden_states .shape
211- hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
212- residual = hidden_states
213- layernorm_factor = (
214- self .adaln (temb )
215- .view (
216- temb .shape [0 ],
217- 6 ,
218- 2 ,
219- hidden_states .shape [- 1 ],
220- )
221- .permute (1 , 2 , 0 , 3 )
222- .contiguous ()
223- )
224- hidden_states = self .layernorm (hidden_states )
225- hidden_states , encoder_hidden_states = self .multi_modulate (
226- hidden_states = hidden_states [:, encoder_hidden_states_len :],
227- encoder_hidden_states = hidden_states [:, :encoder_hidden_states_len ],
228- factors = (layernorm_factor [0 ], layernorm_factor [1 ]),
229- )
230- hidden_states , encoder_hidden_states = self .attn1 (
231- hidden_states = hidden_states ,
232- encoder_hidden_states = encoder_hidden_states ,
228+ # 1. Timestep conditioning
229+ (
230+ norm_hidden_states ,
231+ gate_msa ,
232+ shift_mlp ,
233+ scale_mlp ,
234+ gate_mlp ,
235+ norm_encoder_hidden_states ,
236+ c_gate_msa ,
237+ c_shift_mlp ,
238+ c_scale_mlp ,
239+ c_gate_mlp ,
240+ ) = self .norm1 (hidden_states , encoder_hidden_states , temb )
241+
242+ # 2. Attention
243+ attn_hidden_states , attn_encoder_hidden_states = self .attn1 (
244+ hidden_states = norm_hidden_states ,
245+ encoder_hidden_states = norm_encoder_hidden_states ,
233246 image_rotary_emb = image_rotary_emb ,
234247 )
235- hidden_states , encoder_hidden_states = self .multi_gate (
236- hidden_states = hidden_states ,
237- encoder_hidden_states = encoder_hidden_states ,
238- factor = layernorm_factor [2 ],
239- )
240- hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
241- hidden_states += residual
242-
243- residual = hidden_states
244- hidden_states = self .layernorm (hidden_states )
245- hidden_states , encoder_hidden_states = self .multi_modulate (
246- hidden_states = hidden_states [:, encoder_hidden_states_len :],
247- encoder_hidden_states = hidden_states [:, :encoder_hidden_states_len ],
248- factors = (layernorm_factor [3 ], layernorm_factor [4 ]),
249- )
250- hidden_states = self .ff (hidden_states )
251- encoder_hidden_states = self .ff (encoder_hidden_states )
252- hidden_states , encoder_hidden_states = self .multi_gate (
253- hidden_states = hidden_states ,
254- encoder_hidden_states = encoder_hidden_states ,
255- factor = layernorm_factor [5 ],
256- )
257- hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
258- hidden_states += residual
259- hidden_states , encoder_hidden_states = (
260- hidden_states [:, encoder_hidden_states_len :],
261- hidden_states [:, :encoder_hidden_states_len ],
262- )
248+ hidden_states = hidden_states + attn_hidden_states * gate_msa .unsqueeze (1 )
249+ encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa .unsqueeze (1 )
250+
251+ # 3. Feedforward
252+ norm_hidden_states = self .norm2 (hidden_states ) * (1 + scale_mlp .unsqueeze (1 )) + shift_mlp .unsqueeze (1 )
253+ norm_encoder_hidden_states = self .norm2_context (encoder_hidden_states ) * (
254+ 1 + c_scale_mlp .unsqueeze (1 )
255+ ) + c_shift_mlp .unsqueeze (1 )
256+
257+ ff_output = self .ff (norm_hidden_states )
258+ ff_output_context = self .ff (norm_encoder_hidden_states )
259+ hidden_states = hidden_states + ff_output * gate_mlp .unsqueeze (1 )
260+ encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp .unsqueeze (1 )
261+
263262 return hidden_states , encoder_hidden_states
264263
265264
0 commit comments