Skip to content

Commit df83bf2

Browse files
1
1 parent 2cbdf35 commit df83bf2

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

examples/cogview4-control/train_control_cogview4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ def main(args):
804804
cogview4_transformer.patch_embed.proj = new_linear
805805

806806
assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0)
807+
cogview4_transformer.register_to_config(in_channels=cogview4_transformer.config.in_channels * 2, out_channels=initial_input_channels)
807808

808809
if args.only_target_transformer_blocks:
809810
cogview4_transformer.patch_embed.proj.requires_grad_(True)

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def __call__(
579579
)
580580

581581
# Prepare latents
582-
latent_channels = self.transformer.config.in_channels
582+
latent_channels = self.transformer.config.in_channels // 2
583583

584584
control_image = self.prepare_image(
585585
image=control_image,

0 commit comments

Comments
 (0)