Skip to content

Commit d13b0d6

Browse files
authored
[Flux] add lora integration tests. (huggingface#9353)
* add lora integration tests. * internal note * add a skip marker.
1 parent 5d476f5 commit d13b0d6

File tree

1 file changed

+95
-1
lines changed

1 file changed

+95
-1
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import gc
1516
import os
1617
import sys
1718
import tempfile
@@ -23,7 +24,14 @@
2324
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2425

2526
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
26-
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device
27+
from diffusers.utils.testing_utils import (
28+
floats_tensor,
29+
is_peft_available,
30+
require_peft_backend,
31+
require_torch_gpu,
32+
slow,
33+
torch_device,
34+
)
2735

2836

2937
if is_peft_available():
@@ -145,3 +153,89 @@ def test_with_alpha_in_state_dict(self):
145153
"Loading from saved checkpoints should give same results.",
146154
)
147155
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
156+
157+
158+
@slow
159+
@require_torch_gpu
160+
@require_peft_backend
161+
@unittest.skip("We cannot run inference on this model with the current CI hardware")
162+
# TODO (DN6, sayakpaul): move these tests to a beefier GPU
163+
class FluxLoRAIntegrationTests(unittest.TestCase):
164+
"""internal note: The integration slices were obtained on audace."""
165+
166+
num_inference_steps = 10
167+
seed = 0
168+
169+
def setUp(self):
170+
super().setUp()
171+
172+
gc.collect()
173+
torch.cuda.empty_cache()
174+
175+
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
176+
177+
def tearDown(self):
178+
super().tearDown()
179+
180+
gc.collect()
181+
torch.cuda.empty_cache()
182+
183+
def test_flux_the_last_ben(self):
184+
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
185+
self.pipeline.fuse_lora()
186+
self.pipeline.unload_lora_weights()
187+
self.pipeline.enable_model_cpu_offload()
188+
189+
prompt = "jon snow eating pizza with ketchup"
190+
191+
out = self.pipeline(
192+
prompt,
193+
num_inference_steps=self.num_inference_steps,
194+
guidance_scale=4.0,
195+
output_type="np",
196+
generator=torch.manual_seed(self.seed),
197+
).images
198+
out_slice = out[0, -3:, -3:, -1].flatten()
199+
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])
200+
201+
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
202+
203+
def test_flux_kohya(self):
204+
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
205+
self.pipeline.fuse_lora()
206+
self.pipeline.unload_lora_weights()
207+
self.pipeline.enable_model_cpu_offload()
208+
209+
prompt = "The cat with a brain slug earring"
210+
out = self.pipeline(
211+
prompt,
212+
num_inference_steps=self.num_inference_steps,
213+
guidance_scale=4.5,
214+
output_type="np",
215+
generator=torch.manual_seed(self.seed),
216+
).images
217+
218+
out_slice = out[0, -3:, -3:, -1].flatten()
219+
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
220+
221+
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)
222+
223+
def test_flux_xlabs(self):
224+
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
225+
self.pipeline.fuse_lora()
226+
self.pipeline.unload_lora_weights()
227+
self.pipeline.enable_model_cpu_offload()
228+
229+
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
230+
231+
out = self.pipeline(
232+
prompt,
233+
num_inference_steps=self.num_inference_steps,
234+
guidance_scale=3.5,
235+
output_type="np",
236+
generator=torch.manual_seed(self.seed),
237+
).images
238+
out_slice = out[0, -3:, -3:, -1].flatten()
239+
expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980])
240+
241+
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)