Skip to content

Commit c30ca7a

Browse files
change to channel 1
1 parent a97fca2 commit c30ca7a

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,9 +646,7 @@ def __call__(
646646
for i, t in enumerate(timesteps):
647647
if self.interrupt:
648648
continue
649-
650-
latent_model_input = torch.cat([latents, control_image], dim=2).to(transformer_dtype)
651-
# latent_model_input = latents.to(transformer_dtype)
649+
latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype)
652650

653651
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
654652
timestep = t.expand(latents.shape[0])

0 commit comments

Comments
 (0)