@@ -714,68 +714,68 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
714714 return freqs_cos , freqs_sin
715715
716716
717- class CogView3PlusPosEmbed (nn .Module ):
718- def __init__ (
719- self ,
720- max_height : int = 128 ,
721- max_width : int = 128 ,
722- hidden_size : int = 2560 ,
723- text_length : int = 0 ,
724- block_size : int = 16 ,
725- ):
726- super ().__init__ ()
727- self .max_height = max_height
728- self .max_width = max_width
729- self .hidden_size = hidden_size
730- self .text_length = text_length
731- self .block_size = block_size
732-
733- # Initialize the positional embedding as a non-trainable parameter
734- self .image_pos_embedding = nn .Parameter (
735- torch .zeros (self .max_height , self .max_width , hidden_size ), requires_grad = False
736- )
737- # Reinitialize the positional embedding using a sin-cos function
738- self .reinit ()
739-
740- def forward (self , target_size : List [int ]) -> torch .Tensor :
741- ret = []
742- for h , w in target_size :
743- # Scale height and width according to the block size
744- h , w = h // self .block_size , w // self .block_size
745-
746- # Reshape the image positional embedding for the target size
747- image_pos_embed = self .image_pos_embedding [:h , :w ].reshape (h * w , - 1 )
748-
749- # Combine the text positional embedding and image positional embedding
750- pos_embed = torch .cat (
751- [
752- torch .zeros (
753- (self .text_length , self .hidden_size ),
754- dtype = image_pos_embed .dtype ,
755- device = image_pos_embed .device ,
756- ),
757- image_pos_embed ,
758- ],
759- dim = 0 ,
760- )
761-
762- ret .append (pos_embed [None , ...]) # Add a batch dimension
763-
764- return torch .cat (ret , dim = 0 ) # Concatenate along the batch dimension
765-
766- def reinit (self ):
767- # Initialize the positional embedding using the updated 2D sin-cos function
768- grid_size = (self .max_height , self .max_width )
769- pos_embed_np = get_2d_sincos_pos_embed (
770- embed_dim = self .hidden_size ,
771- grid_size = grid_size ,
772- )
773-
774- # Reshape the positional embedding to the desired shape
775- pos_embed_np = pos_embed_np .reshape (self .max_height , self .max_width , self .hidden_size )
776-
777- # Copy the positional embedding data
778- self .image_pos_embedding .data .copy_ (torch .from_numpy (pos_embed_np ).float ())
717+ # class CogView3PlusPosEmbed(nn.Module):
718+ # def __init__(
719+ # self,
720+ # max_height: int = 128,
721+ # max_width: int = 128,
722+ # hidden_size: int = 2560,
723+ # text_length: int = 0,
724+ # block_size: int = 16,
725+ # ):
726+ # super().__init__()
727+ # self.max_height = max_height
728+ # self.max_width = max_width
729+ # self.hidden_size = hidden_size
730+ # self.text_length = text_length
731+ # self.block_size = block_size
732+ #
733+ # # Initialize the positional embedding as a non-trainable parameter
734+ # self.image_pos_embedding = nn.Parameter(
735+ # torch.zeros(self.max_height, self.max_width, hidden_size), requires_grad=False
736+ # )
737+ # # Reinitialize the positional embedding using a sin-cos function
738+ # self.reinit()
739+ #
740+ # def forward(self, target_size: List[int]) -> torch.Tensor:
741+ # ret = []
742+ # for h, w in target_size:
743+ # # Scale height and width according to the block size
744+ # h, w = h // self.block_size, w // self.block_size
745+ #
746+ # # Reshape the image positional embedding for the target size
747+ # image_pos_embed = self.image_pos_embedding[:h, :w].reshape(h * w, -1)
748+ #
749+ # # Combine the text positional embedding and image positional embedding
750+ # pos_embed = torch.cat(
751+ # [
752+ # torch.zeros(
753+ # (self.text_length, self.hidden_size),
754+ # dtype=image_pos_embed.dtype,
755+ # device=image_pos_embed.device,
756+ # ),
757+ # image_pos_embed,
758+ # ],
759+ # dim=0,
760+ # )
761+ #
762+ # ret.append(pos_embed[None, ...]) # Add a batch dimension
763+ #
764+ # return torch.cat(ret, dim=0) # Concatenate along the batch dimension
765+ #
766+ # def reinit(self):
767+ # # Initialize the positional embedding using the updated 2D sin-cos function
768+ # grid_size = (self.max_height, self.max_width)
769+ # pos_embed_np = get_2d_sincos_pos_embed(
770+ # embed_dim=self.hidden_size,
771+ # grid_size=grid_size,
772+ # )
773+ #
774+ # # Reshape the positional embedding to the desired shape
775+ # pos_embed_np = pos_embed_np.reshape(self.max_height, self.max_width, self.hidden_size)
776+ #
777+ # # Copy the positional embedding data
778+ # self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed_np).float())
779779
780780
781781class CogView3PlusImagePatchEmbedding (nn .Module ):
@@ -809,8 +809,6 @@ def forward(self, images: torch.Tensor, encoder_outputs: torch.Tensor = None) ->
809809 images = images .view (b , c , h // p1 , p1 , w // p2 , p2 )
810810 patches_images = images .permute (0 , 2 , 4 , 1 , 3 , 5 ).contiguous ()
811811 patches_images = patches_images .view (b , (h // p1 ) * (w // p2 ), c * p1 * p2 )
812-
813- # Project the patches
814812 image_emb = self .proj (patches_images )
815813
816814 # If text embeddings are provided, project and concatenate them
@@ -1135,6 +1133,27 @@ def forward(self, image_embeds: torch.Tensor):
11351133 return self .norm (x )
11361134
11371135
1136+ class CogView3CombineTimestepLabelEmbedding (nn .Module ):
1137+ def __init__ (self , time_embed_dim , label_embed_dim , in_channels = 2560 ):
1138+ super ().__init__ ()
1139+
1140+ self .time_proj = Timesteps (num_channels = in_channels , flip_sin_to_cos = True , downscale_freq_shift = 1 )
1141+ self .timestep_embedder = TimestepEmbedding (in_channels = in_channels , time_embed_dim = time_embed_dim )
1142+ self .label_embedder = nn .Sequential (
1143+ nn .Linear (label_embed_dim , time_embed_dim ),
1144+ nn .SiLU (),
1145+ nn .Linear (time_embed_dim , time_embed_dim ),
1146+ )
1147+
1148+ def forward (self , timestep , class_labels , hidden_dtype = None ):
1149+ t_proj = self .time_proj (timestep )
1150+ t_emb = self .timestep_embedder (t_proj .to (dtype = hidden_dtype ))
1151+ label_emb = self .label_embedder (class_labels )
1152+ emb = t_emb + label_emb
1153+
1154+ return emb
1155+
1156+
11381157class CombinedTimestepLabelEmbeddings (nn .Module ):
11391158 def __init__ (self , num_classes , embedding_dim , class_dropout_prob = 0.1 ):
11401159 super ().__init__ ()
0 commit comments