|  | 
| 16 | 16 | import unittest | 
| 17 | 17 | 
 | 
| 18 | 18 | import torch | 
|  | 19 | +from parameterized import parameterized | 
| 19 | 20 | 
 | 
| 20 | 21 | from diffusers import DiffusionPipeline, QuantoConfig | 
| 21 | 22 | from diffusers.quantizers import PipelineQuantizationConfig | 
|  | 23 | +from diffusers.utils import logging | 
| 22 | 24 | from diffusers.utils.testing_utils import ( | 
|  | 25 | +    CaptureLogger, | 
| 23 | 26 |     is_transformers_available, | 
| 24 | 27 |     require_accelerate, | 
| 25 | 28 |     require_bitsandbytes_version_greater, | 
| @@ -188,3 +191,55 @@ def test_saving_loading(self): | 
| 188 | 191 |         output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images | 
| 189 | 192 | 
 | 
| 190 | 193 |         self.assertTrue(torch.allclose(output_1, output_2)) | 
|  | 194 | + | 
|  | 195 | +    @parameterized.expand(["quant_kwargs", "quant_mapping"]) | 
|  | 196 | +    def test_warn_invalid_component(self, method): | 
|  | 197 | +        invalid_component = "foo" | 
|  | 198 | +        if method == "quant_kwargs": | 
|  | 199 | +            components_to_quantize = ["transformer", invalid_component] | 
|  | 200 | +            quant_config = PipelineQuantizationConfig( | 
|  | 201 | +                quant_backend="bitsandbytes_8bit", | 
|  | 202 | +                quant_kwargs={"load_in_8bit": True}, | 
|  | 203 | +                components_to_quantize=components_to_quantize, | 
|  | 204 | +            ) | 
|  | 205 | +        else: | 
|  | 206 | +            quant_config = PipelineQuantizationConfig( | 
|  | 207 | +                quant_mapping={ | 
|  | 208 | +                    "transformer": QuantoConfig("int8"), | 
|  | 209 | +                    invalid_component: TranBitsAndBytesConfig(load_in_8bit=True), | 
|  | 210 | +                } | 
|  | 211 | +            ) | 
|  | 212 | + | 
|  | 213 | +        logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils") | 
|  | 214 | +        logger.setLevel(logging.WARNING) | 
|  | 215 | +        with CaptureLogger(logger) as cap_logger: | 
|  | 216 | +            _ = DiffusionPipeline.from_pretrained( | 
|  | 217 | +                self.model_name, | 
|  | 218 | +                quantization_config=quant_config, | 
|  | 219 | +                torch_dtype=torch.bfloat16, | 
|  | 220 | +            ) | 
|  | 221 | +        self.assertTrue(invalid_component in cap_logger.out) | 
|  | 222 | + | 
|  | 223 | +    @parameterized.expand(["quant_kwargs", "quant_mapping"]) | 
|  | 224 | +    def test_no_quantization_for_all_invalid_components(self, method): | 
|  | 225 | +        invalid_component = "foo" | 
|  | 226 | +        if method == "quant_kwargs": | 
|  | 227 | +            components_to_quantize = [invalid_component] | 
|  | 228 | +            quant_config = PipelineQuantizationConfig( | 
|  | 229 | +                quant_backend="bitsandbytes_8bit", | 
|  | 230 | +                quant_kwargs={"load_in_8bit": True}, | 
|  | 231 | +                components_to_quantize=components_to_quantize, | 
|  | 232 | +            ) | 
|  | 233 | +        else: | 
|  | 234 | +            quant_config = PipelineQuantizationConfig( | 
|  | 235 | +                quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)} | 
|  | 236 | +            ) | 
|  | 237 | + | 
|  | 238 | +        pipe = DiffusionPipeline.from_pretrained( | 
|  | 239 | +            self.model_name, | 
|  | 240 | +            quantization_config=quant_config, | 
|  | 241 | +            torch_dtype=torch.bfloat16, | 
|  | 242 | +        ) | 
|  | 243 | +        for name, component in pipe.components.items(): | 
|  | 244 | +            if isinstance(component, torch.nn.Module): | 
|  | 245 | +                self.assertTrue(not hasattr(component.config, "quantization_config")) | 
0 commit comments