|
19 | 19 | import numpy as np |
20 | 20 | import pytest |
21 | 21 | from huggingface_hub import hf_hub_download |
| 22 | +from PIL import Image |
22 | 23 |
|
23 | 24 | from diffusers import ( |
24 | 25 | BitsAndBytesConfig, |
25 | 26 | DiffusionPipeline, |
| 27 | + FluxControlPipeline, |
26 | 28 | FluxTransformer2DModel, |
27 | 29 | SanaTransformer2DModel, |
28 | 30 | SD3Transformer2DModel, |
29 | 31 | logging, |
30 | 32 | ) |
| 33 | +from diffusers.quantizers import PipelineQuantizationConfig |
31 | 34 | from diffusers.utils import is_accelerate_version |
32 | 35 | from diffusers.utils.testing_utils import ( |
33 | 36 | CaptureLogger, |
|
39 | 42 | numpy_cosine_similarity_distance, |
40 | 43 | require_accelerate, |
41 | 44 | require_bitsandbytes_version_greater, |
| 45 | + require_peft_backend, |
42 | 46 | require_peft_version_greater, |
43 | 47 | require_torch, |
44 | 48 | require_torch_accelerator, |
@@ -697,6 +701,50 @@ def test_lora_loading(self): |
697 | 701 | self.assertTrue(max_diff < 1e-3) |
698 | 702 |
|
699 | 703 |
|
| 704 | +@require_transformers_version_greater("4.44.0") |
| 705 | +@require_peft_backend |
| 706 | +class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests): |
| 707 | + def setUp(self) -> None: |
| 708 | + gc.collect() |
| 709 | + backend_empty_cache(torch_device) |
| 710 | + |
| 711 | + self.pipeline_8bit = FluxControlPipeline.from_pretrained( |
| 712 | + "black-forest-labs/FLUX.1-dev", |
| 713 | + quantization_config=PipelineQuantizationConfig( |
| 714 | + quant_backend="bitsandbytes_8bit", |
| 715 | + quant_kwargs={"load_in_8bit": True}, |
| 716 | + components_to_quantize=["transformer", "text_encoder_2"], |
| 717 | + ), |
| 718 | + torch_dtype=torch.float16, |
| 719 | + ) |
| 720 | + self.pipeline_8bit.enable_model_cpu_offload() |
| 721 | + |
| 722 | + def tearDown(self): |
| 723 | + del self.pipeline_8bit |
| 724 | + |
| 725 | + gc.collect() |
| 726 | + backend_empty_cache(torch_device) |
| 727 | + |
| 728 | + def test_lora_loading(self): |
| 729 | + self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") |
| 730 | + |
| 731 | + output = self.pipeline_8bit( |
| 732 | + prompt=self.prompt, |
| 733 | + control_image=Image.new(mode="RGB", size=(256, 256)), |
| 734 | + height=256, |
| 735 | + width=256, |
| 736 | + max_sequence_length=64, |
| 737 | + output_type="np", |
| 738 | + num_inference_steps=8, |
| 739 | + generator=torch.Generator().manual_seed(42), |
| 740 | + ).images |
| 741 | + out_slice = output[0, -3:, -3:, -1].flatten() |
| 742 | + expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292]) |
| 743 | + |
| 744 | + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) |
| 745 | + self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}") |
| 746 | + |
| 747 | + |
700 | 748 | @slow |
701 | 749 | class BaseBnb8bitSerializationTests(Base8bitTests): |
702 | 750 | def setUp(self): |
|
0 commit comments