|
15 | 15 | HiDreamImageTransformer2DModel, |
16 | 16 | SD3Transformer2DModel, |
17 | 17 | StableDiffusion3Pipeline, |
| 18 | + WanTransformer3DModel, |
| 19 | + WanVACETransformer3DModel, |
18 | 20 | ) |
19 | 21 | from diffusers.utils import load_image |
20 | 22 | from diffusers.utils.testing_utils import ( |
@@ -577,3 +579,71 @@ def get_dummy_inputs(self): |
577 | 579 | ).to(torch_device, self.torch_dtype), |
578 | 580 | "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype), |
579 | 581 | } |
| 582 | + |
| 583 | + |
| 584 | +class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): |
| 585 | + ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf" |
| 586 | + torch_dtype = torch.bfloat16 |
| 587 | + model_cls = WanTransformer3DModel |
| 588 | + expected_memory_use_in_gb = 9 |
| 589 | + |
| 590 | + def get_dummy_inputs(self): |
| 591 | + return { |
| 592 | + "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( |
| 593 | + torch_device, self.torch_dtype |
| 594 | + ), |
| 595 | + "encoder_hidden_states": torch.randn( |
| 596 | + (1, 512, 4096), |
| 597 | + generator=torch.Generator("cpu").manual_seed(0), |
| 598 | + ).to(torch_device, self.torch_dtype), |
| 599 | + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), |
| 600 | + } |
| 601 | + |
| 602 | + |
| 603 | +class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): |
| 604 | + ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf" |
| 605 | + torch_dtype = torch.bfloat16 |
| 606 | + model_cls = WanTransformer3DModel |
| 607 | + expected_memory_use_in_gb = 9 |
| 608 | + |
| 609 | + def get_dummy_inputs(self): |
| 610 | + return { |
| 611 | + "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( |
| 612 | + torch_device, self.torch_dtype |
| 613 | + ), |
| 614 | + "encoder_hidden_states": torch.randn( |
| 615 | + (1, 512, 4096), |
| 616 | + generator=torch.Generator("cpu").manual_seed(0), |
| 617 | + ).to(torch_device, self.torch_dtype), |
| 618 | + "encoder_hidden_states_image": torch.randn( |
| 619 | + (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0) |
| 620 | + ).to(torch_device, self.torch_dtype), |
| 621 | + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), |
| 622 | + } |
| 623 | + |
| 624 | + |
| 625 | +class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): |
| 626 | + ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf" |
| 627 | + torch_dtype = torch.bfloat16 |
| 628 | + model_cls = WanVACETransformer3DModel |
| 629 | + expected_memory_use_in_gb = 9 |
| 630 | + |
| 631 | + def get_dummy_inputs(self): |
| 632 | + return { |
| 633 | + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( |
| 634 | + torch_device, self.torch_dtype |
| 635 | + ), |
| 636 | + "encoder_hidden_states": torch.randn( |
| 637 | + (1, 512, 4096), |
| 638 | + generator=torch.Generator("cpu").manual_seed(0), |
| 639 | + ).to(torch_device, self.torch_dtype), |
| 640 | + "control_hidden_states": torch.randn( |
| 641 | + (1, 96, 2, 64, 64), |
| 642 | + generator=torch.Generator("cpu").manual_seed(0), |
| 643 | + ).to(torch_device, self.torch_dtype), |
| 644 | + "control_hidden_states_scale": torch.randn( |
| 645 | + (8,), |
| 646 | + generator=torch.Generator("cpu").manual_seed(0), |
| 647 | + ).to(torch_device, self.torch_dtype), |
| 648 | + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), |
| 649 | + } |
0 commit comments