Skip to content

Commit 347dd17

Browse files
use control format
1 parent cbfeb0b commit 347dd17

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

scripts/convert_cogview4_to_diffusers_megatron.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,14 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
162162
Returns:
163163
dict: The converted state dictionary compatible with Diffusers.
164164
"""
165-
ckpt = torch.load(ckpt_path, map_location="cpu")
165+
ckpt = torch.load(ckpt_path, map_location="cpu",weights_only=False)
166166
mega = ckpt["model"]
167167

168168
new_state_dict = {}
169169

170170
# Patch Embedding
171171
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
172-
hidden_size, 128 if args.control else 64, 64
172+
hidden_size, 128 if args.control else 64
173173
)
174174
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
175175
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
@@ -260,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
260260
Returns:
261261
dict: The converted VAE state dictionary compatible with Diffusers.
262262
"""
263-
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
263+
original_state_dict = torch.load(ckpt_path, map_location="cpu",weights_only=False)["state_dict"]
264264
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
265265

266266

@@ -294,7 +294,7 @@ def main(args):
294294
)
295295
transformer = CogView4Transformer2DModel(
296296
patch_size=2,
297-
in_channels=16,
297+
in_channels=32 if args.control else 16,
298298
num_layers=args.num_layers,
299299
attention_head_dim=args.attention_head_dim,
300300
num_attention_heads=args.num_heads,

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,18 @@
4646
>>> from diffusers import CogView4Pipeline
4747
4848
>>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
49-
>>> pipe.to("cuda")
50-
51-
>>> prompt = "A photo of an astronaut riding a horse on mars"
52-
>>> image = pipe(prompt).images[0]
53-
>>> image.save("output.png")
49+
>>> control_image = load_image(
50+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
51+
... )
52+
>>> prompt = "A bird in space"
53+
>>> image = pipe(
54+
... prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5)
55+
... ).images[0]
56+
>>> image.save("cogview4-control.png")
5457
```
5558
"""
5659

57-
60+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift
5861
def calculate_shift(
5962
image_seq_len,
6063
base_seq_len: int = 256,
@@ -175,6 +178,7 @@ def __init__(
175178
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
176179
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
177180

181+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds
178182
def _get_glm_embeds(
179183
self,
180184
prompt: Union[str, List[str]] = None,
@@ -341,7 +345,7 @@ def prepare_image(
341345
# image batch size is the same as prompt batch size
342346
repeat_by = num_images_per_prompt
343347

344-
image = image.repeat_interleave(repeat_by, dim=0)
348+
image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by)
345349

346350
image = image.to(device=device, dtype=dtype)
347351

0 commit comments

Comments
 (0)