Skip to content

Commit a440a6e

Browse files
authored
Add kwargs to cogvideox pipeline __call__ (#501)
1 parent ac5ad29 commit a440a6e

File tree

1 file changed

+25
-42
lines changed

1 file changed

+25
-42
lines changed

xfuser/model_executor/pipelines/pipeline_cogvideox.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,44 @@
1+
import inspect
2+
import math
13
import os
2-
from typing import Any, List, Tuple, Callable, Optional, Union, Dict
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
35

46
import torch
57
import torch.distributed
6-
import inspect
78
from diffusers import CogVideoXPipeline
9+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
810
from diffusers.pipelines.cogvideo.pipeline_cogvideox import (
911
CogVideoXPipelineOutput,
1012
retrieve_timesteps,
1113
)
1214
from diffusers.schedulers import CogVideoXDPMScheduler
13-
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
14-
15-
import math
1615

1716
from xfuser.config import EngineConfig
18-
1917
from xfuser.core.distributed import (
18+
get_cfg_group,
19+
get_classifier_free_guidance_world_size,
2020
get_pipeline_parallel_world_size,
21-
get_sequence_parallel_world_size,
21+
get_runtime_state,
2222
get_sequence_parallel_rank,
23-
get_classifier_free_guidance_world_size,
24-
get_cfg_group,
23+
get_sequence_parallel_world_size,
2524
get_sp_group,
26-
get_runtime_state,
2725
is_dp_last_group,
2826
)
29-
3027
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
28+
3129
from .register import xFuserPipelineWrapperRegister
3230

3331

3432
@xFuserPipelineWrapperRegister.register(CogVideoXPipeline)
3533
class xFuserCogVideoXPipeline(xFuserPipelineBaseWrapper):
36-
3734
@classmethod
3835
def from_pretrained(
3936
cls,
4037
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
4138
engine_config: EngineConfig,
4239
**kwargs,
4340
):
44-
pipeline = CogVideoXPipeline.from_pretrained(
45-
pretrained_model_name_or_path, **kwargs
46-
)
41+
pipeline = CogVideoXPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
4742
return cls(pipeline, engine_config)
4843

4944
@torch.no_grad()
@@ -74,6 +69,7 @@ def __call__(
7469
] = None,
7570
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
7671
max_sequence_length: int = 226,
72+
**kwargs,
7773
) -> Union[CogVideoXPipelineOutput, Tuple]:
7874
"""
7975
Function invoked when calling the pipeline for generation.
@@ -213,9 +209,7 @@ def __call__(
213209
max_sequence_length=max_sequence_length,
214210
device=device,
215211
)
216-
prompt_embeds = self._process_cfg_split_batch(
217-
negative_prompt_embeds, prompt_embeds
218-
)
212+
prompt_embeds = self._process_cfg_split_batch(negative_prompt_embeds, prompt_embeds)
219213

220214
# 4. Prepare timesteps
221215
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -249,9 +243,7 @@ def __call__(
249243

250244
# 7. Create rotary embeds if required
251245
image_rotary_emb = (
252-
self._prepare_rotary_positional_embeddings(
253-
height, width, latents.size(1), device
254-
)
246+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
255247
if self.transformer.config.use_rotary_positional_embeddings
256248
else None
257249
)
@@ -261,8 +253,7 @@ def __call__(
261253

262254
p_t = self.transformer.config.patch_size_t or 1
263255
latents, prompt_embeds, image_rotary_emb = self._init_sync_pipeline(
264-
latents, prompt_embeds, image_rotary_emb,
265-
(latents.size(1) + p_t - 1) // p_t
256+
latents, prompt_embeds, image_rotary_emb, (latents.size(1) + p_t - 1) // p_t
266257
)
267258
with self.progress_bar(total=num_inference_steps) as progress_bar:
268259
# for DPM-solver++
@@ -272,9 +263,7 @@ def __call__(
272263
continue
273264

274265
if do_classifier_free_guidance:
275-
latent_model_input = torch.cat(
276-
[latents] * (2 // get_classifier_free_guidance_world_size())
277-
)
266+
latent_model_input = torch.cat([latents] * (2 // get_classifier_free_guidance_world_size()))
278267

279268
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
280269

@@ -289,6 +278,7 @@ def __call__(
289278
image_rotary_emb=image_rotary_emb,
290279
attention_kwargs=attention_kwargs,
291280
return_dict=False,
281+
**kwargs,
292282
)[0]
293283
noise_pred = noise_pred.float()
294284

@@ -304,9 +294,7 @@ def __call__(
304294
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
305295
noise_pred, separate_tensors=True
306296
)
307-
noise_pred = noise_pred_uncond + self.guidance_scale * (
308-
noise_pred_text - noise_pred_uncond
309-
)
297+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
310298

311299
# compute the previous noisy sample x_t -> x_t-1
312300
if not isinstance(self.scheduler.module, CogVideoXDPMScheduler):
@@ -334,9 +322,7 @@ def __call__(
334322
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
335323
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
336324

337-
if i == len(timesteps) - 1 or (
338-
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
339-
):
325+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
340326
progress_bar.update()
341327

342328
if get_sequence_parallel_world_size() > 1:
@@ -369,12 +355,14 @@ def _init_sync_pipeline(
369355
latents_frames: Optional[int] = None,
370356
):
371357
latents = super()._init_video_sync_pipeline(latents)
372-
358+
373359
if get_runtime_state().split_text_embed_in_sp:
374360
if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0:
375-
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
361+
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[
362+
get_sequence_parallel_rank()
363+
]
376364
else:
377-
get_runtime_state().split_text_embed_in_sp = False
365+
get_runtime_state().split_text_embed_in_sp = False
378366

379367
if image_rotary_emb is not None:
380368
assert latents_frames is not None
@@ -383,9 +371,7 @@ def _init_sync_pipeline(
383371
torch.cat(
384372
[
385373
image_rotary_emb[0]
386-
.reshape(latents_frames, -1, d)[
387-
:, start_token_idx:end_token_idx
388-
]
374+
.reshape(latents_frames, -1, d)[:, start_token_idx:end_token_idx]
389375
.reshape(-1, d)
390376
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
391377
],
@@ -394,9 +380,7 @@ def _init_sync_pipeline(
394380
torch.cat(
395381
[
396382
image_rotary_emb[1]
397-
.reshape(latents_frames, -1, d)[
398-
:, start_token_idx:end_token_idx
399-
]
383+
.reshape(latents_frames, -1, d)[:, start_token_idx:end_token_idx]
400384
.reshape(-1, d)
401385
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
402386
],
@@ -405,7 +389,6 @@ def _init_sync_pipeline(
405389
)
406390
return latents, prompt_embeds, image_rotary_emb
407391

408-
409392
def prepare_extra_step_kwargs(self, generator, eta):
410393
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
411394
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

0 commit comments

Comments
 (0)