@@ -812,6 +812,75 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
812812
813813        return  (hidden_states  +  pos_embed ).to (hidden_states .dtype )
814814
815+ class  CogView4PatchEmbed (nn .Module ):
816+     def  __init__ (
817+         self ,
818+         in_channels : int  =  16 ,
819+         hidden_size : int  =  2560 ,
820+         patch_size : int  =  2 ,
821+         text_hidden_size : int  =  4096 ,
822+         pos_embed_max_size : int  =  128 ,
823+     ):
824+         super ().__init__ ()
825+         self .in_channels  =  in_channels 
826+         self .hidden_size  =  hidden_size 
827+         self .patch_size  =  patch_size 
828+         self .text_hidden_size  =  text_hidden_size 
829+         self .pos_embed_max_size  =  pos_embed_max_size 
830+         # Linear projection for image patches 
831+         self .proj  =  nn .Linear (in_channels  *  patch_size ** 2 , hidden_size )
832+ 
833+         # Linear projection for text embeddings 
834+         self .text_proj  =  nn .Linear (text_hidden_size , hidden_size )
835+         #TODO:这里需要改成RotaryEmbed 
836+         pos_embed  =  get_2d_sincos_pos_embed (
837+             hidden_size , pos_embed_max_size , base_size = pos_embed_max_size , output_type = "pt" 
838+         )
839+         pos_embed  =  pos_embed .reshape (pos_embed_max_size , pos_embed_max_size , hidden_size )
840+         self .register_buffer ("pos_embed" , pos_embed .float (), persistent = False )
841+ 
842+     def  forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ) ->  torch .Tensor :
843+         batch_size , channel , height , width  =  hidden_states .shape 
844+ 
845+         if  height  %  self .patch_size  !=  0  or  width  %  self .patch_size  !=  0 :
846+             raise  ValueError ("Height and width must be divisible by patch size" )
847+ 
848+         height  =  height  //  self .patch_size 
849+         width  =  width  //  self .patch_size 
850+         hidden_states  =  hidden_states .view (batch_size , channel , height , self .patch_size , width , self .patch_size )
851+         hidden_states  =  hidden_states .permute (0 , 2 , 4 , 1 , 3 , 5 ).contiguous ()
852+         hidden_states  =  hidden_states .view (batch_size , height  *  width , channel  *  self .patch_size  *  self .patch_size )
853+ 
854+         # Project the patches 
855+         hidden_states  =  self .proj (hidden_states )
856+         prompt_encoder_hidden_states  =  []
857+         negative_prompt_encoder_hidden_states  =  []
858+ 
859+         for  i  in  range (0 , batch_size , 2 ):
860+             prompt_embeds  =  encoder_hidden_states [i , :, :]  # [seq_len, hidden_size] 
861+             negative_embeds  =  encoder_hidden_states [i  +  1 , :, :]  # [seq_len, hidden_size] 
862+             mask  =  negative_embeds .abs ().sum (dim = - 1 ) >  0 
863+             seq_len_neg  =  mask .sum ().item ()  # 非零部分的数量 
864+             negative_embeds_valid  =  negative_embeds [:seq_len_neg , :]  # [seq_len_neg, hidden_size] 
865+             prompt_encoder_hidden_states .append (prompt_embeds )
866+             negative_prompt_encoder_hidden_states .append (negative_embeds_valid )
867+         prompt_encoder_hidden_states  =  torch .stack (prompt_encoder_hidden_states , dim = 0 )
868+         negative_prompt_encoder_hidden_states  =  torch .stack (negative_prompt_encoder_hidden_states , dim = 0 )
869+         prompt_text_length  =  prompt_encoder_hidden_states .shape [1 ]
870+         negative_prompt_text_length  =   negative_prompt_encoder_hidden_states .shape [1 ]
871+         image_pos_embed  =  self .pos_embed [:height , :width ].reshape (height  *  width , - 1 )
872+         prompt_text_pos_embed  =  torch .zeros (
873+             (prompt_text_length , self .hidden_size ), dtype = image_pos_embed .dtype , device = image_pos_embed .device 
874+         )
875+         negative_prompt_text_pos_embed  =  torch .zeros (
876+             (negative_prompt_text_length , self .hidden_size ), dtype = image_pos_embed .dtype , device = image_pos_embed .device 
877+         )
878+         prompt_pos_embed  =  torch .cat ([prompt_text_pos_embed , image_pos_embed ], dim = 0 )[None , ...]
879+         negative_prompt_pos_embed  =  torch .cat ([negative_prompt_text_pos_embed , image_pos_embed ], dim = 0 )[None , ...]
880+         # TODO: 拼接哼一个完整的 pos_embed 以及拼接 Rope Embed 
881+         pos_embed  =  torch .cat ([prompt_pos_embed , negative_prompt_pos_embed ], dim = 0 )
882+         hidden_states  =  hidden_states  +  pos_embed .to (hidden_states .dtype )
883+         return  hidden_states 
815884
816885def  get_3d_rotary_pos_embed (
817886    embed_dim ,
0 commit comments