1+ import inspect
2+ import math
13import os
2- from typing import Any , List , Tuple , Callable , Optional , Union , Dict
4+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
35
46import torch
57import torch .distributed
6- import inspect
78from diffusers import CogVideoXPipeline
9+ from diffusers .callbacks import MultiPipelineCallbacks , PipelineCallback
810from diffusers .pipelines .cogvideo .pipeline_cogvideox import (
911 CogVideoXPipelineOutput ,
1012 retrieve_timesteps ,
1113)
1214from diffusers .schedulers import CogVideoXDPMScheduler
13- from diffusers .callbacks import MultiPipelineCallbacks , PipelineCallback
14-
15- import math
1615
1716from xfuser .config import EngineConfig
18-
1917from xfuser .core .distributed import (
18+ get_cfg_group ,
19+ get_classifier_free_guidance_world_size ,
2020 get_pipeline_parallel_world_size ,
21- get_sequence_parallel_world_size ,
21+ get_runtime_state ,
2222 get_sequence_parallel_rank ,
23- get_classifier_free_guidance_world_size ,
24- get_cfg_group ,
23+ get_sequence_parallel_world_size ,
2524 get_sp_group ,
26- get_runtime_state ,
2725 is_dp_last_group ,
2826)
29-
3027from xfuser .model_executor .pipelines import xFuserPipelineBaseWrapper
28+
3129from .register import xFuserPipelineWrapperRegister
3230
3331
3432@xFuserPipelineWrapperRegister .register (CogVideoXPipeline )
3533class xFuserCogVideoXPipeline (xFuserPipelineBaseWrapper ):
36-
3734 @classmethod
3835 def from_pretrained (
3936 cls ,
4037 pretrained_model_name_or_path : Optional [Union [str , os .PathLike ]],
4138 engine_config : EngineConfig ,
4239 ** kwargs ,
4340 ):
44- pipeline = CogVideoXPipeline .from_pretrained (
45- pretrained_model_name_or_path , ** kwargs
46- )
41+ pipeline = CogVideoXPipeline .from_pretrained (pretrained_model_name_or_path , ** kwargs )
4742 return cls (pipeline , engine_config )
4843
4944 @torch .no_grad ()
@@ -74,6 +69,7 @@ def __call__(
7469 ] = None ,
7570 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
7671 max_sequence_length : int = 226 ,
72+ ** kwargs ,
7773 ) -> Union [CogVideoXPipelineOutput , Tuple ]:
7874 """
7975 Function invoked when calling the pipeline for generation.
@@ -213,9 +209,7 @@ def __call__(
213209 max_sequence_length = max_sequence_length ,
214210 device = device ,
215211 )
216- prompt_embeds = self ._process_cfg_split_batch (
217- negative_prompt_embeds , prompt_embeds
218- )
212+ prompt_embeds = self ._process_cfg_split_batch (negative_prompt_embeds , prompt_embeds )
219213
220214 # 4. Prepare timesteps
221215 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
@@ -249,9 +243,7 @@ def __call__(
249243
250244 # 7. Create rotary embeds if required
251245 image_rotary_emb = (
252- self ._prepare_rotary_positional_embeddings (
253- height , width , latents .size (1 ), device
254- )
246+ self ._prepare_rotary_positional_embeddings (height , width , latents .size (1 ), device )
255247 if self .transformer .config .use_rotary_positional_embeddings
256248 else None
257249 )
@@ -261,8 +253,7 @@ def __call__(
261253
262254 p_t = self .transformer .config .patch_size_t or 1
263255 latents , prompt_embeds , image_rotary_emb = self ._init_sync_pipeline (
264- latents , prompt_embeds , image_rotary_emb ,
265- (latents .size (1 ) + p_t - 1 ) // p_t
256+ latents , prompt_embeds , image_rotary_emb , (latents .size (1 ) + p_t - 1 ) // p_t
266257 )
267258 with self .progress_bar (total = num_inference_steps ) as progress_bar :
268259 # for DPM-solver++
@@ -272,9 +263,7 @@ def __call__(
272263 continue
273264
274265 if do_classifier_free_guidance :
275- latent_model_input = torch .cat (
276- [latents ] * (2 // get_classifier_free_guidance_world_size ())
277- )
266+ latent_model_input = torch .cat ([latents ] * (2 // get_classifier_free_guidance_world_size ()))
278267
279268 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
280269
@@ -289,6 +278,7 @@ def __call__(
289278 image_rotary_emb = image_rotary_emb ,
290279 attention_kwargs = attention_kwargs ,
291280 return_dict = False ,
281+ ** kwargs ,
292282 )[0 ]
293283 noise_pred = noise_pred .float ()
294284
@@ -304,9 +294,7 @@ def __call__(
304294 noise_pred_uncond , noise_pred_text = get_cfg_group ().all_gather (
305295 noise_pred , separate_tensors = True
306296 )
307- noise_pred = noise_pred_uncond + self .guidance_scale * (
308- noise_pred_text - noise_pred_uncond
309- )
297+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
310298
311299 # compute the previous noisy sample x_t -> x_t-1
312300 if not isinstance (self .scheduler .module , CogVideoXDPMScheduler ):
@@ -334,9 +322,7 @@ def __call__(
334322 prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
335323 negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
336324
337- if i == len (timesteps ) - 1 or (
338- (i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0
339- ):
325+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
340326 progress_bar .update ()
341327
342328 if get_sequence_parallel_world_size () > 1 :
@@ -369,12 +355,14 @@ def _init_sync_pipeline(
369355 latents_frames : Optional [int ] = None ,
370356 ):
371357 latents = super ()._init_video_sync_pipeline (latents )
372-
358+
373359 if get_runtime_state ().split_text_embed_in_sp :
374360 if prompt_embeds .shape [- 2 ] % get_sequence_parallel_world_size () == 0 :
375- prompt_embeds = torch .chunk (prompt_embeds , get_sequence_parallel_world_size (), dim = - 2 )[get_sequence_parallel_rank ()]
361+ prompt_embeds = torch .chunk (prompt_embeds , get_sequence_parallel_world_size (), dim = - 2 )[
362+ get_sequence_parallel_rank ()
363+ ]
376364 else :
377- get_runtime_state ().split_text_embed_in_sp = False
365+ get_runtime_state ().split_text_embed_in_sp = False
378366
379367 if image_rotary_emb is not None :
380368 assert latents_frames is not None
@@ -383,9 +371,7 @@ def _init_sync_pipeline(
383371 torch .cat (
384372 [
385373 image_rotary_emb [0 ]
386- .reshape (latents_frames , - 1 , d )[
387- :, start_token_idx :end_token_idx
388- ]
374+ .reshape (latents_frames , - 1 , d )[:, start_token_idx :end_token_idx ]
389375 .reshape (- 1 , d )
390376 for start_token_idx , end_token_idx in get_runtime_state ().pp_patches_token_start_end_idx_global
391377 ],
@@ -394,9 +380,7 @@ def _init_sync_pipeline(
394380 torch .cat (
395381 [
396382 image_rotary_emb [1 ]
397- .reshape (latents_frames , - 1 , d )[
398- :, start_token_idx :end_token_idx
399- ]
383+ .reshape (latents_frames , - 1 , d )[:, start_token_idx :end_token_idx ]
400384 .reshape (- 1 , d )
401385 for start_token_idx , end_token_idx in get_runtime_state ().pp_patches_token_start_end_idx_global
402386 ],
@@ -405,7 +389,6 @@ def _init_sync_pipeline(
405389 )
406390 return latents , prompt_embeds , image_rotary_emb
407391
408-
409392 def prepare_extra_step_kwargs (self , generator , eta ):
410393 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
411394 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
0 commit comments