@@ -714,6 +714,114 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
714714 return freqs_cos , freqs_sin
715715
716716
717+ class CogView3PlusPosEmbed (nn .Module ):
718+ def __init__ (
719+ self ,
720+ max_height : int = 128 ,
721+ max_width : int = 128 ,
722+ hidden_size : int = 2560 ,
723+ text_length : int = 0 ,
724+ block_size : int = 16 ,
725+ ):
726+ super ().__init__ ()
727+ self .max_height = max_height
728+ self .max_width = max_width
729+ self .hidden_size = hidden_size
730+ self .text_length = text_length
731+ self .block_size = block_size
732+
733+ # Initialize the positional embedding as a non-trainable parameter
734+ self .image_pos_embedding = nn .Parameter (
735+ torch .zeros (self .max_height , self .max_width , hidden_size ), requires_grad = False
736+ )
737+ # Reinitialize the positional embedding using a sin-cos function
738+ self .reinit ()
739+
740+ def forward (self , target_size : List [int ]) -> torch .Tensor :
741+ ret = []
742+ for h , w in target_size :
743+ # Scale height and width according to the block size
744+ h , w = h // self .block_size , w // self .block_size
745+
746+ # Reshape the image positional embedding for the target size
747+ image_pos_embed = self .image_pos_embedding [:h , :w ].reshape (h * w , - 1 )
748+
749+ # Combine the text positional embedding and image positional embedding
750+ pos_embed = torch .cat (
751+ [
752+ torch .zeros (
753+ (self .text_length , self .hidden_size ),
754+ dtype = image_pos_embed .dtype ,
755+ device = image_pos_embed .device ,
756+ ),
757+ image_pos_embed ,
758+ ],
759+ dim = 0 ,
760+ )
761+
762+ ret .append (pos_embed [None , ...]) # Add a batch dimension
763+
764+ return torch .cat (ret , dim = 0 ) # Concatenate along the batch dimension
765+
766+ def reinit (self ):
767+ # Initialize the positional embedding using a 2D sin-cos function
768+ pos_embed_np = self .get_2d_sincos_pos_embed (self .hidden_size , self .max_height , self .max_width )
769+ self .image_pos_embedding .data .copy_ (torch .from_numpy (pos_embed_np ).float ())
770+
771+
772+ class CogView3PlusImagePatchEmbedding (nn .Module ):
773+ def __init__ (
774+ self ,
775+ in_channels : int = 128 ,
776+ hidden_size : int = 128 ,
777+ patch_size : int = 2 ,
778+ text_hidden_size : int = 4096 ,
779+ ):
780+ super ().__init__ ()
781+ self .in_channels = in_channels
782+ self .hidden_size = hidden_size
783+ self .patch_size = patch_size
784+ self .text_hidden_size = text_hidden_size
785+
786+ # Linear projection for image patches
787+ self .proj = nn .Linear (in_channels * patch_size ** 2 , hidden_size )
788+
789+ # Linear projection for text embeddings
790+ self .text_proj = nn .Linear (text_hidden_size , hidden_size )
791+
792+ def forward (self , images : torch .Tensor , encoder_outputs : torch .Tensor = None ) -> torch .Tensor :
793+ # Rearrange the images
794+ # patches_images = rearrange(images, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=self.patch_size, p2=self.patch_size)
795+
796+ b , c , h , w = images .shape
797+ p1 , p2 = self .patch_size , self .patch_size
798+ assert h % p1 == 0 and w % p2 == 0 , "Height and width must be divisible by patch size"
799+
800+ images = images .view (b , c , h // p1 , p1 , w // p2 , p2 )
801+ patches_images = images .permute (0 , 2 , 4 , 1 , 3 , 5 ).contiguous ()
802+ patches_images = patches_images .view (b , (h // p1 ) * (w // p2 ), c * p1 * p2 )
803+
804+ # Project the patches
805+ image_emb = self .proj (patches_images )
806+
807+ # If text embeddings are provided, project and concatenate them
808+ if self .text_hidden_size is not None and encoder_outputs is not None :
809+ text_emb = self .text_proj (encoder_outputs )
810+ emb = torch .cat ([text_emb , image_emb ], dim = 1 )
811+ else :
812+ emb = image_emb
813+
814+ return emb
815+
816+ def reinit (self , parent_model = None ):
817+ # Reinitialize the projection weights
818+ nn .init .xavier_uniform_ (self .proj .weight )
819+ nn .init .constant_ (self .proj .bias , 0 )
820+ if self .text_hidden_size is not None :
821+ nn .init .xavier_uniform_ (self .text_proj .weight )
822+ nn .init .constant_ (self .text_proj .bias , 0 )
823+
824+
717825class TimestepEmbedding (nn .Module ):
718826 def __init__ (
719827 self ,
0 commit comments