| 
19 | 19 |     UNet2DConditionModel,  | 
20 | 20 | )  | 
21 | 21 | from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings  | 
22 |  | -from diffusers.utils.testing_utils import torch_device  | 
 | 22 | +from diffusers.utils.testing_utils import require_torch_gpu, torch_device  | 
23 | 23 | 
 
  | 
24 | 24 | 
 
  | 
25 | 25 | class IsSafetensorsCompatibleTests(unittest.TestCase):  | 
@@ -826,3 +826,104 @@ def test_video_to_video(self):  | 
826 | 826 |         with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):  | 
827 | 827 |             _ = pipe(**inputs)  | 
828 | 828 |             self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")  | 
 | 829 | + | 
 | 830 | + | 
 | 831 | +@require_torch_gpu  | 
 | 832 | +class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):  | 
 | 833 | +    expected_pipe_device = torch.device("cuda:0")  | 
 | 834 | +    expected_pipe_dtype = torch.float64  | 
 | 835 | + | 
 | 836 | +    def get_dummy_components_image_generation(self):  | 
 | 837 | +        cross_attention_dim = 8  | 
 | 838 | + | 
 | 839 | +        torch.manual_seed(0)  | 
 | 840 | +        unet = UNet2DConditionModel(  | 
 | 841 | +            block_out_channels=(4, 8),  | 
 | 842 | +            layers_per_block=1,  | 
 | 843 | +            sample_size=32,  | 
 | 844 | +            in_channels=4,  | 
 | 845 | +            out_channels=4,  | 
 | 846 | +            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),  | 
 | 847 | +            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),  | 
 | 848 | +            cross_attention_dim=cross_attention_dim,  | 
 | 849 | +            norm_num_groups=2,  | 
 | 850 | +        )  | 
 | 851 | +        scheduler = DDIMScheduler(  | 
 | 852 | +            beta_start=0.00085,  | 
 | 853 | +            beta_end=0.012,  | 
 | 854 | +            beta_schedule="scaled_linear",  | 
 | 855 | +            clip_sample=False,  | 
 | 856 | +            set_alpha_to_one=False,  | 
 | 857 | +        )  | 
 | 858 | +        torch.manual_seed(0)  | 
 | 859 | +        vae = AutoencoderKL(  | 
 | 860 | +            block_out_channels=[4, 8],  | 
 | 861 | +            in_channels=3,  | 
 | 862 | +            out_channels=3,  | 
 | 863 | +            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],  | 
 | 864 | +            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],  | 
 | 865 | +            latent_channels=4,  | 
 | 866 | +            norm_num_groups=2,  | 
 | 867 | +        )  | 
 | 868 | +        torch.manual_seed(0)  | 
 | 869 | +        text_encoder_config = CLIPTextConfig(  | 
 | 870 | +            bos_token_id=0,  | 
 | 871 | +            eos_token_id=2,  | 
 | 872 | +            hidden_size=cross_attention_dim,  | 
 | 873 | +            intermediate_size=16,  | 
 | 874 | +            layer_norm_eps=1e-05,  | 
 | 875 | +            num_attention_heads=2,  | 
 | 876 | +            num_hidden_layers=2,  | 
 | 877 | +            pad_token_id=1,  | 
 | 878 | +            vocab_size=1000,  | 
 | 879 | +        )  | 
 | 880 | +        text_encoder = CLIPTextModel(text_encoder_config)  | 
 | 881 | +        tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")  | 
 | 882 | + | 
 | 883 | +        components = {  | 
 | 884 | +            "unet": unet,  | 
 | 885 | +            "scheduler": scheduler,  | 
 | 886 | +            "vae": vae,  | 
 | 887 | +            "text_encoder": text_encoder,  | 
 | 888 | +            "tokenizer": tokenizer,  | 
 | 889 | +            "safety_checker": None,  | 
 | 890 | +            "feature_extractor": None,  | 
 | 891 | +            "image_encoder": None,  | 
 | 892 | +        }  | 
 | 893 | +        return components  | 
 | 894 | + | 
 | 895 | +    def test_deterministic_device(self):  | 
 | 896 | +        components = self.get_dummy_components_image_generation()  | 
 | 897 | + | 
 | 898 | +        pipe = StableDiffusionPipeline(**components)  | 
 | 899 | +        pipe.to(device=torch_device, dtype=torch.float32)  | 
 | 900 | + | 
 | 901 | +        pipe.unet.to(device="cpu")  | 
 | 902 | +        pipe.vae.to(device="cuda")  | 
 | 903 | +        pipe.text_encoder.to(device="cuda:0")  | 
 | 904 | + | 
 | 905 | +        pipe_device = pipe.device  | 
 | 906 | + | 
 | 907 | +        self.assertEqual(  | 
 | 908 | +            self.expected_pipe_device,  | 
 | 909 | +            pipe_device,  | 
 | 910 | +            f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",  | 
 | 911 | +        )  | 
 | 912 | + | 
 | 913 | +    def test_deterministic_dtype(self):  | 
 | 914 | +        components = self.get_dummy_components_image_generation()  | 
 | 915 | + | 
 | 916 | +        pipe = StableDiffusionPipeline(**components)  | 
 | 917 | +        pipe.to(device=torch_device, dtype=torch.float32)  | 
 | 918 | + | 
 | 919 | +        pipe.unet.to(dtype=torch.float16)  | 
 | 920 | +        pipe.vae.to(dtype=torch.float32)  | 
 | 921 | +        pipe.text_encoder.to(dtype=torch.float64)  | 
 | 922 | + | 
 | 923 | +        pipe_dtype = pipe.dtype  | 
 | 924 | + | 
 | 925 | +        self.assertEqual(  | 
 | 926 | +            self.expected_pipe_dtype,  | 
 | 927 | +            pipe_dtype,  | 
 | 928 | +            f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",  | 
 | 929 | +        )  | 
0 commit comments