Skip to content

Commit e044850

Browse files
committed
add fast tests
1 parent 1970f4f commit e044850

File tree

1 file changed

+326
-0
lines changed

1 file changed

+326
-0
lines changed
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
from PIL import Image
21+
from transformers import AutoTokenizer, T5EncoderModel
22+
23+
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
24+
from diffusers.utils.testing_utils import (
25+
enable_full_determinism,
26+
torch_device,
27+
)
28+
29+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
30+
from ..test_pipelines_common import (
31+
PipelineTesterMixin,
32+
check_qkv_fusion_matches_attn_procs_length,
33+
check_qkv_fusion_processors_exist,
34+
to_np,
35+
)
36+
37+
38+
enable_full_determinism()
39+
40+
41+
class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
42+
pipeline_class = CogVideoXImageToVideoPipeline
43+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
44+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
45+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
46+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
47+
required_optional_params = frozenset(
48+
[
49+
"num_inference_steps",
50+
"generator",
51+
"latents",
52+
"return_dict",
53+
"callback_on_step_end",
54+
"callback_on_step_end_tensor_inputs",
55+
]
56+
)
57+
58+
def get_dummy_components(self):
59+
torch.manual_seed(0)
60+
transformer = CogVideoXTransformer3DModel(
61+
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
62+
# But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel
63+
# to be 32. The internal dim is product of num_attention_heads and attention_head_dim
64+
num_attention_heads=4,
65+
attention_head_dim=8,
66+
in_channels=8,
67+
out_channels=4,
68+
time_embed_dim=2,
69+
text_embed_dim=32, # Must match with tiny-random-t5
70+
num_layers=1,
71+
sample_width=16, # latent width: 2 -> final width: 16
72+
sample_height=16, # latent height: 2 -> final height: 16
73+
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
74+
patch_size=2,
75+
temporal_compression_ratio=4,
76+
max_text_seq_length=16,
77+
)
78+
79+
torch.manual_seed(0)
80+
vae = AutoencoderKLCogVideoX(
81+
in_channels=3,
82+
out_channels=3,
83+
down_block_types=(
84+
"CogVideoXDownBlock3D",
85+
"CogVideoXDownBlock3D",
86+
"CogVideoXDownBlock3D",
87+
"CogVideoXDownBlock3D",
88+
),
89+
up_block_types=(
90+
"CogVideoXUpBlock3D",
91+
"CogVideoXUpBlock3D",
92+
"CogVideoXUpBlock3D",
93+
"CogVideoXUpBlock3D",
94+
),
95+
block_out_channels=(8, 8, 8, 8),
96+
latent_channels=4,
97+
layers_per_block=1,
98+
norm_num_groups=2,
99+
temporal_compression_ratio=4,
100+
)
101+
102+
torch.manual_seed(0)
103+
scheduler = DDIMScheduler()
104+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
105+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
106+
107+
components = {
108+
"transformer": transformer,
109+
"vae": vae,
110+
"scheduler": scheduler,
111+
"text_encoder": text_encoder,
112+
"tokenizer": tokenizer,
113+
}
114+
return components
115+
116+
def get_dummy_inputs(self, device, seed=0):
117+
if str(device).startswith("mps"):
118+
generator = torch.manual_seed(seed)
119+
else:
120+
generator = torch.Generator(device=device).manual_seed(seed)
121+
122+
# Cannot reduce because convolution kernel becomes bigger than sample
123+
image_height = 16
124+
image_width = 16
125+
image = Image.new("RGB", (image_width, image_height))
126+
inputs = {
127+
"image": image,
128+
"prompt": "dance monkey",
129+
"negative_prompt": "",
130+
"generator": generator,
131+
"num_inference_steps": 2,
132+
"guidance_scale": 6.0,
133+
"height": image_height,
134+
"width": image_width,
135+
"num_frames": 8,
136+
"max_sequence_length": 16,
137+
"output_type": "pt",
138+
}
139+
return inputs
140+
141+
def test_inference(self):
142+
device = "cpu"
143+
144+
components = self.get_dummy_components()
145+
pipe = self.pipeline_class(**components)
146+
pipe.to(device)
147+
pipe.set_progress_bar_config(disable=None)
148+
149+
inputs = self.get_dummy_inputs(device)
150+
video = pipe(**inputs).frames
151+
generated_video = video[0]
152+
153+
self.assertEqual(generated_video.shape, (8, 3, 16, 16))
154+
expected_video = torch.randn(8, 3, 16, 16)
155+
max_diff = np.abs(generated_video - expected_video).max()
156+
self.assertLessEqual(max_diff, 1e10)
157+
158+
def test_callback_inputs(self):
159+
sig = inspect.signature(self.pipeline_class.__call__)
160+
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
161+
has_callback_step_end = "callback_on_step_end" in sig.parameters
162+
163+
if not (has_callback_tensor_inputs and has_callback_step_end):
164+
return
165+
166+
components = self.get_dummy_components()
167+
pipe = self.pipeline_class(**components)
168+
pipe = pipe.to(torch_device)
169+
pipe.set_progress_bar_config(disable=None)
170+
self.assertTrue(
171+
hasattr(pipe, "_callback_tensor_inputs"),
172+
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
173+
)
174+
175+
def callback_inputs_subset(pipe, i, t, callback_kwargs):
176+
# iterate over callback args
177+
for tensor_name, tensor_value in callback_kwargs.items():
178+
# check that we're only passing in allowed tensor inputs
179+
assert tensor_name in pipe._callback_tensor_inputs
180+
181+
return callback_kwargs
182+
183+
def callback_inputs_all(pipe, i, t, callback_kwargs):
184+
for tensor_name in pipe._callback_tensor_inputs:
185+
assert tensor_name in callback_kwargs
186+
187+
# iterate over callback args
188+
for tensor_name, tensor_value in callback_kwargs.items():
189+
# check that we're only passing in allowed tensor inputs
190+
assert tensor_name in pipe._callback_tensor_inputs
191+
192+
return callback_kwargs
193+
194+
inputs = self.get_dummy_inputs(torch_device)
195+
196+
# Test passing in a subset
197+
inputs["callback_on_step_end"] = callback_inputs_subset
198+
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
199+
output = pipe(**inputs)[0]
200+
201+
# Test passing in a everything
202+
inputs["callback_on_step_end"] = callback_inputs_all
203+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
204+
output = pipe(**inputs)[0]
205+
206+
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
207+
is_last = i == (pipe.num_timesteps - 1)
208+
if is_last:
209+
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
210+
return callback_kwargs
211+
212+
inputs["callback_on_step_end"] = callback_inputs_change_tensor
213+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
214+
output = pipe(**inputs)[0]
215+
assert output.abs().sum() < 1e10
216+
217+
def test_inference_batch_single_identical(self):
218+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
219+
220+
def test_attention_slicing_forward_pass(
221+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
222+
):
223+
if not self.test_attention_slicing:
224+
return
225+
226+
components = self.get_dummy_components()
227+
pipe = self.pipeline_class(**components)
228+
for component in pipe.components.values():
229+
if hasattr(component, "set_default_attn_processor"):
230+
component.set_default_attn_processor()
231+
pipe.to(torch_device)
232+
pipe.set_progress_bar_config(disable=None)
233+
234+
generator_device = "cpu"
235+
inputs = self.get_dummy_inputs(generator_device)
236+
output_without_slicing = pipe(**inputs)[0]
237+
238+
pipe.enable_attention_slicing(slice_size=1)
239+
inputs = self.get_dummy_inputs(generator_device)
240+
output_with_slicing1 = pipe(**inputs)[0]
241+
242+
pipe.enable_attention_slicing(slice_size=2)
243+
inputs = self.get_dummy_inputs(generator_device)
244+
output_with_slicing2 = pipe(**inputs)[0]
245+
246+
if test_max_difference:
247+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
248+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
249+
self.assertLess(
250+
max(max_diff1, max_diff2),
251+
expected_max_diff,
252+
"Attention slicing should not affect the inference results",
253+
)
254+
255+
def test_vae_tiling(self, expected_diff_max: float = 0.3):
256+
# Note(aryan): Investigate why this needs a bit higher tolerance
257+
generator_device = "cpu"
258+
components = self.get_dummy_components()
259+
260+
pipe = self.pipeline_class(**components)
261+
pipe.to("cpu")
262+
pipe.set_progress_bar_config(disable=None)
263+
264+
# Without tiling
265+
inputs = self.get_dummy_inputs(generator_device)
266+
inputs["height"] = inputs["width"] = 128
267+
output_without_tiling = pipe(**inputs)[0]
268+
269+
# With tiling
270+
pipe.vae.enable_tiling(
271+
tile_sample_min_height=96,
272+
tile_sample_min_width=96,
273+
tile_overlap_factor_height=1 / 12,
274+
tile_overlap_factor_width=1 / 12,
275+
)
276+
inputs = self.get_dummy_inputs(generator_device)
277+
inputs["height"] = inputs["width"] = 128
278+
output_with_tiling = pipe(**inputs)[0]
279+
280+
self.assertLess(
281+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
282+
expected_diff_max,
283+
"VAE tiling should not affect the inference results",
284+
)
285+
286+
@unittest.skip("xformers attention processor does not exist for CogVideoX")
287+
def test_xformers_attention_forwardGenerator_pass(self):
288+
pass
289+
290+
def test_fused_qkv_projections(self):
291+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
292+
components = self.get_dummy_components()
293+
pipe = self.pipeline_class(**components)
294+
pipe = pipe.to(device)
295+
pipe.set_progress_bar_config(disable=None)
296+
297+
inputs = self.get_dummy_inputs(device)
298+
frames = pipe(**inputs).frames # [B, F, C, H, W]
299+
original_image_slice = frames[0, -2:, -1, -3:, -3:]
300+
301+
pipe.fuse_qkv_projections()
302+
assert check_qkv_fusion_processors_exist(
303+
pipe.transformer
304+
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
305+
assert check_qkv_fusion_matches_attn_procs_length(
306+
pipe.transformer, pipe.transformer.original_attn_processors
307+
), "Something wrong with the attention processors concerning the fused QKV projections."
308+
309+
inputs = self.get_dummy_inputs(device)
310+
frames = pipe(**inputs).frames
311+
image_slice_fused = frames[0, -2:, -1, -3:, -3:]
312+
313+
pipe.transformer.unfuse_qkv_projections()
314+
inputs = self.get_dummy_inputs(device)
315+
frames = pipe(**inputs).frames
316+
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
317+
318+
assert np.allclose(
319+
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
320+
), "Fusion of QKV projections shouldn't affect the outputs."
321+
assert np.allclose(
322+
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
323+
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
324+
assert np.allclose(
325+
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
326+
), "Original outputs should match when fused QKV projections are disabled."

0 commit comments

Comments
 (0)