Skip to content

Commit c238fe2

Browse files
committed
make style
1 parent a56c510 commit c238fe2

File tree

4 files changed

+95
-105
lines changed

4 files changed

+95
-105
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@
255255
"BlipDiffusionControlNetPipeline",
256256
"BlipDiffusionPipeline",
257257
"CLIPImageProjection",
258-
"CogVideoXPipeline",
259258
"CogVideoXImageToVideoPipeline",
259+
"CogVideoXPipeline",
260260
"CogVideoXVideoToVideoPipeline",
261261
"CycleDiffusionPipeline",
262262
"FluxControlNetPipeline",
@@ -704,8 +704,8 @@
704704
AudioLDMPipeline,
705705
AuraFlowPipeline,
706706
CLIPImageProjection,
707-
CogVideoXPipeline,
708707
CogVideoXImageToVideoPipeline,
708+
CogVideoXPipeline,
709709
CogVideoXVideoToVideoPipeline,
710710
CycleDiffusionPipeline,
711711
FluxControlNetPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
_import_structure["cogvideo"] = [
142142
"CogVideoXPipeline",
143143
"CogVideoXImageToVideoPipeline",
144-
"CogVideoXVideoToVideoPipeline"
144+
"CogVideoXVideoToVideoPipeline",
145145
]
146146
_import_structure["controlnet"].extend(
147147
[

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 90 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@
1414
# limitations under the License.
1515

1616
import inspect
17-
18-
import PIL
1917
import math
2018
from typing import Callable, Dict, List, Optional, Tuple, Union
19+
20+
import PIL
2121
import torch
22-
from PIL import Image
2322
from transformers import T5EncoderModel, T5Tokenizer
2423

25-
from ...image_processor import PipelineImageInput
2624
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25+
from ...image_processor import PipelineImageInput
2726
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2827
from ...models.embeddings import get_3d_rotary_pos_embed
2928
from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -36,6 +35,7 @@
3635
from ...video_processor import VideoProcessor
3736
from .pipeline_output import CogVideoXPipelineOutput
3837

38+
3939
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4040

4141

@@ -157,14 +157,12 @@ def _gaussian(window_size: int, sigma):
157157
>>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
158158
>>> pipe.to("cuda")
159159
>>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
160-
160+
161161
>>> image = load_image(
162162
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg"
163163
... )
164164
>>> image = image.resize((720, 480))
165-
>>> video = pipe(
166-
... image=image, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50
167-
... ).frames[0]
165+
>>> video = pipe(image=image, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50).frames[0]
168166
>>> export_to_video(frames, "output.mp4", fps=8)
169167
```
170168
"""
@@ -191,12 +189,12 @@ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
191189

192190
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
193191
def retrieve_timesteps(
194-
scheduler,
195-
num_inference_steps: Optional[int] = None,
196-
device: Optional[Union[str, torch.device]] = None,
197-
timesteps: Optional[List[int]] = None,
198-
sigmas: Optional[List[float]] = None,
199-
**kwargs,
192+
scheduler,
193+
num_inference_steps: Optional[int] = None,
194+
device: Optional[Union[str, torch.device]] = None,
195+
timesteps: Optional[List[int]] = None,
196+
sigmas: Optional[List[float]] = None,
197+
**kwargs,
200198
):
201199
"""
202200
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
@@ -251,7 +249,7 @@ def retrieve_timesteps(
251249

252250
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
253251
def retrieve_latents(
254-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
252+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
255253
):
256254
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
257255
return encoder_output.latent_dist.sample(generator)
@@ -296,13 +294,13 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
296294
]
297295

298296
def __init__(
299-
self,
300-
tokenizer: T5Tokenizer,
301-
text_encoder: T5EncoderModel,
302-
image_encoder: AutoencoderKLCogVideoX,
303-
vae: AutoencoderKLCogVideoX,
304-
transformer: CogVideoXTransformer3DModel,
305-
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
297+
self,
298+
tokenizer: T5Tokenizer,
299+
text_encoder: T5EncoderModel,
300+
image_encoder: AutoencoderKLCogVideoX,
301+
vae: AutoencoderKLCogVideoX,
302+
transformer: CogVideoXTransformer3DModel,
303+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
306304
):
307305
super().__init__()
308306

@@ -312,7 +310,7 @@ def __init__(
312310
image_encoder=image_encoder,
313311
vae=vae,
314312
transformer=transformer,
315-
scheduler=scheduler
313+
scheduler=scheduler,
316314
)
317315
self.vae_scale_factor_spatial = (
318316
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -324,11 +322,11 @@ def __init__(
324322
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
325323

326324
def _encode_image(
327-
self,
328-
image: PipelineImageInput,
329-
device: Union[str, torch.device],
330-
num_videos_per_prompt: int,
331-
do_classifier_free_guidance: bool,
325+
self,
326+
image: PipelineImageInput,
327+
device: Union[str, torch.device],
328+
num_videos_per_prompt: int,
329+
do_classifier_free_guidance: bool,
332330
) -> torch.Tensor:
333331
dtype = next(self.image_encoder.parameters()).dtype
334332

@@ -342,7 +340,6 @@ def _encode_image(
342340
image = _resize_with_antialiasing(image, (224, 224))
343341
image = (image + 1.0) / 2.0
344342

345-
346343
# encode image using VAE
347344
image = image.to(device=device, dtype=dtype)
348345
image_embeddings = self.image_encoder(image).image_embeds
@@ -365,12 +362,12 @@ def _encode_image(
365362

366363
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
367364
def _get_t5_prompt_embeds(
368-
self,
369-
prompt: Union[str, List[str]] = None,
370-
num_videos_per_prompt: int = 1,
371-
max_sequence_length: int = 226,
372-
device: Optional[torch.device] = None,
373-
dtype: Optional[torch.dtype] = None,
365+
self,
366+
prompt: Union[str, List[str]] = None,
367+
num_videos_per_prompt: int = 1,
368+
max_sequence_length: int = 226,
369+
device: Optional[torch.device] = None,
370+
dtype: Optional[torch.dtype] = None,
374371
):
375372
device = device or self._execution_device
376373
dtype = dtype or self.text_encoder.dtype
@@ -390,7 +387,7 @@ def _get_t5_prompt_embeds(
390387
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
391388

392389
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
393-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1: -1])
390+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
394391
logger.warning(
395392
"The following part of your input was truncated because `max_sequence_length` is set to "
396393
f" {max_sequence_length} tokens: {removed_text}"
@@ -408,16 +405,16 @@ def _get_t5_prompt_embeds(
408405

409406
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
410407
def encode_prompt(
411-
self,
412-
prompt: Union[str, List[str]],
413-
negative_prompt: Optional[Union[str, List[str]]] = None,
414-
do_classifier_free_guidance: bool = True,
415-
num_videos_per_prompt: int = 1,
416-
prompt_embeds: Optional[torch.Tensor] = None,
417-
negative_prompt_embeds: Optional[torch.Tensor] = None,
418-
max_sequence_length: int = 226,
419-
device: Optional[torch.device] = None,
420-
dtype: Optional[torch.dtype] = None,
408+
self,
409+
prompt: Union[str, List[str]],
410+
negative_prompt: Optional[Union[str, List[str]]] = None,
411+
do_classifier_free_guidance: bool = True,
412+
num_videos_per_prompt: int = 1,
413+
prompt_embeds: Optional[torch.Tensor] = None,
414+
negative_prompt_embeds: Optional[torch.Tensor] = None,
415+
max_sequence_length: int = 226,
416+
device: Optional[torch.device] = None,
417+
dtype: Optional[torch.dtype] = None,
421418
):
422419
r"""
423420
Encodes the prompt into text encoder hidden states.
@@ -489,15 +486,7 @@ def encode_prompt(
489486
return prompt_embeds, negative_prompt_embeds
490487

491488
def prepare_latents(
492-
self,
493-
batch_size,
494-
num_channels_latents,
495-
num_frames,
496-
height, width,
497-
dtype,
498-
device,
499-
generator,
500-
latents=None
489+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
501490
):
502491
shape = (
503492
batch_size,
@@ -535,7 +524,7 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
535524
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
536525

537526
t_start = max(num_inference_steps - init_timestep, 0)
538-
timesteps = timesteps[t_start * self.scheduler.order:]
527+
timesteps = timesteps[t_start * self.scheduler.order :]
539528

540529
return timesteps, num_inference_steps - t_start
541530

@@ -558,17 +547,17 @@ def prepare_extra_step_kwargs(self, generator, eta):
558547
return extra_step_kwargs
559548

560549
def check_inputs(
561-
self,
562-
prompt,
563-
height,
564-
width,
565-
strength,
566-
negative_prompt,
567-
callback_on_step_end_tensor_inputs,
568-
video=None,
569-
latents=None,
570-
prompt_embeds=None,
571-
negative_prompt_embeds=None,
550+
self,
551+
prompt,
552+
height,
553+
width,
554+
strength,
555+
negative_prompt,
556+
callback_on_step_end_tensor_inputs,
557+
video=None,
558+
latents=None,
559+
prompt_embeds=None,
560+
negative_prompt_embeds=None,
572561
):
573562
if height % 8 != 0 or width % 8 != 0:
574563
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -577,7 +566,7 @@ def check_inputs(
577566
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
578567

579568
if callback_on_step_end_tensor_inputs is not None and not all(
580-
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
569+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
581570
):
582571
raise ValueError(
583572
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -634,11 +623,11 @@ def unfuse_qkv_projections(self) -> None:
634623

635624
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
636625
def _prepare_rotary_positional_embeddings(
637-
self,
638-
height: int,
639-
width: int,
640-
num_frames: int,
641-
device: torch.device,
626+
self,
627+
height: int,
628+
width: int,
629+
num_frames: int,
630+
device: torch.device,
642631
) -> Tuple[torch.Tensor, torch.Tensor]:
643632
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
644633
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
@@ -674,32 +663,32 @@ def interrupt(self):
674663
@torch.no_grad()
675664
@replace_example_docstring(EXAMPLE_DOC_STRING)
676665
def __call__(
677-
self,
678-
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
679-
prompt: Optional[Union[str, List[str]]] = None,
680-
negative_prompt: Optional[Union[str, List[str]]] = None,
681-
height: int = 480,
682-
width: int = 720,
683-
num_frames: int = 49,
684-
num_inference_steps: int = 50,
685-
timesteps: Optional[List[int]] = None,
686-
strength: float = 0.8,
687-
guidance_scale: float = 6,
688-
use_dynamic_cfg: bool = False,
689-
num_videos_per_prompt: int = 1,
690-
eta: float = 0.0,
691-
noise_aug_strength: float = 0.02,
692-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
693-
latents: Optional[torch.FloatTensor] = None,
694-
prompt_embeds: Optional[torch.FloatTensor] = None,
695-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
696-
output_type: str = "pil",
697-
return_dict: bool = True,
698-
callback_on_step_end: Optional[
699-
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
700-
] = None,
701-
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
702-
max_sequence_length: int = 226,
666+
self,
667+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
668+
prompt: Optional[Union[str, List[str]]] = None,
669+
negative_prompt: Optional[Union[str, List[str]]] = None,
670+
height: int = 480,
671+
width: int = 720,
672+
num_frames: int = 49,
673+
num_inference_steps: int = 50,
674+
timesteps: Optional[List[int]] = None,
675+
strength: float = 0.8,
676+
guidance_scale: float = 6,
677+
use_dynamic_cfg: bool = False,
678+
num_videos_per_prompt: int = 1,
679+
eta: float = 0.0,
680+
noise_aug_strength: float = 0.02,
681+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
682+
latents: Optional[torch.FloatTensor] = None,
683+
prompt_embeds: Optional[torch.FloatTensor] = None,
684+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
685+
output_type: str = "pil",
686+
return_dict: bool = True,
687+
callback_on_step_end: Optional[
688+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
689+
] = None,
690+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
691+
max_sequence_length: int = 226,
703692
) -> Union[CogVideoXPipelineOutput, Tuple]:
704693
"""
705694
Function invoked when calling the pipeline for generation.
@@ -827,7 +816,7 @@ def __call__(
827816
image=image,
828817
device=device,
829818
num_videos_per_prompt=num_videos_per_prompt,
830-
do_classifier_free_guidance=do_classifier_free_guidance
819+
do_classifier_free_guidance=do_classifier_free_guidance,
831820
)
832821
image = self.video_processor.preprocess(image, height=height, width=width).to(device)
833822
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
@@ -904,8 +893,7 @@ def __call__(
904893
# perform guidance
905894
if use_dynamic_cfg:
906895
self._guidance_scale = 1 + guidance_scale * (
907-
(1 - math.cos(
908-
math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
896+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
909897
)
910898
if do_classifier_free_guidance:
911899
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def from_config(cls, *args, **kwargs):
286286
def from_pretrained(cls, *args, **kwargs):
287287
requires_backends(cls, ["torch", "transformers"])
288288

289+
289290
class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
290291
_backends = ["torch", "transformers"]
291292

@@ -300,6 +301,7 @@ def from_config(cls, *args, **kwargs):
300301
def from_pretrained(cls, *args, **kwargs):
301302
requires_backends(cls, ["torch", "transformers"])
302303

304+
303305
class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
304306
_backends = ["torch", "transformers"]
305307

0 commit comments

Comments
 (0)