|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import gc |
15 | 16 | import inspect |
16 | 17 | import unittest |
17 | 18 |
|
|
21 | 22 | from transformers import AutoTokenizer, T5EncoderModel |
22 | 23 |
|
23 | 24 | from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler |
| 25 | +from diffusers.utils import load_image |
24 | 26 | from diffusers.utils.testing_utils import ( |
25 | 27 | enable_full_determinism, |
| 28 | + numpy_cosine_similarity_distance, |
| 29 | + require_torch_gpu, |
| 30 | + slow, |
26 | 31 | torch_device, |
27 | 32 | ) |
28 | 33 |
|
@@ -321,3 +326,48 @@ def test_fused_qkv_projections(self): |
321 | 326 | assert np.allclose( |
322 | 327 | original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 |
323 | 328 | ), "Original outputs should match when fused QKV projections are disabled." |
| 329 | + |
| 330 | + |
| 331 | +@unittest.skip("The model 'THUDM/CogVideoX-5b-I2V' is not public yet.") |
| 332 | +@slow |
| 333 | +@require_torch_gpu |
| 334 | +class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase): |
| 335 | + prompt = "A painting of a squirrel eating a burger." |
| 336 | + |
| 337 | + def setUp(self): |
| 338 | + super().setUp() |
| 339 | + gc.collect() |
| 340 | + torch.cuda.empty_cache() |
| 341 | + |
| 342 | + def tearDown(self): |
| 343 | + super().tearDown() |
| 344 | + gc.collect() |
| 345 | + torch.cuda.empty_cache() |
| 346 | + |
| 347 | + def test_cogvideox(self): |
| 348 | + generator = torch.Generator("cpu").manual_seed(0) |
| 349 | + |
| 350 | + pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) |
| 351 | + pipe.enable_model_cpu_offload() |
| 352 | + |
| 353 | + prompt = self.prompt |
| 354 | + image = load_image( |
| 355 | + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" |
| 356 | + ) |
| 357 | + |
| 358 | + videos = pipe( |
| 359 | + image=image, |
| 360 | + prompt=prompt, |
| 361 | + height=480, |
| 362 | + width=720, |
| 363 | + num_frames=16, |
| 364 | + generator=generator, |
| 365 | + num_inference_steps=2, |
| 366 | + output_type="pt", |
| 367 | + ).frames |
| 368 | + |
| 369 | + video = videos[0] |
| 370 | + expected_video = torch.randn(1, 16, 480, 720, 3).numpy() |
| 371 | + |
| 372 | + max_diff = numpy_cosine_similarity_distance(video, expected_video) |
| 373 | + assert max_diff < 1e-3, f"Max diff is too high. got {video}" |
0 commit comments