Skip to content

Commit 21dd890

Browse files
committed
remove pooled_projection_dim as a parameter
1 parent 4ac4e52 commit 21dd890

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,18 @@
3737

3838

3939
class CogView3PlusTransformerBlock(nn.Module):
40-
"""
41-
Updated CogView3 Transformer Block to align with AdalnAttentionMixin style, simplified with qk_ln always True.
40+
r"""
41+
Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
42+
43+
Args:
44+
dim (`int`):
45+
The number of channels in the input and output.
46+
num_attention_heads (`int`):
47+
The number of heads to use for multi-head attention.
48+
attention_head_dim (`int`):
49+
The number of channels in each head.
50+
time_embed_dim (`int`):
51+
The number of channels in timestep embedding.
4252
"""
4353

4454
def __init__(
@@ -145,12 +155,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
145155
condition_dim (`int`, defaults to `256`):
146156
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
147157
crop_coords).
148-
pooled_projection_dim (`int`, defaults to `1536`):
149-
The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions
150-
are used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 *
151-
condition_dim`, we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep
152-
embeddings will be projected to this dimension as well. TODO(yiyi): Do we need this parameter based on the
153-
above explanation?
154158
pos_embed_max_size (`int`, defaults to `128`):
155159
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
156160
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
@@ -175,14 +179,17 @@ def __init__(
175179
text_embed_dim: int = 4096,
176180
time_embed_dim: int = 512,
177181
condition_dim: int = 256,
178-
pooled_projection_dim: int = 1536,
179182
pos_embed_max_size: int = 128,
180183
sample_size: int = 128,
181184
):
182185
super().__init__()
183186
self.out_channels = out_channels
184187
self.inner_dim = num_attention_heads * attention_head_dim
185188

189+
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
190+
# Each of these are sincos embeddings of shape 2 * condition_dim
191+
self.pooled_projection_dim = 3 * 2 * condition_dim
192+
186193
self.patch_embed = CogView3PlusPatchEmbed(
187194
in_channels=in_channels,
188195
hidden_size=self.inner_dim,
@@ -194,7 +201,7 @@ def __init__(
194201
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
195202
embedding_dim=time_embed_dim,
196203
condition_dim=condition_dim,
197-
pooled_projection_dim=pooled_projection_dim,
204+
pooled_projection_dim=self.pooled_projection_dim,
198205
timesteps_dim=self.inner_dim,
199206
)
200207

0 commit comments

Comments
 (0)