1313# limitations under the License.
1414
1515import re
16+ from typing import List
1617
1718import torch
1819
19- from ..utils import is_peft_version , logging
20+ from ..utils import is_peft_version , logging , state_dict_all_zero
2021
2122
2223logger = logging .get_logger (__name__ )
2324
2425
26+ def swap_scale_shift (weight ):
27+ shift , scale = weight .chunk (2 , dim = 0 )
28+ new_weight = torch .cat ([scale , shift ], dim = 0 )
29+ return new_weight
30+
31+
2532def _maybe_map_sgm_blocks_to_diffusers (state_dict , unet_config , delimiter = "_" , block_slice_pos = 5 ):
2633 # 1. get all state_dict_keys
2734 all_keys = list (state_dict .keys ())
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
313320 # Be aware that this is the new diffusers convention and the rest of the code might
314321 # not utilize it yet.
315322 diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
323+
316324 return diffusers_name
317325
318326
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
331339
332340
333341# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334- # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335- # All credits go to `kohya-ss`.
342+ # are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
336343def _convert_kohya_flux_lora_to_diffusers (state_dict ):
337344 def _convert_to_ai_toolkit (sds_sd , ait_sd , sds_key , ait_key ):
338345 if sds_key + ".lora_down.weight" not in sds_sd :
@@ -341,7 +348,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
341348
342349 # scale weight by alpha and dim
343350 rank = down_weight .shape [0 ]
344- alpha = sds_sd .pop (sds_key + ".alpha" ).item () # alpha is scalar
351+ default_alpha = torch .tensor (rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False )
352+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha ).item () # alpha is scalar
345353 scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346354
347355 # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +370,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
362370 sd_lora_rank = down_weight .shape [0 ]
363371
364372 # scale weight by alpha and dim
365- alpha = sds_sd .pop (sds_key + ".alpha" )
373+ default_alpha = torch .tensor (
374+ sd_lora_rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False
375+ )
376+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha )
366377 scale = alpha / sd_lora_rank
367378
368379 # calculate scale_down and scale_up
@@ -516,10 +527,103 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516527 f"transformer.single_transformer_blocks.{ i } .norm.linear" ,
517528 )
518529
530+ # TODO: alphas.
531+ def assign_remaining_weights (assignments , source ):
532+ for lora_key in ["lora_A" , "lora_B" ]:
533+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
534+ for target_fmt , source_fmt , transform in assignments :
535+ target_key = target_fmt .format (lora_key = lora_key )
536+ source_key = source_fmt .format (orig_lora_key = orig_lora_key )
537+ value = source .pop (source_key )
538+ if transform :
539+ value = transform (value )
540+ ait_sd [target_key ] = value
541+
542+ if any ("guidance_in" in k for k in sds_sd ):
543+ assign_remaining_weights (
544+ [
545+ (
546+ "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" ,
547+ "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight" ,
548+ None ,
549+ ),
550+ (
551+ "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" ,
552+ "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight" ,
553+ None ,
554+ ),
555+ ],
556+ sds_sd ,
557+ )
558+
559+ if any ("img_in" in k for k in sds_sd ):
560+ assign_remaining_weights (
561+ [
562+ ("x_embedder.{lora_key}.weight" , "lora_unet_img_in.{orig_lora_key}.weight" , None ),
563+ ],
564+ sds_sd ,
565+ )
566+
567+ if any ("txt_in" in k for k in sds_sd ):
568+ assign_remaining_weights (
569+ [
570+ ("context_embedder.{lora_key}.weight" , "lora_unet_txt_in.{orig_lora_key}.weight" , None ),
571+ ],
572+ sds_sd ,
573+ )
574+
575+ if any ("time_in" in k for k in sds_sd ):
576+ assign_remaining_weights (
577+ [
578+ (
579+ "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" ,
580+ "lora_unet_time_in_in_layer.{orig_lora_key}.weight" ,
581+ None ,
582+ ),
583+ (
584+ "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" ,
585+ "lora_unet_time_in_out_layer.{orig_lora_key}.weight" ,
586+ None ,
587+ ),
588+ ],
589+ sds_sd ,
590+ )
591+
592+ if any ("vector_in" in k for k in sds_sd ):
593+ assign_remaining_weights (
594+ [
595+ (
596+ "time_text_embed.text_embedder.linear_1.{lora_key}.weight" ,
597+ "lora_unet_vector_in_in_layer.{orig_lora_key}.weight" ,
598+ None ,
599+ ),
600+ (
601+ "time_text_embed.text_embedder.linear_2.{lora_key}.weight" ,
602+ "lora_unet_vector_in_out_layer.{orig_lora_key}.weight" ,
603+ None ,
604+ ),
605+ ],
606+ sds_sd ,
607+ )
608+
609+ if any ("final_layer" in k for k in sds_sd ):
610+ # Notice the swap in processing for "final_layer".
611+ assign_remaining_weights (
612+ [
613+ (
614+ "norm_out.linear.{lora_key}.weight" ,
615+ "lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight" ,
616+ swap_scale_shift ,
617+ ),
618+ ("proj_out.{lora_key}.weight" , "lora_unet_final_layer_linear.{orig_lora_key}.weight" , None ),
619+ ],
620+ sds_sd ,
621+ )
622+
519623 remaining_keys = list (sds_sd .keys ())
520624 te_state_dict = {}
521625 if remaining_keys :
522- if not all (k .startswith ("lora_te" ) for k in remaining_keys ):
626+ if not all (k .startswith (( "lora_te" , "lora_te1" ) ) for k in remaining_keys ):
523627 raise ValueError (f"Incompatible keys detected: \n \n { ', ' .join (remaining_keys )} " )
524628 for key in remaining_keys :
525629 if not key .endswith ("lora_down.weight" ):
@@ -680,10 +784,98 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
680784 if has_peft_state_dict :
681785 state_dict = {k : v for k , v in state_dict .items () if k .startswith ("transformer." )}
682786 return state_dict
787+
683788 # Another weird one.
684789 has_mixture = any (
685790 k .startswith ("lora_transformer_" ) and ("lora_down" in k or "lora_up" in k or "alpha" in k ) for k in state_dict
686791 )
792+
793+ # ComfyUI.
794+ if not has_mixture :
795+ state_dict = {k .replace ("diffusion_model." , "lora_unet_" ): v for k , v in state_dict .items ()}
796+ state_dict = {k .replace ("text_encoders.clip_l.transformer." , "lora_te_" ): v for k , v in state_dict .items ()}
797+
798+ has_position_embedding = any ("position_embedding" in k for k in state_dict )
799+ if has_position_embedding :
800+ zero_status_pe = state_dict_all_zero (state_dict , "position_embedding" )
801+ if zero_status_pe :
802+ logger .info (
803+ "The `position_embedding` LoRA params are all zeros which make them ineffective. "
804+ "So, we will purge them out of the curret state dict to make loading possible."
805+ )
806+
807+ else :
808+ logger .info (
809+ "The state_dict has position_embedding LoRA params and we currently do not support them. "
810+ "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
811+ )
812+ state_dict = {k : v for k , v in state_dict .items () if "position_embedding" not in k }
813+
814+ has_t5xxl = any (k .startswith ("text_encoders.t5xxl.transformer." ) for k in state_dict )
815+ if has_t5xxl :
816+ zero_status_t5 = state_dict_all_zero (state_dict , "text_encoders.t5xxl" )
817+ if zero_status_t5 :
818+ logger .info (
819+ "The `t5xxl` LoRA params are all zeros which make them ineffective. "
820+ "So, we will purge them out of the curret state dict to make loading possible."
821+ )
822+ else :
823+ logger .info (
824+ "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
825+ "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
826+ )
827+ state_dict = {k : v for k , v in state_dict .items () if not k .startswith ("text_encoders.t5xxl.transformer." )}
828+
829+ has_diffb = any ("diff_b" in k and k .startswith (("lora_unet_" , "lora_te_" )) for k in state_dict )
830+ if has_diffb :
831+ zero_status_diff_b = state_dict_all_zero (state_dict , ".diff_b" )
832+ if zero_status_diff_b :
833+ logger .info (
834+ "The `diff_b` LoRA params are all zeros which make them ineffective. "
835+ "So, we will purge them out of the curret state dict to make loading possible."
836+ )
837+ else :
838+ logger .info (
839+ "`diff_b` keys found in the state dict which are currently unsupported. "
840+ "So, we will filter out those keys. Open an issue if this is a problem - "
841+ "https://github.com/huggingface/diffusers/issues/new."
842+ )
843+ state_dict = {k : v for k , v in state_dict .items () if ".diff_b" not in k }
844+
845+ has_norm_diff = any (".norm" in k and ".diff" in k for k in state_dict )
846+ if has_norm_diff :
847+ zero_status_diff = state_dict_all_zero (state_dict , ".diff" )
848+ if zero_status_diff :
849+ logger .info (
850+ "The `diff` LoRA params are all zeros which make them ineffective. "
851+ "So, we will purge them out of the curret state dict to make loading possible."
852+ )
853+ else :
854+ logger .info (
855+ "Normalization diff keys found in the state dict which are currently unsupported. "
856+ "So, we will filter out those keys. Open an issue if this is a problem - "
857+ "https://github.com/huggingface/diffusers/issues/new."
858+ )
859+ state_dict = {k : v for k , v in state_dict .items () if ".norm" not in k and ".diff" not in k }
860+
861+ limit_substrings = ["lora_down" , "lora_up" ]
862+ if any ("alpha" in k for k in state_dict ):
863+ limit_substrings .append ("alpha" )
864+
865+ state_dict = {
866+ _custom_replace (k , limit_substrings ): v
867+ for k , v in state_dict .items ()
868+ if k .startswith (("lora_unet_" , "lora_te_" ))
869+ }
870+
871+ if any ("text_projection" in k for k in state_dict ):
872+ logger .info (
873+ "`text_projection` keys found in the `state_dict` which are unexpected. "
874+ "So, we will filter out those keys. Open an issue if this is a problem - "
875+ "https://github.com/huggingface/diffusers/issues/new."
876+ )
877+ state_dict = {k : v for k , v in state_dict .items () if "text_projection" not in k }
878+
687879 if has_mixture :
688880 return _convert_mixture_state_dict_to_diffusers (state_dict )
689881
@@ -798,6 +990,26 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
798990 return new_state_dict
799991
800992
993+ def _custom_replace (key : str , substrings : List [str ]) -> str :
994+ # Replaces the "."s with "_"s upto the `substrings`.
995+ # Example:
996+ # lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
997+ pattern = "(" + "|" .join (re .escape (sub ) for sub in substrings ) + ")"
998+
999+ match = re .search (pattern , key )
1000+ if match :
1001+ start_sub = match .start ()
1002+ if start_sub > 0 and key [start_sub - 1 ] == "." :
1003+ boundary = start_sub - 1
1004+ else :
1005+ boundary = start_sub
1006+ left = key [:boundary ].replace ("." , "_" )
1007+ right = key [boundary :]
1008+ return left + right
1009+ else :
1010+ return key .replace ("." , "_" )
1011+
1012+
8011013def _convert_bfl_flux_control_lora_to_diffusers (original_state_dict ):
8021014 converted_state_dict = {}
8031015 original_state_dict_keys = list (original_state_dict .keys ())
@@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
8061018 inner_dim = 3072
8071019 mlp_ratio = 4.0
8081020
809- def swap_scale_shift (weight ):
810- shift , scale = weight .chunk (2 , dim = 0 )
811- new_weight = torch .cat ([scale , shift ], dim = 0 )
812- return new_weight
813-
8141021 for lora_key in ["lora_A" , "lora_B" ]:
8151022 ## time_text_embed.timestep_embedder <- time_in
8161023 converted_state_dict [
0 commit comments