4545 >>> import torch
4646 >>> from diffusers import CogView4Pipeline
4747
48- >>> pipe = CogView4Pipeline .from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
48+ >>> pipe = CogView4ControlPipeline .from_pretrained("THUDM/CogView4-6B-Control ", torch_dtype=torch.bfloat16)
4949 >>> pipe.to("cuda")
5050
5151 >>> prompt = "A photo of an astronaut riding a horse on mars"
@@ -60,17 +60,11 @@ def calculate_shift(
6060 base_seq_len : int = 256 ,
6161 base_shift : float = 0.25 ,
6262 max_shift : float = 0.75 ,
63- ):
64- # m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
65- # b = base_shift - m * base_seq_len
66- # mu = image_seq_len * m + b
67- # return mu
68-
63+ ) -> float :
6964 m = (image_seq_len / base_seq_len ) ** 0.5
7065 mu = m * max_shift + base_shift
7166 return mu
7267
73-
7468# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7569def retrieve_timesteps (
7670 scheduler ,
@@ -224,6 +218,7 @@ def _get_glm_embeds(
224218 prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
225219 return prompt_embeds
226220
221+ # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt
227222 def encode_prompt (
228223 self ,
229224 prompt : Union [str , List [str ]],
@@ -627,16 +622,15 @@ def __call__(
627622 if timesteps is None
628623 else np .array (timesteps )
629624 )
630- timesteps = timesteps .astype (np .int64 )
625+ timesteps = timesteps .astype (np .int64 ). astype ( np . float32 )
631626 sigmas = timesteps / self .scheduler .config .num_train_timesteps if sigmas is None else sigmas
632627 mu = calculate_shift (
633628 image_seq_len ,
634629 self .scheduler .config .get ("base_image_seq_len" , 256 ),
635630 self .scheduler .config .get ("base_shift" , 0.25 ),
636631 self .scheduler .config .get ("max_shift" , 0.75 ),
637632 )
638- _ , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas , mu = mu )
639- timesteps = torch .from_numpy (timesteps ).to (device )
633+ timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps , sigmas = sigmas , mu = mu )
640634
641635 # Denoising loop
642636 transformer_dtype = self .transformer .dtype
0 commit comments