Skip to content

Commit 4f89426

Browse files
committed
add slow test
1 parent 2d8dce9 commit 4f89426

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import math
1818
from typing import Callable, Dict, List, Optional, Tuple, Union
1919

20-
import torch
2120
import PIL
21+
import torch
2222
from transformers import T5EncoderModel, T5Tokenizer
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -452,7 +452,7 @@ def check_inputs(
452452
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
453453
f" {type(image)}"
454454
)
455-
455+
456456
if height % 8 != 0 or width % 8 != 0:
457457
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
458458

tests/pipelines/cogvideo/test_cogvideox_image2video.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import inspect
1617
import unittest
1718

@@ -21,8 +22,12 @@
2122
from transformers import AutoTokenizer, T5EncoderModel
2223

2324
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
25+
from diffusers.utils import load_image
2426
from diffusers.utils.testing_utils import (
2527
enable_full_determinism,
28+
numpy_cosine_similarity_distance,
29+
require_torch_gpu,
30+
slow,
2631
torch_device,
2732
)
2833

@@ -321,3 +326,48 @@ def test_fused_qkv_projections(self):
321326
assert np.allclose(
322327
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
323328
), "Original outputs should match when fused QKV projections are disabled."
329+
330+
331+
@unittest.skip("The model 'THUDM/CogVideoX-5b-I2V' is not public yet.")
332+
@slow
333+
@require_torch_gpu
334+
class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase):
335+
prompt = "A painting of a squirrel eating a burger."
336+
337+
def setUp(self):
338+
super().setUp()
339+
gc.collect()
340+
torch.cuda.empty_cache()
341+
342+
def tearDown(self):
343+
super().tearDown()
344+
gc.collect()
345+
torch.cuda.empty_cache()
346+
347+
def test_cogvideox(self):
348+
generator = torch.Generator("cpu").manual_seed(0)
349+
350+
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
351+
pipe.enable_model_cpu_offload()
352+
353+
prompt = self.prompt
354+
image = load_image(
355+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
356+
)
357+
358+
videos = pipe(
359+
image=image,
360+
prompt=prompt,
361+
height=480,
362+
width=720,
363+
num_frames=16,
364+
generator=generator,
365+
num_inference_steps=2,
366+
output_type="pt",
367+
).frames
368+
369+
video = videos[0]
370+
expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
371+
372+
max_diff = numpy_cosine_similarity_distance(video, expected_video)
373+
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

0 commit comments

Comments
 (0)