2929 convert_unet_state_dict_to_peft ,
3030 delete_adapter_layers ,
3131 get_adapter_name ,
32- get_peft_kwargs ,
3332 is_peft_available ,
3433 is_peft_version ,
3534 logging ,
3635 set_adapter_layers ,
3736 set_weights_and_activate_adapters ,
3837)
38+ from ..utils .peft_utils import _create_lora_config , _maybe_warn_for_unhandled_keys
3939from .lora_base import _fetch_state_dict , _func_optionally_disable_offloading
4040from .unet_loader_utils import _maybe_expand_lora_scales
4141
6464}
6565
6666
67- def _maybe_raise_error_for_ambiguity (config ):
68- rank_pattern = config ["rank_pattern" ].copy ()
69- target_modules = config ["target_modules" ]
70-
71- for key in list (rank_pattern .keys ()):
72- # try to detect ambiguity
73- # `target_modules` can also be a str, in which case this loop would loop
74- # over the chars of the str. The technically correct way to match LoRA keys
75- # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
76- # But this cuts it for now.
77- exact_matches = [mod for mod in target_modules if mod == key ]
78- substring_matches = [mod for mod in target_modules if key in mod and mod != key ]
79-
80- if exact_matches and substring_matches :
81- if is_peft_version ("<" , "0.14.1" ):
82- raise ValueError (
83- "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
84- )
85-
86-
8767class PeftAdapterMixin :
8868 """
8969 A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -191,7 +171,7 @@ def load_lora_adapter(
191171 LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
192172 initialize `LoraConfig`.
193173 """
194- from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
174+ from peft import inject_adapter_in_model , set_peft_model_state_dict
195175 from peft .tuners .tuners_utils import BaseTunerLayer
196176
197177 cache_dir = kwargs .pop ("cache_dir" , None )
@@ -216,7 +196,6 @@ def load_lora_adapter(
216196 )
217197
218198 user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
219-
220199 state_dict , metadata = _fetch_state_dict (
221200 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
222201 weight_name = weight_name ,
@@ -275,38 +254,8 @@ def load_lora_adapter(
275254 k .removeprefix (f"{ prefix } ." ): v for k , v in network_alphas .items () if k in alpha_keys
276255 }
277256
278- if metadata is not None :
279- lora_config_kwargs = metadata
280- else :
281- lora_config_kwargs = get_peft_kwargs (
282- rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict
283- )
284- _maybe_raise_error_for_ambiguity (lora_config_kwargs )
285-
286- if "use_dora" in lora_config_kwargs :
287- if lora_config_kwargs ["use_dora" ]:
288- if is_peft_version ("<" , "0.9.0" ):
289- raise ValueError (
290- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
291- )
292- else :
293- if is_peft_version ("<" , "0.9.0" ):
294- lora_config_kwargs .pop ("use_dora" )
295-
296- if "lora_bias" in lora_config_kwargs :
297- if lora_config_kwargs ["lora_bias" ]:
298- if is_peft_version ("<=" , "0.13.2" ):
299- raise ValueError (
300- "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
301- )
302- else :
303- if is_peft_version ("<=" , "0.13.2" ):
304- lora_config_kwargs .pop ("lora_bias" )
305-
306- try :
307- lora_config = LoraConfig (** lora_config_kwargs )
308- except TypeError as e :
309- raise TypeError ("`LoraConfig` class could not be instantiated." ) from e
257+ # create LoraConfig
258+ lora_config = _create_lora_config (state_dict , network_alphas , metadata , rank )
310259
311260 # adapter_name
312261 if adapter_name is None :
@@ -317,9 +266,8 @@ def load_lora_adapter(
317266 # Now we remove any existing hooks to `_pipeline`.
318267
319268 # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
320- # otherwise loading LoRA weights will lead to an error
269+ # otherwise loading LoRA weights will lead to an error.
321270 is_model_cpu_offload , is_sequential_cpu_offload = self ._optionally_disable_offloading (_pipeline )
322-
323271 peft_kwargs = {}
324272 if is_peft_version (">=" , "0.13.1" ):
325273 peft_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
@@ -403,30 +351,7 @@ def map_state_dict_for_hotswap(sd):
403351 logger .error (f"Loading { adapter_name } was unsuccessful with the following error: \n { e } " )
404352 raise
405353
406- warn_msg = ""
407- if incompatible_keys is not None :
408- # Check only for unexpected keys.
409- unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
410- if unexpected_keys :
411- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
412- if lora_unexpected_keys :
413- warn_msg = (
414- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
415- f" { ', ' .join (lora_unexpected_keys )} . "
416- )
417-
418- # Filter missing keys specific to the current adapter.
419- missing_keys = getattr (incompatible_keys , "missing_keys" , None )
420- if missing_keys :
421- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
422- if lora_missing_keys :
423- warn_msg += (
424- f"Loading adapter weights from state_dict led to missing keys in the model:"
425- f" { ', ' .join (lora_missing_keys )} ."
426- )
427-
428- if warn_msg :
429- logger .warning (warn_msg )
354+ _maybe_warn_for_unhandled_keys (incompatible_keys , adapter_name )
430355
431356 # Offload back.
432357 if is_model_cpu_offload :
@@ -436,10 +361,11 @@ def map_state_dict_for_hotswap(sd):
436361 # Unsafe code />
437362
438363 if prefix is not None and not state_dict :
364+ model_class_name = self .__class__ .__name__
439365 logger .warning (
440- f"No LoRA keys associated to { self . __class__ . __name__ } found with the { prefix = } . "
366+ f"No LoRA keys associated to { model_class_name } found with the { prefix = } . "
441367 "This is safe to ignore if LoRA state dict didn't originally have any "
442- f"{ self . __class__ . __name__ } related params. You can also try specifying `prefix=None` "
368+ f"{ model_class_name } related params. You can also try specifying `prefix=None` "
443369 "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
444370 "https://github.com/huggingface/diffusers/issues/new"
445371 )
0 commit comments