3737
3838
3939class CogView3PlusTransformerBlock (nn .Module ):
40- """
41- Updated CogView3 Transformer Block to align with AdalnAttentionMixin style, simplified with qk_ln always True.
40+ r"""
41+ Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
42+
43+ Args:
44+ dim (`int`):
45+ The number of channels in the input and output.
46+ num_attention_heads (`int`):
47+ The number of heads to use for multi-head attention.
48+ attention_head_dim (`int`):
49+ The number of channels in each head.
50+ time_embed_dim (`int`):
51+ The number of channels in timestep embedding.
4252 """
4353
4454 def __init__ (
@@ -145,12 +155,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
145155 condition_dim (`int`, defaults to `256`):
146156 The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
147157 crop_coords).
148- pooled_projection_dim (`int`, defaults to `1536`):
149- The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions
150- are used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 *
151- condition_dim`, we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep
152- embeddings will be projected to this dimension as well. TODO(yiyi): Do we need this parameter based on the
153- above explanation?
154158 pos_embed_max_size (`int`, defaults to `128`):
155159 The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
156160 to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
@@ -175,14 +179,17 @@ def __init__(
175179 text_embed_dim : int = 4096 ,
176180 time_embed_dim : int = 512 ,
177181 condition_dim : int = 256 ,
178- pooled_projection_dim : int = 1536 ,
179182 pos_embed_max_size : int = 128 ,
180183 sample_size : int = 128 ,
181184 ):
182185 super ().__init__ ()
183186 self .out_channels = out_channels
184187 self .inner_dim = num_attention_heads * attention_head_dim
185188
189+ # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
190+ # Each of these are sincos embeddings of shape 2 * condition_dim
191+ self .pooled_projection_dim = 3 * 2 * condition_dim
192+
186193 self .patch_embed = CogView3PlusPatchEmbed (
187194 in_channels = in_channels ,
188195 hidden_size = self .inner_dim ,
@@ -194,7 +201,7 @@ def __init__(
194201 self .time_condition_embed = CogView3CombinedTimestepSizeEmbeddings (
195202 embedding_dim = time_embed_dim ,
196203 condition_dim = condition_dim ,
197- pooled_projection_dim = pooled_projection_dim ,
204+ pooled_projection_dim = self . pooled_projection_dim ,
198205 timesteps_dim = self .inner_dim ,
199206 )
200207
0 commit comments