@@ -200,7 +200,8 @@ def save_model_card(
200200 "diffusers" ,
201201 "diffusers-training" ,
202202 lora ,
203- "template:sd-lora" "stable-diffusion" ,
203+ "template:sd-lora" ,
204+ "stable-diffusion" ,
204205 "stable-diffusion-diffusers" ,
205206 ]
206207 model_card = populate_model_card (model_card , tags = tags )
@@ -724,9 +725,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
724725 idx = 0
725726 for tokenizer , text_encoder in zip (self .tokenizers , self .text_encoders ):
726727 assert isinstance (inserting_toks , list ), "inserting_toks should be a list of strings."
727- assert all (
728- isinstance ( tok , str ) for tok in inserting_toks
729- ), "All elements in inserting_toks should be strings."
728+ assert all (isinstance ( tok , str ) for tok in inserting_toks ), (
729+ "All elements in inserting_toks should be strings."
730+ )
730731
731732 self .inserting_toks = inserting_toks
732733 special_tokens_dict = {"additional_special_tokens" : self .inserting_toks }
@@ -746,9 +747,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
746747 .to (dtype = self .dtype )
747748 * std_token_embedding
748749 )
749- self .embeddings_settings [
750- f"original_embeddings_ { idx } "
751- ] = text_encoder . text_model . embeddings . token_embedding . weight . data . clone ( )
750+ self .embeddings_settings [f"original_embeddings_ { idx } " ] = (
751+ text_encoder . text_model . embeddings . token_embedding . weight . data . clone ()
752+ )
752753 self .embeddings_settings [f"std_token_embedding_{ idx } " ] = std_token_embedding
753754
754755 inu = torch .ones ((len (tokenizer ),), dtype = torch .bool )
@@ -1322,7 +1323,7 @@ def load_model_hook(models, input_dir):
13221323
13231324 lora_state_dict , network_alphas = StableDiffusionPipeline .lora_state_dict (input_dir )
13241325
1325- unet_state_dict = {f' { k .replace (" unet." , "" ) } ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
1326+ unet_state_dict = {f" { k .replace (' unet.' , '' ) } " : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
13261327 unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
13271328 incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
13281329 if incompatible_keys is not None :
0 commit comments