Skip to content

Commit 8abca19

Browse files
add control convert
1 parent c774f45 commit 8abca19

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

scripts/convert_cogview4_to_diffusers_megatron.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
from tqdm import tqdm
2828
from transformers import GlmModel, PreTrainedTokenizerFast
2929

30-
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
30+
from diffusers import (
31+
AutoencoderKL,
32+
CogView4Pipeline,
33+
CogView4ControlPipeline,
34+
CogView4Transformer2DModel,
35+
FlowMatchEulerDiscreteScheduler,
36+
)
3137
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
3238

3339

@@ -112,6 +118,12 @@
112118
default=128,
113119
help="Maximum size for positional embeddings.",
114120
)
121+
parser.add_argument(
122+
"--control",
123+
action="store_true",
124+
default=False,
125+
help="Whether to use control model.",
126+
)
115127

116128
args = parser.parse_args()
117129

@@ -156,7 +168,9 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
156168
new_state_dict = {}
157169

158170
# Patch Embedding
159-
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
171+
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
172+
hidden_size, 128 if args.control else 64, 64
173+
)
160174
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
161175
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
162176
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
@@ -340,13 +354,22 @@ def main(args):
340354
)
341355

342356
# Create the pipeline
343-
pipe = CogView4Pipeline(
344-
tokenizer=tokenizer,
345-
text_encoder=text_encoder,
346-
vae=vae,
347-
transformer=transformer,
348-
scheduler=scheduler,
349-
)
357+
if args.control:
358+
pipe = CogView4ControlPipeline(
359+
tokenizer=tokenizer,
360+
text_encoder=text_encoder,
361+
vae=vae,
362+
transformer=transformer,
363+
scheduler=scheduler,
364+
)
365+
else:
366+
pipe = CogView4Pipeline(
367+
tokenizer=tokenizer,
368+
text_encoder=text_encoder,
369+
vae=vae,
370+
transformer=transformer,
371+
scheduler=scheduler,
372+
)
350373

351374
# Save the converted pipeline
352375
pipe.save_pretrained(

0 commit comments

Comments
 (0)