Skip to content

Commit be8aff7

Browse files
Merge branch 'cogvideox1.1-5b' of github.com:zRzRzRzRzRzRzR/diffusers into cogvideox1.1-5b
2 parents be80dbf + 5e96cae commit be8aff7

File tree

11 files changed

+351
-44
lines changed

11 files changed

+351
-44
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
4040
from huggingface_hub import create_repo, upload_folder
4141
from packaging import version
42-
from peft import LoraConfig
42+
from peft import LoraConfig, set_peft_model_state_dict
4343
from peft.utils import get_peft_model_state_dict
4444
from PIL import Image
4545
from PIL.ImageOps import exif_transpose
@@ -59,12 +59,13 @@
5959
)
6060
from diffusers.loaders import StableDiffusionLoraLoaderMixin
6161
from diffusers.optimization import get_scheduler
62-
from diffusers.training_utils import compute_snr
62+
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
6363
from diffusers.utils import (
6464
check_min_version,
6565
convert_all_state_dict_to_peft,
6666
convert_state_dict_to_diffusers,
6767
convert_state_dict_to_kohya,
68+
convert_unet_state_dict_to_peft,
6869
is_wandb_available,
6970
)
7071
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
@@ -1319,6 +1320,37 @@ def load_model_hook(models, input_dir):
13191320
else:
13201321
raise ValueError(f"unexpected save model: {model.__class__}")
13211322

1323+
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
1324+
1325+
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
1326+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1327+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
1328+
if incompatible_keys is not None:
1329+
# check only for unexpected keys
1330+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1331+
if unexpected_keys:
1332+
logger.warning(
1333+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1334+
f" {unexpected_keys}. "
1335+
)
1336+
1337+
if args.train_text_encoder:
1338+
# Do we need to call `scale_lora_layers()` here?
1339+
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
1340+
1341+
_set_state_dict_into_text_encoder(
1342+
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
1343+
)
1344+
1345+
# Make sure the trainable params are in float32. This is again needed since the base models
1346+
# are in `weight_dtype`. More details:
1347+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1348+
if args.mixed_precision == "fp16":
1349+
models = [unet_]
1350+
if args.train_text_encoder:
1351+
models.extend([text_encoder_one_])
1352+
# only upcast trainable parameters (LoRA) into fp32
1353+
cast_training_params(models)
13221354
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
13231355
StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
13241356

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
convert_state_dict_to_diffusers,
6868
convert_state_dict_to_kohya,
6969
convert_unet_state_dict_to_peft,
70+
is_peft_version,
7071
is_wandb_available,
7172
)
7273
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
@@ -1183,26 +1184,33 @@ def main(args):
11831184
text_encoder_one.gradient_checkpointing_enable()
11841185
text_encoder_two.gradient_checkpointing_enable()
11851186

1187+
def get_lora_config(rank, use_dora, target_modules):
1188+
base_config = {
1189+
"r": rank,
1190+
"lora_alpha": rank,
1191+
"init_lora_weights": "gaussian",
1192+
"target_modules": target_modules,
1193+
}
1194+
if use_dora:
1195+
if is_peft_version("<", "0.9.0"):
1196+
raise ValueError(
1197+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1198+
)
1199+
else:
1200+
base_config["use_dora"] = True
1201+
1202+
return LoraConfig(**base_config)
1203+
11861204
# now we will add new LoRA weights to the attention layers
1187-
unet_lora_config = LoraConfig(
1188-
r=args.rank,
1189-
use_dora=args.use_dora,
1190-
lora_alpha=args.rank,
1191-
init_lora_weights="gaussian",
1192-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1193-
)
1205+
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1206+
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
11941207
unet.add_adapter(unet_lora_config)
11951208

11961209
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
11971210
# So, instead, we monkey-patch the forward calls of its attention-blocks.
11981211
if args.train_text_encoder:
1199-
text_lora_config = LoraConfig(
1200-
r=args.rank,
1201-
use_dora=args.use_dora,
1202-
lora_alpha=args.rank,
1203-
init_lora_weights="gaussian",
1204-
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
1205-
)
1212+
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
1213+
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
12061214
text_encoder_one.add_adapter(text_lora_config)
12071215
text_encoder_two.add_adapter(text_lora_config)
12081216

src/diffusers/loaders/ip_adapter.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,14 @@
3333

3434

3535
if is_transformers_available():
36-
from transformers import (
37-
CLIPImageProcessor,
38-
CLIPVisionModelWithProjection,
39-
)
36+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
4037

4138
from ..models.attention_processor import (
4239
AttnProcessor,
4340
AttnProcessor2_0,
4441
IPAdapterAttnProcessor,
4542
IPAdapterAttnProcessor2_0,
43+
IPAdapterXFormersAttnProcessor,
4644
)
4745

4846
logger = logging.get_logger(__name__)
@@ -284,7 +282,9 @@ def set_ip_adapter_scale(self, scale):
284282
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
285283

286284
for attn_name, attn_processor in unet.attn_processors.items():
287-
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
285+
if isinstance(
286+
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
287+
):
288288
if len(scale_configs) != len(attn_processor.scale):
289289
raise ValueError(
290290
f"Cannot assign {len(scale_configs)} scale_configs to "
@@ -342,7 +342,9 @@ def unload_ip_adapter(self):
342342
)
343343
attn_procs[name] = (
344344
attn_processor_class
345-
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
345+
if isinstance(
346+
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
347+
)
346348
else value.__class__()
347349
)
348350
self.unet.set_attn_processor(attn_procs)

src/diffusers/loaders/unet.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
765765
from ..models.attention_processor import (
766766
IPAdapterAttnProcessor,
767767
IPAdapterAttnProcessor2_0,
768+
IPAdapterXFormersAttnProcessor,
768769
)
769770

770771
if low_cpu_mem_usage:
@@ -804,11 +805,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
804805
if cross_attention_dim is None or "motion_modules" in name:
805806
attn_processor_class = self.attn_processors[name].__class__
806807
attn_procs[name] = attn_processor_class()
807-
808808
else:
809-
attn_processor_class = (
810-
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
811-
)
809+
if "XFormers" in str(self.attn_processors[name].__class__):
810+
attn_processor_class = IPAdapterXFormersAttnProcessor
811+
else:
812+
attn_processor_class = (
813+
IPAdapterAttnProcessor2_0
814+
if hasattr(F, "scaled_dot_product_attention")
815+
else IPAdapterAttnProcessor
816+
)
812817
num_image_text_embeds = []
813818
for state_dict in state_dicts:
814819
if "proj.weight" in state_dict["image_proj"]:

0 commit comments

Comments
 (0)