@@ -2031,18 +2031,36 @@ def lora_state_dict(
20312031        if  is_kohya :
20322032            state_dict  =  _convert_kohya_flux_lora_to_diffusers (state_dict )
20332033            # Kohya already takes care of scaling the LoRA parameters with alpha. 
2034-             return  (state_dict , None ) if  return_alphas  else  state_dict 
2034+             return  cls ._prepare_outputs (
2035+                 state_dict ,
2036+                 metadata = metadata ,
2037+                 alphas = None ,
2038+                 return_alphas = return_alphas ,
2039+                 return_metadata = return_lora_metadata ,
2040+             )
20352041
20362042        is_xlabs  =  any ("processor"  in  k  for  k  in  state_dict )
20372043        if  is_xlabs :
20382044            state_dict  =  _convert_xlabs_flux_lora_to_diffusers (state_dict )
20392045            # xlabs doesn't use `alpha`. 
2040-             return  (state_dict , None ) if  return_alphas  else  state_dict 
2046+             return  cls ._prepare_outputs (
2047+                 state_dict ,
2048+                 metadata = metadata ,
2049+                 alphas = None ,
2050+                 return_alphas = return_alphas ,
2051+                 return_metadata = return_lora_metadata ,
2052+             )
20412053
20422054        is_bfl_control  =  any ("query_norm.scale"  in  k  for  k  in  state_dict )
20432055        if  is_bfl_control :
20442056            state_dict  =  _convert_bfl_flux_control_lora_to_diffusers (state_dict )
2045-             return  (state_dict , None ) if  return_alphas  else  state_dict 
2057+             return  cls ._prepare_outputs (
2058+                 state_dict ,
2059+                 metadata = metadata ,
2060+                 alphas = None ,
2061+                 return_alphas = return_alphas ,
2062+                 return_metadata = return_lora_metadata ,
2063+             )
20462064
20472065        # For state dicts like 
20482066        # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA 
@@ -2061,12 +2079,13 @@ def lora_state_dict(
20612079                    )
20622080
20632081        if  return_alphas  or  return_lora_metadata :
2064-             outputs  =  [state_dict ]
2065-             if  return_alphas :
2066-                 outputs .append (network_alphas )
2067-             if  return_lora_metadata :
2068-                 outputs .append (metadata )
2069-             return  tuple (outputs )
2082+             return  cls ._prepare_outputs (
2083+                 state_dict ,
2084+                 metadata = metadata ,
2085+                 alphas = network_alphas ,
2086+                 return_alphas = return_alphas ,
2087+                 return_metadata = return_lora_metadata ,
2088+             )
20702089        else :
20712090            return  state_dict 
20722091
@@ -2785,6 +2804,15 @@ def _get_weight_shape(weight: torch.Tensor):
27852804
27862805        raise  ValueError ("Either `base_module` or `base_weight_param_name` must be provided." )
27872806
2807+     @staticmethod  
2808+     def  _prepare_outputs (state_dict , metadata , alphas = None , return_alphas = False , return_metadata = False ):
2809+         outputs  =  [state_dict ]
2810+         if  return_alphas :
2811+             outputs .append (alphas )
2812+         if  return_metadata :
2813+             outputs .append (metadata )
2814+         return  tuple (outputs ) if  (return_alphas  or  return_metadata ) else  state_dict 
2815+ 
27882816
27892817# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially 
27902818# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. 
0 commit comments