@@ -1825,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18251825 is_i2v_lora = any ("k_img" in k for k in original_state_dict ) and any ("v_img" in k for k in original_state_dict )
18261826 lora_down_key = "lora_A" if any ("lora_A" in k for k in original_state_dict ) else "lora_down"
18271827 lora_up_key = "lora_B" if any ("lora_B" in k for k in original_state_dict ) else "lora_up"
1828+ has_time_projection_weight = any (
1829+ k .startswith ("time_projection" ) and k .endswith (".weight" ) for k in original_state_dict
1830+ )
18281831
1829- diff_keys = [k for k in original_state_dict if k .endswith ((".diff_b" , ".diff" ))]
1830- if diff_keys :
1831- for diff_k in diff_keys :
1832- param = original_state_dict [diff_k ]
1833- # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1834- # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1835- # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1836- # is okay to ignore because they do not affect the model output in a significant manner.
1837- threshold = 1.6e-2
1838- absdiff = param .abs ().max () - param .abs ().min ()
1839- all_zero = torch .all (param == 0 ).item ()
1840- all_absdiff_lower_than_threshold = absdiff < threshold
1841- if all_zero or all_absdiff_lower_than_threshold :
1842- logger .debug (
1843- f"Removed { diff_k } key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1844- )
1845- original_state_dict .pop (diff_k )
1832+ for key in list (original_state_dict .keys ()):
1833+ if key .endswith ((".diff" , ".diff_b" )) and "norm" in key :
1834+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
1835+ # in future if needed and they are not zeroed.
1836+ original_state_dict .pop (key )
1837+ logger .debug (f"Removing { key } key from the state dict as it is a norm diff key. This is unsupported." )
1838+
1839+ if "time_projection" in key and not has_time_projection_weight :
1840+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
1841+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
1842+ # CausVid lora has the weight keys and the bias keys.
1843+ original_state_dict .pop (key )
18461844
18471845 # For the `diff_b` keys, we treat them as lora_bias.
18481846 # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
0 commit comments