@@ -558,6 +558,72 @@ def test_load_regular_lora(self):
558558 self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features * 2 )
559559 self .assertFalse (np .allclose (original_output , lora_output , atol = 1e-3 , rtol = 1e-3 ))
560560
561+ def test_lora_unload_with_parameter_expanded_shapes (self ):
562+ components , _ , _ = self .get_dummy_components (FlowMatchEulerDiscreteScheduler )
563+
564+ logger = logging .get_logger ("diffusers.loaders.lora_pipeline" )
565+ logger .setLevel (logging .DEBUG )
566+
567+ # Change the transformer config to mimic a real use case.
568+ num_channels_without_control = 4
569+ transformer = FluxTransformer2DModel .from_config (
570+ components ["transformer" ].config , in_channels = num_channels_without_control
571+ ).to (torch_device )
572+ self .assertTrue (
573+ transformer .config .in_channels == num_channels_without_control ,
574+ f"Expected { num_channels_without_control } channels in the modified transformer but has { transformer .config .in_channels = } " ,
575+ )
576+
577+ # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
578+ components ["transformer" ] = transformer
579+ pipe = FluxPipeline (** components )
580+ pipe = pipe .to (torch_device )
581+ pipe .set_progress_bar_config (disable = None )
582+
583+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
584+ control_image = inputs .pop ("control_image" )
585+ original_out = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
586+
587+ control_pipe = self .pipeline_class (** components )
588+ out_features , in_features = control_pipe .transformer .x_embedder .weight .shape
589+ rank = 4
590+
591+ dummy_lora_A = torch .nn .Linear (2 * in_features , rank , bias = False )
592+ dummy_lora_B = torch .nn .Linear (rank , out_features , bias = False )
593+ lora_state_dict = {
594+ "transformer.x_embedder.lora_A.weight" : dummy_lora_A .weight ,
595+ "transformer.x_embedder.lora_B.weight" : dummy_lora_B .weight ,
596+ }
597+ with CaptureLogger (logger ) as cap_logger :
598+ control_pipe .load_lora_weights (lora_state_dict , "adapter-1" )
599+ self .assertTrue (check_if_lora_correctly_set (pipe .transformer ), "Lora not correctly set in denoiser" )
600+
601+ inputs ["control_image" ] = control_image
602+ lora_out = control_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
603+
604+ self .assertFalse (np .allclose (original_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
605+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == 2 * in_features )
606+ self .assertTrue (pipe .transformer .config .in_channels == 2 * in_features )
607+ self .assertTrue (cap_logger .out .startswith ("Expanding the nn.Linear input/output features for module" ))
608+
609+ control_pipe .unload_lora_weights ()
610+ self .assertTrue (
611+ control_pipe .transformer .config .in_channels == num_channels_without_control ,
612+ f"Expected { num_channels_without_control } channels in the modified transformer but has { control_pipe .transformer .config .in_channels = } " ,
613+ )
614+ loaded_pipe = FluxPipeline .from_pipe (control_pipe )
615+ self .assertTrue (
616+ loaded_pipe .transformer .config .in_channels == num_channels_without_control ,
617+ f"Expected { num_channels_without_control } channels in the modified transformer but has { loaded_pipe .transformer .config .in_channels = } " ,
618+ )
619+ inputs .pop ("control_image" )
620+ unloaded_lora_out = loaded_pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
621+
622+ self .assertFalse (np .allclose (unloaded_lora_out , lora_out , rtol = 1e-4 , atol = 1e-4 ))
623+ self .assertTrue (np .allclose (unloaded_lora_out , original_out , atol = 1e-4 , rtol = 1e-4 ))
624+ self .assertTrue (pipe .transformer .x_embedder .weight .data .shape [1 ] == in_features )
625+ self .assertTrue (pipe .transformer .config .in_channels == in_features )
626+
561627 @unittest .skip ("Not supported in Flux." )
562628 def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options (self ):
563629 pass
0 commit comments