Skip to content

Commit b439f4c

Browse files
committed
make style
1 parent f35850c commit b439f4c

File tree

4 files changed

+39
-45
lines changed

4 files changed

+39
-45
lines changed

scripts/convert_cogview3_to_diffusers.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,19 @@
66
77
Example usage:
88
python scripts/convert_cogview3_to_diffusers.py \
9-
--original_state_dict_repo_id "THUDM/cogview3-sat" \
10-
--filename "cogview3.pt" \
11-
--transformer \
12-
--output_path "./cogview3_diffusers" \
13-
--dtype "bf16"
14-
15-
Alternatively, if you have a local checkpoint:
16-
python scripts/convert_cogview3_to_diffusers.py \
17-
--checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
18-
--transformer \
9+
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
10+
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
1911
--output_path "/raid/yiyi/cogview3_diffusers" \
2012
--dtype "bf16"
2113
2214
Arguments:
23-
--original_state_dict_repo_id: The Hugging Face repo ID containing the original checkpoint.
24-
--filename: The filename of the checkpoint in the repo (default: "flux.safetensors").
25-
--checkpoint_path: Path to a local checkpoint file (alternative to repo_id and filename).
26-
--transformer: Flag to convert the transformer model.
15+
--transformer_checkpoint_path: Path to Transformer state dict.
16+
--vae_checkpoint_path: Path to VAE state dict.
2717
--output_path: The path to save the converted model.
28-
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32").
18+
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
19+
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
20+
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
21+
2922
Default is "bf16" because CogView3 uses bfloat16 for Training.
3023
3124
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
@@ -73,11 +66,11 @@ def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
7366

7467
new_state_dict = {}
7568

76-
# Convert pos_embed
77-
new_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
78-
new_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
79-
new_state_dict["pos_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
80-
new_state_dict["pos_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
69+
# Convert patch_embed
70+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
71+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
72+
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
73+
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
8174

8275
# Convert time_condition_embed
8376
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(

src/diffusers/models/embeddings.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,10 @@ def __init__(
469469

470470
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
471471
batch_size, channel, height, width = hidden_states.shape
472-
472+
473473
if height % self.patch_size != 0 or width % self.patch_size != 0:
474474
raise ValueError("Height and width must be divisible by patch size")
475-
475+
476476
height = height // self.patch_size
477477
width = width // self.patch_size
478478
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
@@ -1156,11 +1156,9 @@ def forward(
11561156
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
11571157
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
11581158
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
1159-
1159+
11601160
# (B, 3 * condition_dim)
1161-
condition_proj = torch.cat(
1162-
[original_size_proj, crop_coords_proj, target_size_proj], dim=1
1163-
)
1161+
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
11641162

11651163
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
11661164
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)

src/diffusers/models/transformers/transformer_cogview3plus.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,22 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
140140
time_embed_dim (`int`, defaults to `512`):
141141
Output dimension of timestep embeddings.
142142
condition_dim (`int`, defaults to `256`):
143-
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, crop_coords).
143+
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
144+
crop_coords).
144145
pooled_projection_dim (`int`, defaults to `1536`):
145-
The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions are
146-
used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 * condition_dim`,
147-
we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep embeddings will be projected
148-
to this dimension as well.
149-
TODO(yiyi): Do we need this parameter based on the above explanation?
146+
The overall pooled dimension by concatenating SDXL-style resolution conditions. As 3 additional conditions
147+
are used (original_size, target_size, crop_coords), and each is a sinusoidal condition of dimension `2 *
148+
condition_dim`, we get the pooled projection dimension as `2 * condition_dim * 3 => 1536`. The timestep
149+
embeddings will be projected to this dimension as well. TODO(yiyi): Do we need this parameter based on the
150+
above explanation?
150151
pos_embed_max_size (`int`, defaults to `128`):
151-
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added to input
152-
patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 means that the maximum
153-
supported height and width for image generation is `128 * vae_scale_factor * patch_size => 128 * 8 * 2 => 2048`.
152+
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
153+
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
154+
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
155+
patch_size => 128 * 8 * 2 => 2048`.
154156
sample_size (`int`, defaults to `128`):
155-
The base resolution of input latents. If height/width is not provided during generation, this value is used to determine
156-
the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
157+
The base resolution of input latents. If height/width is not provided during generation, this value is used
158+
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
157159
"""
158160

159161
_supports_gradient_checkpointing = True
@@ -336,16 +338,19 @@ def forward(
336338
hidden_states (`torch.Tensor`):
337339
Input `hidden_states` of shape `(batch size, channel, height, width)`.
338340
encoder_hidden_states (`torch.Tensor`):
339-
Conditional embeddings (embeddings computed from the input conditions such as prompts)
340-
of shape `(batch_size, sequence_len, text_embed_dim)`
341+
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
342+
`(batch_size, sequence_len, text_embed_dim)`
341343
timestep (`torch.LongTensor`):
342344
Used to indicate denoising step.
343345
original_size (`torch.Tensor`):
344-
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
346+
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
347+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
345348
target_size (`torch.Tensor`):
346-
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
349+
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
350+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
347351
crop_coords (`torch.Tensor`):
348-
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
352+
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
353+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
349354
return_dict (`bool`, *optional*, defaults to `True`):
350355
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
351356
tuple.

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,7 @@
145145
"CogVideoXImageToVideoPipeline",
146146
"CogVideoXVideoToVideoPipeline",
147147
]
148-
_import_structure["cogview3"] = [
149-
"CogView3PlusPipeline"
150-
]
148+
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
151149
_import_structure["controlnet"].extend(
152150
[
153151
"BlipDiffusionControlNetPipeline",

0 commit comments

Comments
 (0)