Skip to content

Commit c1f7a80

Browse files
committed
image latents preparation
1 parent c238fe2 commit c1f7a80

File tree

2 files changed

+68
-185
lines changed

2 files changed

+68
-185
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import torch
55
from transformers import T5EncoderModel, T5Tokenizer
66

7-
from diffusers import (AutoencoderKLCogVideoX,
8-
CogVideoXDDIMScheduler,
9-
CogVideoXPipeline,
10-
CogVideoXImageToVideoPipeline,
11-
CogVideoXTransformer3DModel
12-
)
7+
from diffusers import (
8+
AutoencoderKLCogVideoX,
9+
CogVideoXDDIMScheduler,
10+
CogVideoXImageToVideoPipeline,
11+
CogVideoXPipeline,
12+
CogVideoXTransformer3DModel,
13+
)
1314

1415

1516
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
@@ -95,7 +96,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
9596
"freqs_cos": remove_keys_inplace,
9697
"position_embedding": remove_keys_inplace,
9798
# TODO zRzRzRzRzRzRzR: really need to remove?
98-
"pos_embedding": remove_keys_inplace
99+
"pos_embedding": remove_keys_inplace,
99100
}
100101

101102
VAE_KEYS_RENAME_DICT = {
@@ -134,12 +135,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
134135

135136

136137
def convert_transformer(
137-
ckpt_path: str,
138-
num_layers: int,
139-
num_attention_heads: int,
140-
use_rotary_positional_embeddings: bool,
141-
i2v: bool,
142-
dtype: torch.dtype,
138+
ckpt_path: str,
139+
num_layers: int,
140+
num_attention_heads: int,
141+
use_rotary_positional_embeddings: bool,
142+
i2v: bool,
143+
dtype: torch.dtype,
143144
):
144145
PREFIX_KEY = "model.diffusion_model."
145146

@@ -152,7 +153,7 @@ def convert_transformer(
152153
).to(dtype=dtype)
153154

154155
for key in list(original_state_dict.keys()):
155-
new_key = key[len(PREFIX_KEY):]
156+
new_key = key[len(PREFIX_KEY) :]
156157
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
157158
new_key = new_key.replace(replace_key, rename_key)
158159
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -268,15 +269,11 @@ def get_args():
268269
image_encoder=vae,
269270
vae=vae,
270271
transformer=transformer,
271-
scheduler=scheduler
272+
scheduler=scheduler,
272273
)
273274
else:
274275
pipe = CogVideoXPipeline(
275-
tokenizer=tokenizer,
276-
text_encoder=text_encoder,
277-
vae=vae,
278-
transformer=transformer,
279-
scheduler=scheduler
276+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
280277
)
281278

282279
if args.fp16:

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 51 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from transformers import T5EncoderModel, T5Tokenizer
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25-
from ...image_processor import PipelineImageInput
2625
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
2726
from ...models.embeddings import get_3d_rotary_pos_embed
2827
from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -39,113 +38,6 @@
3938
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4039

4140

42-
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
43-
h, w = input.shape[-2:]
44-
factors = (h / size[0], w / size[1])
45-
46-
# First, we have to determine sigma
47-
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
48-
sigmas = (
49-
max((factors[0] - 1.0) / 2.0, 0.001),
50-
max((factors[1] - 1.0) / 2.0, 0.001),
51-
)
52-
53-
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
54-
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
55-
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
56-
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
57-
58-
# Make sure it is odd
59-
if (ks[0] % 2) == 0:
60-
ks = ks[0] + 1, ks[1]
61-
62-
if (ks[1] % 2) == 0:
63-
ks = ks[0], ks[1] + 1
64-
65-
input = _gaussian_blur2d(input, ks, sigmas)
66-
67-
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
68-
return output
69-
70-
71-
def _gaussian_blur2d(input, kernel_size, sigma):
72-
if isinstance(sigma, tuple):
73-
sigma = torch.tensor([sigma], dtype=input.dtype)
74-
else:
75-
sigma = sigma.to(dtype=input.dtype)
76-
77-
ky, kx = int(kernel_size[0]), int(kernel_size[1])
78-
bs = sigma.shape[0]
79-
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
80-
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
81-
out_x = _filter2d(input, kernel_x[..., None, :])
82-
out = _filter2d(out_x, kernel_y[..., None])
83-
84-
return out
85-
86-
87-
def _compute_padding(kernel_size):
88-
"""Compute padding tuple."""
89-
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
90-
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
91-
if len(kernel_size) < 2:
92-
raise AssertionError(kernel_size)
93-
computed = [k - 1 for k in kernel_size]
94-
95-
# for even kernels we need to do asymmetric padding :(
96-
out_padding = 2 * len(kernel_size) * [0]
97-
98-
for i in range(len(kernel_size)):
99-
computed_tmp = computed[-(i + 1)]
100-
101-
pad_front = computed_tmp // 2
102-
pad_rear = computed_tmp - pad_front
103-
104-
out_padding[2 * i + 0] = pad_front
105-
out_padding[2 * i + 1] = pad_rear
106-
107-
return out_padding
108-
109-
110-
def _filter2d(input, kernel):
111-
# prepare kernel
112-
b, c, h, w = input.shape
113-
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
114-
115-
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
116-
117-
height, width = tmp_kernel.shape[-2:]
118-
119-
padding_shape: List[int] = _compute_padding([height, width])
120-
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
121-
122-
# kernel and input tensor reshape to align element-wise or batch-wise params
123-
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
124-
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
125-
126-
# convolve the tensor with the kernel.
127-
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
128-
129-
out = output.view(b, c, h, w)
130-
return out
131-
132-
133-
def _gaussian(window_size: int, sigma):
134-
if isinstance(sigma, float):
135-
sigma = torch.tensor([[sigma]])
136-
137-
batch_size = sigma.shape[0]
138-
139-
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
140-
141-
if window_size % 2 == 0:
142-
x = x + 0.5
143-
144-
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
145-
146-
return gauss / gauss.sum(-1, keepdim=True)
147-
148-
14941
EXAMPLE_DOC_STRING = """
15042
Examples:
15143
```py
@@ -285,7 +177,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
285177
"""
286178

287179
_optional_components = []
288-
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
180+
model_cpu_offload_seq = "text_encoder->transformer->vae"
289181

290182
_callback_tensor_inputs = [
291183
"latents",
@@ -297,7 +189,6 @@ def __init__(
297189
self,
298190
tokenizer: T5Tokenizer,
299191
text_encoder: T5EncoderModel,
300-
image_encoder: AutoencoderKLCogVideoX,
301192
vae: AutoencoderKLCogVideoX,
302193
transformer: CogVideoXTransformer3DModel,
303194
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
@@ -307,7 +198,6 @@ def __init__(
307198
self.register_modules(
308199
tokenizer=tokenizer,
309200
text_encoder=text_encoder,
310-
image_encoder=image_encoder,
311201
vae=vae,
312202
transformer=transformer,
313203
scheduler=scheduler,
@@ -321,45 +211,6 @@ def __init__(
321211

322212
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
323213

324-
def _encode_image(
325-
self,
326-
image: PipelineImageInput,
327-
device: Union[str, torch.device],
328-
num_videos_per_prompt: int,
329-
do_classifier_free_guidance: bool,
330-
) -> torch.Tensor:
331-
dtype = next(self.image_encoder.parameters()).dtype
332-
333-
if not isinstance(image, torch.Tensor):
334-
image = self.video_processor.pil_to_numpy(image)
335-
image = self.video_processor.numpy_to_pt(image)
336-
337-
# We normalize the image before resizing to match with the original implementation.
338-
# Then we unnormalize it after resizing.
339-
image = image * 2.0 - 1.0
340-
image = _resize_with_antialiasing(image, (224, 224))
341-
image = (image + 1.0) / 2.0
342-
343-
# encode image using VAE
344-
image = image.to(device=device, dtype=dtype)
345-
image_embeddings = self.image_encoder(image).image_embeds
346-
image_embeddings = image_embeddings.unsqueeze(1)
347-
348-
# duplicate image embeddings for each generation per prompt, using mps friendly method
349-
bs_embed, seq_len, _ = image_embeddings.shape
350-
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
351-
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
352-
353-
if do_classifier_free_guidance:
354-
negative_image_embeddings = torch.zeros_like(image_embeddings)
355-
356-
# For classifier free guidance, we need to do two forward passes.
357-
# Here we concatenate the unconditional and text embeddings into a single batch
358-
# to avoid doing two forward passes
359-
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
360-
361-
return image_embeddings
362-
363214
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
364215
def _get_t5_prompt_embeds(
365216
self,
@@ -486,23 +337,65 @@ def encode_prompt(
486337
return prompt_embeds, negative_prompt_embeds
487338

488339
def prepare_latents(
489-
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
340+
self,
341+
image: Optional[torch.Tensor] = None,
342+
batch_size: int = 1,
343+
num_channels_latents: int = 16,
344+
num_frames: int = 13,
345+
height: int = 60,
346+
width: int = 90,
347+
dtype: Optional[torch.dtype] = None,
348+
device: Optional[torch.device] = None,
349+
generator: Optional[torch.Generator] = None,
350+
latents: Optional[torch.Tensor] = None,
490351
):
352+
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
491353
shape = (
492354
batch_size,
493-
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
355+
num_frames,
494356
num_channels_latents,
495357
height // self.vae_scale_factor_spatial,
496358
width // self.vae_scale_factor_spatial,
497359
)
360+
498361
if isinstance(generator, list) and len(generator) != batch_size:
499362
raise ValueError(
500363
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
501364
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
502365
)
503366

504367
if latents is None:
505-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
368+
assert image.ndim == 4
369+
image = image.unsqueeze(2) # [B, C, F, H, W]
370+
371+
if isinstance(generator, list):
372+
if len(generator) != batch_size:
373+
raise ValueError(
374+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
375+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
376+
)
377+
378+
init_latents = [
379+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
380+
]
381+
else:
382+
init_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
383+
384+
init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
385+
init_latents = self.vae.config.scaling_factor * init_latents
386+
387+
padding_shape = (
388+
batch_size,
389+
num_frames - 1,
390+
num_channels_latents,
391+
height // self.vae_scale_factor_spatial,
392+
width // self.vae_scale_factor_spatial,
393+
)
394+
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
395+
init_latents = torch.cat([init_latents, latent_padding], dim=1)
396+
397+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
398+
latents = torch.cat([noise, init_latents], dim=2)
506399
else:
507400
latents = latents.to(device)
508401

@@ -811,17 +704,7 @@ def __call__(
811704
# corresponds to doing no classifier free guidance.
812705
do_classifier_free_guidance = guidance_scale > 1.0
813706

814-
# 3. Encode input prompt and image prompt
815-
image_embeddings = self._encode_image(
816-
image=image,
817-
device=device,
818-
num_videos_per_prompt=num_videos_per_prompt,
819-
do_classifier_free_guidance=do_classifier_free_guidance,
820-
)
821-
image = self.video_processor.preprocess(image, height=height, width=width).to(device)
822-
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
823-
image = image + noise_aug_strength * noise
824-
707+
# 3. Encode input prompt
825708
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
826709
prompt=prompt,
827710
negative_prompt=negative_prompt,
@@ -837,12 +720,15 @@ def __call__(
837720

838721
# 4. Prepare timesteps
839722
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
840-
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
841723
self._num_timesteps = len(timesteps)
842724

843725
# 5. Prepare latents
726+
image = self.video_processor.preprocess(image, height=height, width=width).to(device)
727+
image = image.unsqueeze(2) # [B, C, F, H, W]
728+
844729
latent_channels = self.transformer.config.in_channels
845730
latents = self.prepare_latents(
731+
image,
846732
batch_size * num_videos_per_prompt,
847733
latent_channels,
848734
num_frames,

0 commit comments

Comments
 (0)