@@ -430,6 +430,122 @@ def test_correct_lora_configs_with_different_ranks(self):
430430 self .assertTrue (not np .allclose (original_output , lora_output_diff_alpha , atol = 1e-3 , rtol = 1e-3 ))
431431 self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
432432
433+ def test_lora_expanding_shape_with_normal_lora_raises_error (self ):
434+ # TODO: This test checks if an error is raised when a lora expands shapes (like control loras) but
435+ # another lora with correct shapes is loaded. This is not supported at the moment and should raise an error.
436+ # When we do support it, this test should be removed. Context: https://github.com/huggingface/diffusers/issues/10180
437+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
438+
439+ # Change the transformer config to mimic a real use case.
440+ num_channels_without_control = 4
441+ transformer = FluxTransformer2DModel .from_config (
442+ components ["transformer" ].config , in_channels = num_channels_without_control
443+ ).to (torch_device )
444+ components ["transformer" ] = transformer
445+
446+ pipe = self .pipeline_class (** components )
447+ pipe = pipe .to (torch_device )
448+ pipe .set_progress_bar_config (disable = None )
449+
450+ logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
451+ logger .setLevel (logging .DEBUG )
452+
453+ out_features , in_features = pipe .transformer .x_embedder .weight .shape
454+ rank = 4
455+
456+ shape_expander_lora_A = torch .nn .Linear (2 * in_features , rank , bias = False )
457+ shape_expander_lora_B = torch .nn .Linear (rank , out_features , bias = False )
458+ lora_state_dict = {
459+ "transformer.x_embedder.lora_A.weight" : shape_expander_lora_A .weight ,
460+ "transformer.x_embedder.lora_B.weight" : shape_expander_lora_B .weight ,
461+ }
462+ with CaptureLogger (logger ) as cap_logger :
463+ pipe .load_lora_weights (lora_state_dict , "adapter-1" )
464+
465+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
466+ self .assertTrue (pipe .get_active_adapters () == ["adapter-1" ])
467+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == 2 * in_features )
468+ self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
469+ self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
470+
471+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
472+ lora_output = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
473+
474+ normal_lora_A = torch .nn .Linear (in_features , rank , bias = False )
475+ normal_lora_B = torch .nn .Linear (rank , out_features , bias = False )
476+ lora_state_dict = {
477+ "transformer.x_embedder.lora_A.weight" : normal_lora_A .weight ,
478+ "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
479+ }
480+
481+ # The first lora expanded the input features of x_embedder. Here, we are trying to load a lora with the correct
482+ # input features before expansion. This should raise an error about the weight shapes being incompatible.
483+ self .assertRaisesRegex (
484+ RuntimeError ,
485+ "size mismatch for x_embedder.lora_A.adapter-2.weight" ,
486+ pipe .load_lora_weights ,
487+ lora_state_dict ,
488+ "adapter-2" ,
489+ )
490+ # We should have `adapter-1` as the only adapter.
491+ self .assertTrue (pipe .get_active_adapters () == ["adapter-1" ])
492+
493+ # Check if the output is the same after lora loading error
494+ lora_output_after_error = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
495+ self .assertTrue (np .allclose (lora_output , lora_output_after_error , atol = 1e-3 , rtol = 1e-3 ))
496+
497+ # Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
498+ # This should raise a runtime error on input shapes being incompatible. But it doesn't. This is because PEFT renames the
499+ # original layers as `base_layer` and the lora layers with the adapter names. This makes our logic to check if a lora
500+ # weight is compatible with the current model inadequate. This should be addressed when attempting support for
501+ # https://github.com/huggingface/diffusers/issues/10180 (TODO)
502+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
503+ # Change the transformer config to mimic a real use case.
504+ num_channels_without_control = 4
505+ transformer = FluxTransformer2DModel .from_config (
506+ components ["transformer" ].config , in_channels = num_channels_without_control
507+ ).to (torch_device )
508+ components ["transformer" ] = transformer
509+
510+ pipe = self .pipeline_class (** components )
511+ pipe = pipe .to (torch_device )
512+ pipe .set_progress_bar_config (disable = None )
513+
514+ logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
515+ logger .setLevel (logging .DEBUG )
516+
517+ out_features , in_features = pipe .transformer .x_embedder .weight .shape
518+ rank = 4
519+
520+ lora_state_dict = {
521+ "transformer.x_embedder.lora_A.weight" : normal_lora_A .weight ,
522+ "transformer.x_embedder.lora_B.weight" : normal_lora_B .weight ,
523+ }
524+
525+ with CaptureLogger (logger ) as cap_logger :
526+ pipe .load_lora_weights (lora_state_dict , "adapter-1" )
527+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
528+
529+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features )
530+ self .assertTrue (pipe .transformer .config .in_channels == in_features )
531+ self .assertFalse (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
532+
533+ lora_state_dict = {
534+ "transformer.x_embedder.lora_A.weight" : shape_expander_lora_A .weight ,
535+ "transformer.x_embedder.lora_B.weight" : shape_expander_lora_B .weight ,
536+ }
537+
538+ # We should check for input shapes being incompatible here. But because above mentioned issue is
539+ # not a supported use case, and because of the PEFT renaming, we will currently have a shape
540+ # mismatch error.
541+ self .assertRaisesRegex (
542+ RuntimeError ,
543+ "size mismatch for x_embedder.lora_A.adapter-2.weight" ,
544+ pipe .load_lora_weights ,
545+ lora_state_dict ,
546+ "adapter-2" ,
547+ )
548+
433549 @unittest .skip ("Not supported in Flux." )
434550 def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
435551 pass
0 commit comments