@@ -84,6 +84,7 @@ def forward(
8484 hidden_states : torch .Tensor ,
8585 encoder_hidden_states : torch .Tensor ,
8686 emb : torch .Tensor ,
87+ ** kwargs ,
8788 ) -> torch .Tensor :
8889 text_seq_length = encoder_hidden_states .size (1 )
8990
@@ -103,7 +104,7 @@ def forward(
103104
104105 # attention
105106 attn_hidden_states , attn_encoder_hidden_states = self .attn1 (
106- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
107+ hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states , ** kwargs
107108 )
108109
109110 hidden_states = hidden_states + gate_msa .unsqueeze (1 ) * attn_hidden_states
@@ -191,14 +192,15 @@ def __init__(
191192 # Each of these are sincos embeddings of shape 2 * condition_dim
192193 self .pooled_projection_dim = 3 * 2 * condition_dim
193194
194- # self.patch_embed = CogView3PlusPatchEmbed(
195- # in_channels=in_channels,
196- # hidden_size=self.inner_dim,
197- # patch_size=patch_size,
198- # text_hidden_size=text_embed_dim,
199- # pos_embed_max_size=pos_embed_max_size,
200- # )
201- # TODO: 兼容性适配
195+ self .max_h = 256
196+ self .max_w = 256
197+ self .rope = self .prepare_rope (
198+ embed_dim = self .config .attention_head_dim ,
199+ max_h = self .max_h ,
200+ max_w = self .max_w ,
201+ rotary_base = 10000
202+ )
203+
202204 self .patch_embed = CogView4PatchEmbed (
203205 in_channels = in_channels ,
204206 hidden_size = self .inner_dim ,
@@ -300,10 +302,55 @@ def _set_gradient_checkpointing(self, module, value=False):
300302 if hasattr (module , "gradient_checkpointing" ):
301303 module .gradient_checkpointing = value
302304
305+ @staticmethod
306+ def prepare_rope (embed_dim , max_h , max_w , rotary_base ):
307+ dim_h = embed_dim // 2
308+ dim_w = embed_dim // 2
309+ h_inv_freq = 1.0 / (
310+ rotary_base ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h )
311+ )
312+ w_inv_freq = 1.0 / (
313+ rotary_base ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w )
314+ )
315+ h_seq = torch .arange (max_h , dtype = h_inv_freq .dtype )
316+ w_seq = torch .arange (max_w , dtype = w_inv_freq .dtype )
317+ freqs_h = torch .outer (h_seq , h_inv_freq )
318+ freqs_w = torch .outer (w_seq , w_inv_freq )
319+ return (freqs_h , freqs_w )
320+
321+ def get_rope_embedding (self , height , width , target_h , target_w , device ):
322+ # Get pre-computed frequencies
323+ freqs_h , freqs_w = self .rope
324+
325+ h_idx = torch .arange (height )
326+ w_idx = torch .arange (width )
327+ inner_h_idx = (h_idx * self .max_h ) // target_h
328+ inner_w_idx = (w_idx * self .max_w ) // target_w
329+
330+ freqs_h = freqs_h [inner_h_idx ].to (device )
331+ freqs_w = freqs_w [inner_w_idx ].to (device )
332+
333+ # Create position matrices for height and width
334+ # [height, 1, dim//4] and [1, width, dim//4]
335+ freqs_h = freqs_h .unsqueeze (1 )
336+ freqs_w = freqs_w .unsqueeze (0 )
337+ # Broadcast freqs_h and freqs_w to [height, width, dim//4]
338+ freqs_h = freqs_h .expand (height , width , - 1 )
339+ freqs_w = freqs_w .expand (height , width , - 1 )
340+
341+ # Concatenate along last dimension to get [height, width, dim//2]
342+ freqs = torch .cat ([freqs_h , freqs_w ], dim = - 1 )
343+
344+ freqs = torch .cat ([freqs , freqs ], dim = - 1 ) # [height, width, dim]
345+ freqs = freqs .reshape (height * width , - 1 )
346+
347+ return freqs .cos (), freqs .sin ()
348+
303349 def forward (
304350 self ,
305351 hidden_states : torch .Tensor ,
306- encoder_hidden_states : torch .Tensor ,
352+ prompt_embeds : torch .Tensor ,
353+ negative_prompt_embeds : torch .Tensor | None ,
307354 timestep : torch .LongTensor ,
308355 original_size : torch .Tensor ,
309356 target_size : torch .Tensor ,
@@ -338,16 +385,27 @@ def forward(
338385 `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
339386 The denoised latents using provided inputs as conditioning.
340387 """
341- height , width = hidden_states .shape [- 2 :]
342- text_seq_length = encoder_hidden_states .shape [1 ]
388+ batch_size , channel , height , width = hidden_states .shape
389+ patch_height , patch_width = height // self .config .patch_size , width // self .config .patch_size
390+ do_cfg = negative_prompt_embeds is not None
343391
344- hidden_states = self .patch_embed (
345- hidden_states , encoder_hidden_states
346- ) # takes care of adding positional embeddings too.
347- emb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
392+ if do_cfg :
393+ assert batch_size == prompt_embeds .shape [0 ] + negative_prompt_embeds .shape [0 ], "batch size mismatch in CFG mode"
394+ else :
395+ assert batch_size == prompt_embeds .shape [0 ], "batch size mismatch in non-CFG mode"
396+
397+ hidden_states , prompt_embeds , negative_prompt_embeds = self .patch_embed (
398+ hidden_states , prompt_embeds , negative_prompt_embeds
399+ )
348400
349- encoder_hidden_states = hidden_states [:, :text_seq_length ]
350- hidden_states = hidden_states [:, text_seq_length :]
401+ encoder_hidden_states = torch .cat ([prompt_embeds , negative_prompt_embeds ], dim = 0 )
402+
403+ # prepare image_rotary__emb
404+ image_rotary_emb = self .get_rope_embedding (
405+ patch_height , patch_width , target_h = patch_height , target_w = patch_width , device = hidden_states .device
406+ )
407+
408+ emb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
351409
352410 for index_block , block in enumerate (self .transformer_blocks ):
353411 if torch .is_grad_enabled () and self .gradient_checkpointing :
@@ -363,17 +421,19 @@ def custom_forward(*inputs):
363421 create_custom_forward (block ),
364422 hidden_states ,
365423 encoder_hidden_states ,
366- emb ,
424+ emb = emb ,
425+ image_rotary_emb = image_rotary_emb ,
367426 ** ckpt_kwargs ,
368427 )
369428 else :
370429 hidden_states , encoder_hidden_states = block (
371430 hidden_states = hidden_states ,
372431 encoder_hidden_states = encoder_hidden_states ,
373432 emb = emb ,
433+ image_rotary_emb = image_rotary_emb ,
374434 )
375435
376- hidden_states = self .norm_out (hidden_states , emb )
436+ hidden_states = self .norm_out (hidden_states , emb ) # 结果对应于megatron里的final_layer_input
377437 hidden_states = self .proj_out (hidden_states ) # (batch_size, height*width, patch_size*patch_size*out_channels)
378438
379439 # unpatchify
0 commit comments