Skip to content

Commit 7ffecbc

Browse files
Update pipeline_cogview4_control.py
1 parent 25f4e4b commit 7ffecbc

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
>>> import torch
4646
>>> from diffusers import CogView4Pipeline
4747
48-
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
48+
>>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
4949
>>> pipe.to("cuda")
5050
5151
>>> prompt = "A photo of an astronaut riding a horse on mars"
@@ -60,17 +60,11 @@ def calculate_shift(
6060
base_seq_len: int = 256,
6161
base_shift: float = 0.25,
6262
max_shift: float = 0.75,
63-
):
64-
# m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
65-
# b = base_shift - m * base_seq_len
66-
# mu = image_seq_len * m + b
67-
# return mu
68-
63+
) -> float:
6964
m = (image_seq_len / base_seq_len) ** 0.5
7065
mu = m * max_shift + base_shift
7166
return mu
7267

73-
7468
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7569
def retrieve_timesteps(
7670
scheduler,
@@ -224,6 +218,7 @@ def _get_glm_embeds(
224218
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
225219
return prompt_embeds
226220

221+
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
227222
def encode_prompt(
228223
self,
229224
prompt: Union[str, List[str]],
@@ -627,16 +622,15 @@ def __call__(
627622
if timesteps is None
628623
else np.array(timesteps)
629624
)
630-
timesteps = timesteps.astype(np.int64)
625+
timesteps = timesteps.astype(np.int64).astype(np.float32)
631626
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
632627
mu = calculate_shift(
633628
image_seq_len,
634629
self.scheduler.config.get("base_image_seq_len", 256),
635630
self.scheduler.config.get("base_shift", 0.25),
636631
self.scheduler.config.get("max_shift", 0.75),
637632
)
638-
_, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
639-
timesteps = torch.from_numpy(timesteps).to(device)
633+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu)
640634

641635
# Denoising loop
642636
transformer_dtype = self.transformer.dtype

0 commit comments

Comments
 (0)