Skip to content

Commit d1db4f8

Browse files
authored
[LoRA ]fix flux lora loader when return_metadata is true for non-diffusers (huggingface#11716)
* fix flux lora loader when return_metadata is true for non-diffusers * remove annotation
1 parent 8adc600 commit d1db4f8

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,18 +2031,36 @@ def lora_state_dict(
20312031
if is_kohya:
20322032
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
20332033
# Kohya already takes care of scaling the LoRA parameters with alpha.
2034-
return (state_dict, None) if return_alphas else state_dict
2034+
return cls._prepare_outputs(
2035+
state_dict,
2036+
metadata=metadata,
2037+
alphas=None,
2038+
return_alphas=return_alphas,
2039+
return_metadata=return_lora_metadata,
2040+
)
20352041

20362042
is_xlabs = any("processor" in k for k in state_dict)
20372043
if is_xlabs:
20382044
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
20392045
# xlabs doesn't use `alpha`.
2040-
return (state_dict, None) if return_alphas else state_dict
2046+
return cls._prepare_outputs(
2047+
state_dict,
2048+
metadata=metadata,
2049+
alphas=None,
2050+
return_alphas=return_alphas,
2051+
return_metadata=return_lora_metadata,
2052+
)
20412053

20422054
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
20432055
if is_bfl_control:
20442056
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
2045-
return (state_dict, None) if return_alphas else state_dict
2057+
return cls._prepare_outputs(
2058+
state_dict,
2059+
metadata=metadata,
2060+
alphas=None,
2061+
return_alphas=return_alphas,
2062+
return_metadata=return_lora_metadata,
2063+
)
20462064

20472065
# For state dicts like
20482066
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -2061,12 +2079,13 @@ def lora_state_dict(
20612079
)
20622080

20632081
if return_alphas or return_lora_metadata:
2064-
outputs = [state_dict]
2065-
if return_alphas:
2066-
outputs.append(network_alphas)
2067-
if return_lora_metadata:
2068-
outputs.append(metadata)
2069-
return tuple(outputs)
2082+
return cls._prepare_outputs(
2083+
state_dict,
2084+
metadata=metadata,
2085+
alphas=network_alphas,
2086+
return_alphas=return_alphas,
2087+
return_metadata=return_lora_metadata,
2088+
)
20702089
else:
20712090
return state_dict
20722091

@@ -2785,6 +2804,15 @@ def _get_weight_shape(weight: torch.Tensor):
27852804

27862805
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
27872806

2807+
@staticmethod
2808+
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
2809+
outputs = [state_dict]
2810+
if return_alphas:
2811+
outputs.append(alphas)
2812+
if return_metadata:
2813+
outputs.append(metadata)
2814+
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
2815+
27882816

27892817
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
27902818
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

src/diffusers/loaders/peft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ def load_lora_adapter(
187187
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
188188
limitations to this technique, which are documented here:
189189
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
190-
metadata: TODO
190+
metadata:
191+
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
192+
initialize `LoraConfig`.
191193
"""
192194
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
193195
from peft.tuners.tuners_utils import BaseTunerLayer

src/diffusers/utils/state_dict_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str):
359359
metadata = f.metadata() or {}
360360

361361
metadata.pop("format", None)
362-
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
363-
return json.loads(raw) if raw else None
362+
if metadata:
363+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
364+
return json.loads(raw) if raw else None
365+
else:
366+
return None

0 commit comments

Comments
 (0)