@@ -4813,22 +4813,43 @@ def _maybe_expand_t2v_lora_for_i2v(
48134813 if transformer .config .image_dim is None :
48144814 return state_dict
48154815
4816+ target_device = transformer .device
4817+
48164818 if any (k .startswith ("transformer.blocks." ) for k in state_dict ):
4817- num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4819+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict if "blocks." in k })
48184820 is_i2v_lora = any ("add_k_proj" in k for k in state_dict ) and any ("add_v_proj" in k for k in state_dict )
4821+ has_bias = any (".lora_B.bias" in k for k in state_dict )
48194822
48204823 if is_i2v_lora :
48214824 return state_dict
48224825
48234826 for i in range (num_blocks ):
48244827 for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4828+ # These keys should exist if the block `i` was part of the T2V LoRA.
4829+ ref_key_lora_A = f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight"
4830+ ref_key_lora_B = f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight"
4831+
4832+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict :
4833+ continue
4834+
48254835 state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4826- state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight" ]
4836+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight" ], device = target_device
48274837 )
48284838 state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4829- state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight" ]
4839+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight" ], device = target_device
48304840 )
48314841
4842+ # If the original LoRA had biases (indicated by has_bias)
4843+ # AND the specific reference bias key exists for this block.
4844+
4845+ ref_key_lora_B_bias = f"transformer.blocks.{ i } .attn2.to_k.lora_B.bias"
4846+ if has_bias and ref_key_lora_B_bias in state_dict :
4847+ ref_lora_B_bias_tensor = state_dict [ref_key_lora_B_bias ]
4848+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.bias" ] = torch .zeros_like (
4849+ ref_lora_B_bias_tensor ,
4850+ device = target_device ,
4851+ )
4852+
48324853 return state_dict
48334854
48344855 def load_lora_weights (
0 commit comments