|
27 | 27 | from tqdm import tqdm |
28 | 28 | from transformers import GlmModel, PreTrainedTokenizerFast |
29 | 29 |
|
30 | | -from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler |
| 30 | +from diffusers import ( |
| 31 | + AutoencoderKL, |
| 32 | + CogView4Pipeline, |
| 33 | + CogView4ControlPipeline, |
| 34 | + CogView4Transformer2DModel, |
| 35 | + FlowMatchEulerDiscreteScheduler, |
| 36 | +) |
31 | 37 | from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint |
32 | 38 |
|
33 | 39 |
|
|
112 | 118 | default=128, |
113 | 119 | help="Maximum size for positional embeddings.", |
114 | 120 | ) |
| 121 | +parser.add_argument( |
| 122 | + "--control", |
| 123 | + action="store_true", |
| 124 | + default=False, |
| 125 | + help="Whether to use control model.", |
| 126 | +) |
115 | 127 |
|
116 | 128 | args = parser.parse_args() |
117 | 129 |
|
@@ -156,7 +168,9 @@ def convert_megatron_transformer_checkpoint_to_diffusers( |
156 | 168 | new_state_dict = {} |
157 | 169 |
|
158 | 170 | # Patch Embedding |
159 | | - new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64) |
| 171 | + new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape( |
| 172 | + hidden_size, 128 if args.control else 64, 64 |
| 173 | + ) |
160 | 174 | new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"] |
161 | 175 | new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"] |
162 | 176 | new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"] |
@@ -340,13 +354,22 @@ def main(args): |
340 | 354 | ) |
341 | 355 |
|
342 | 356 | # Create the pipeline |
343 | | - pipe = CogView4Pipeline( |
344 | | - tokenizer=tokenizer, |
345 | | - text_encoder=text_encoder, |
346 | | - vae=vae, |
347 | | - transformer=transformer, |
348 | | - scheduler=scheduler, |
349 | | - ) |
| 357 | + if args.control: |
| 358 | + pipe = CogView4ControlPipeline( |
| 359 | + tokenizer=tokenizer, |
| 360 | + text_encoder=text_encoder, |
| 361 | + vae=vae, |
| 362 | + transformer=transformer, |
| 363 | + scheduler=scheduler, |
| 364 | + ) |
| 365 | + else: |
| 366 | + pipe = CogView4Pipeline( |
| 367 | + tokenizer=tokenizer, |
| 368 | + text_encoder=text_encoder, |
| 369 | + vae=vae, |
| 370 | + transformer=transformer, |
| 371 | + scheduler=scheduler, |
| 372 | + ) |
350 | 373 |
|
351 | 374 | # Save the converted pipeline |
352 | 375 | pipe.save_pretrained( |
|
0 commit comments