2828from ...models .modeling_utils import ModelMixin
2929from ...models .normalization import AdaLayerNormContinuous
3030from ...utils import is_torch_version , logging
31- from ..embeddings import CogView3CombinedTimestepSizeEmbeddings , CogView3PlusPatchEmbed , CogView4PatchEmbed
31+ from ..embeddings import CogView3CombinedTimestepSizeEmbeddings , CogView3PlusPatchEmbed
3232from ..modeling_outputs import Transformer2DModelOutput
3333from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3434
@@ -84,7 +84,6 @@ def forward(
8484 hidden_states : torch .Tensor ,
8585 encoder_hidden_states : torch .Tensor ,
8686 emb : torch .Tensor ,
87- ** kwargs ,
8887 ) -> torch .Tensor :
8988 text_seq_length = encoder_hidden_states .size (1 )
9089
@@ -104,7 +103,7 @@ def forward(
104103
105104 # attention
106105 attn_hidden_states , attn_encoder_hidden_states = self .attn1 (
107- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states , ** kwargs
106+ hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
108107 )
109108
110109 hidden_states = hidden_states + gate_msa .unsqueeze (1 ) * attn_hidden_states
@@ -167,7 +166,8 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
167166 """
168167
169168 _supports_gradient_checkpointing = True
170- _no_split_modules = ["CogView3PlusTransformerBlock" , "CogView3PlusPatchEmbed" , "CogView4PlusPatchEmbed" ]
169+ _skip_layerwise_casting_patterns = ["patch_embed" , "norm" ]
170+ _no_split_modules = ["CogView3PlusTransformerBlock" , "CogView3PlusPatchEmbed" ]
171171
172172 @register_to_config
173173 def __init__ (
@@ -192,16 +192,7 @@ def __init__(
192192 # Each of these are sincos embeddings of shape 2 * condition_dim
193193 self .pooled_projection_dim = 3 * 2 * condition_dim
194194
195- self .max_h = 256
196- self .max_w = 256
197- self .rope = self .prepare_rope (
198- embed_dim = self .config .attention_head_dim ,
199- max_h = self .max_h ,
200- max_w = self .max_w ,
201- rotary_base = 10000
202- )
203-
204- self .patch_embed = CogView4PatchEmbed (
195+ self .patch_embed = CogView3PlusPatchEmbed (
205196 in_channels = in_channels ,
206197 hidden_size = self .inner_dim ,
207198 patch_size = patch_size ,
@@ -232,8 +223,7 @@ def __init__(
232223 embedding_dim = self .inner_dim ,
233224 conditioning_embedding_dim = time_embed_dim ,
234225 elementwise_affine = False ,
235- # eps=1e-6,
236- eps = 1e-5 ,
226+ eps = 1e-6 ,
237227 )
238228 self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True )
239229
@@ -303,55 +293,10 @@ def _set_gradient_checkpointing(self, module, value=False):
303293 if hasattr (module , "gradient_checkpointing" ):
304294 module .gradient_checkpointing = value
305295
306- @staticmethod
307- def prepare_rope (embed_dim , max_h , max_w , rotary_base ):
308- dim_h = embed_dim // 2
309- dim_w = embed_dim // 2
310- h_inv_freq = 1.0 / (
311- rotary_base ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h )
312- )
313- w_inv_freq = 1.0 / (
314- rotary_base ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w )
315- )
316- h_seq = torch .arange (max_h , dtype = h_inv_freq .dtype )
317- w_seq = torch .arange (max_w , dtype = w_inv_freq .dtype )
318- freqs_h = torch .outer (h_seq , h_inv_freq )
319- freqs_w = torch .outer (w_seq , w_inv_freq )
320- return (freqs_h , freqs_w )
321-
322- def get_rope_embedding (self , height , width , target_h , target_w , device ):
323- # Get pre-computed frequencies
324- freqs_h , freqs_w = self .rope
325-
326- h_idx = torch .arange (height )
327- w_idx = torch .arange (width )
328- inner_h_idx = (h_idx * self .max_h ) // target_h
329- inner_w_idx = (w_idx * self .max_w ) // target_w
330-
331- freqs_h = freqs_h [inner_h_idx ].to (device )
332- freqs_w = freqs_w [inner_w_idx ].to (device )
333-
334- # Create position matrices for height and width
335- # [height, 1, dim//4] and [1, width, dim//4]
336- freqs_h = freqs_h .unsqueeze (1 )
337- freqs_w = freqs_w .unsqueeze (0 )
338- # Broadcast freqs_h and freqs_w to [height, width, dim//4]
339- freqs_h = freqs_h .expand (height , width , - 1 )
340- freqs_w = freqs_w .expand (height , width , - 1 )
341-
342- # Concatenate along last dimension to get [height, width, dim//2]
343- freqs = torch .cat ([freqs_h , freqs_w ], dim = - 1 )
344-
345- freqs = torch .cat ([freqs , freqs ], dim = - 1 ) # [height, width, dim]
346- freqs = freqs .reshape (height * width , - 1 )
347-
348- return freqs .cos (), freqs .sin ()
349-
350296 def forward (
351297 self ,
352298 hidden_states : torch .Tensor ,
353- prompt_embeds : torch .Tensor ,
354- negative_prompt_embeds : torch .Tensor | None ,
299+ encoder_hidden_states : torch .Tensor ,
355300 timestep : torch .LongTensor ,
356301 original_size : torch .Tensor ,
357302 target_size : torch .Tensor ,
@@ -386,103 +331,58 @@ def forward(
386331 `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
387332 The denoised latents using provided inputs as conditioning.
388333 """
389- batch_size , channel , height , width = hidden_states .shape
390- patch_height , patch_width = height // self .config .patch_size , width // self .config .patch_size
391- do_cfg = negative_prompt_embeds is not None
392-
393- if do_cfg :
394- assert batch_size == prompt_embeds .shape [0 ] + negative_prompt_embeds .shape [0 ], "batch size mismatch in CFG mode"
395- else :
396- assert batch_size == prompt_embeds .shape [0 ], "batch size mismatch in non-CFG mode"
334+ height , width = hidden_states .shape [- 2 :]
335+ text_seq_length = encoder_hidden_states .shape [1 ]
397336
398- hidden_states , prompt_embeds , negative_prompt_embeds = self .patch_embed (
399- hidden_states , prompt_embeds , negative_prompt_embeds
400- )
337+ hidden_states = self .patch_embed (
338+ hidden_states , encoder_hidden_states
339+ ) # takes care of adding positional embeddings too.
401340 emb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
402341
403- hidden_states_cond , hidden_states_uncond = hidden_states .chunk (2 )
404- emb_cond , emb_uncond = emb .chunk (2 )
405-
406- # prepare image_rotary__emb
407- image_rotary_emb = self .get_rope_embedding (
408- patch_height , patch_width , target_h = patch_height , target_w = patch_width , device = hidden_states .device
409- )
410-
411- ######################
412- # prompt_embeds = torch.load("/home/lhy/code/cogview/c_condition_embedding.pt")
413- # negative_prompt_embeds = torch.load("/home/lhy/code/cogview/uc_condition_embedding.pt")
414- prompt_embeds = torch .load ("/home/lhy/code/cogview/cp_condition_0_16.pt" )[None , ::]
415- negative_prompt_embeds = torch .load ("/home/lhy/code/cogview/cp_uncondition_16_32.pt" )[None , ::]
342+ encoder_hidden_states = hidden_states [:, :text_seq_length ]
343+ hidden_states = hidden_states [:, text_seq_length :]
416344
417- hidden_states_cond = torch . load ( "/home/lhy/code/cogview/cp_vision_input_0_4096.pt" )
418- hidden_states_uncond = torch .load ( "/home/lhy/code/cogview/cp_vision_input_4096:8192.pt" )
345+ for index_block , block in enumerate ( self . transformer_blocks ):
346+ if torch .is_grad_enabled () and self . gradient_checkpointing :
419347
420- emb_cond = torch . load ( "/home/lhy/code/cogview/time_embedding_0_1.pt" )
421- emb_uncond = torch . load ( "/home/lhy/code/cogview/time_embedding_1_2.pt" )
422- ######################
348+ def create_custom_forward ( module ):
349+ def custom_forward ( * inputs ):
350+ return module ( * inputs )
423351
424- encoder_hidden_states_cond = prompt_embeds
425- encoder_hidden_states_uncond = negative_prompt_embeds
352+ return custom_forward
426353
427- for index_block , block in enumerate (self .transformer_blocks ):
428- if torch .is_grad_enabled () and self .gradient_checkpointing :
429- ...
430- else :
431- hidden_states_cond , encoder_hidden_states_cond = block (
432- hidden_states = hidden_states_cond ,
433- encoder_hidden_states = encoder_hidden_states_cond ,
434- emb = emb_cond , # refactor later
435- image_rotary_emb = image_rotary_emb ,
436- # image_rotary_emb=None,
354+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
355+ hidden_states , encoder_hidden_states = torch .utils .checkpoint .checkpoint (
356+ create_custom_forward (block ),
357+ hidden_states ,
358+ encoder_hidden_states ,
359+ emb ,
360+ ** ckpt_kwargs ,
437361 )
438- ###########################
439- # hidden_states_cond, encoder_hidden_states_cond = (
440- # self.norm_out.norm(hidden_states_cond),
441- # self.norm_out.norm(encoder_hidden_states_cond),
442- # )
443- ###########################
444-
445- hidden_states_uncond , encoder_hidden_states_uncond = block (
446- hidden_states = hidden_states_uncond ,
447- encoder_hidden_states = encoder_hidden_states_uncond ,
448- emb = emb_uncond , # refactor later
449- image_rotary_emb = image_rotary_emb ,
450- # image_rotary_emb=None,
362+ else :
363+ hidden_states , encoder_hidden_states = block (
364+ hidden_states = hidden_states ,
365+ encoder_hidden_states = encoder_hidden_states ,
366+ emb = emb ,
451367 )
452- ###########################
453- # hidden_states_uncond, encoder_hidden_states_uncond = (
454- # self.norm_out.norm(hidden_states_uncond),
455- # self.norm_out.norm(encoder_hidden_states_uncond),
456- # )
457- ###########################
458-
459- hidden_states_cond = self .norm_out (hidden_states_cond , emb_cond ) # 结果对应于megatron里的final_layer_input
460- hidden_states_uncond = self .norm_out (hidden_states_uncond , emb_uncond ) # 结果对应于megatron里的final_layer_input
461- hidden_states_cond = self .proj_out (hidden_states_cond ) # (batch_size, height*width, patch_size*patch_size*out_channels)
462- hidden_states_uncond = self .proj_out (hidden_states_uncond ) # (batch_size, height*width, patch_size*patch_size*out_channels)
368+
369+ hidden_states = self .norm_out (hidden_states , emb )
370+ hidden_states = self .proj_out (hidden_states ) # (batch_size, height*width, patch_size*patch_size*out_channels)
463371
464372 # unpatchify
465373 patch_size = self .config .patch_size
466374 height = height // patch_size
467375 width = width // patch_size
468376
469- hidden_states_cond = hidden_states_cond .reshape (
470- shape = (hidden_states_cond .shape [0 ], height , width , self .out_channels , patch_size , patch_size )
471- )
472- hidden_states_cond = torch .einsum ("nhwcpq->nchpwq" , hidden_states_cond )
473- output_cond = hidden_states_cond .reshape (
474- shape = (hidden_states_cond .shape [0 ], self .out_channels , height * patch_size , width * patch_size )
475- )
476-
477- hidden_states_uncond = hidden_states_uncond .reshape (
478- shape = (hidden_states_uncond .shape [0 ], height , width , self .out_channels , patch_size , patch_size )
377+ hidden_states = hidden_states .reshape (
378+ shape = (hidden_states .shape [0 ], height , width , self .out_channels , patch_size , patch_size )
479379 )
480- hidden_states_uncond = torch .einsum ("nhwcpq->nchpwq" , hidden_states_uncond )
481- output_uncond = hidden_states_uncond .reshape (
482- shape = (hidden_states_uncond .shape [0 ], self .out_channels , height * patch_size , width * patch_size )
380+ hidden_states = torch .einsum ("nhwcpq->nchpwq" , hidden_states )
381+ output = hidden_states .reshape (
382+ shape = (hidden_states .shape [0 ], self .out_channels , height * patch_size , width * patch_size )
483383 )
484384
485385 if not return_dict :
486- return (output_cond , output_uncond )
386+ return (output , )
487387
488- return Transformer2DModelOutput (sample = output_cond ), Transformer2DModelOutput ( sample = output_uncond )
388+ return Transformer2DModelOutput (sample = output )
0 commit comments