2222 USE_PEFT_BACKEND ,
2323 deprecate ,
2424 get_submodule_by_name ,
25+ is_bitsandbytes_available ,
26+ is_gguf_available ,
2527 is_peft_available ,
2628 is_peft_version ,
2729 is_torch_version ,
6870_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder" : "in_channels" }
6971
7072
73+ def _maybe_dequantize_weight_for_expanded_lora (model , module ):
74+ if is_bitsandbytes_available ():
75+ from ..quantizers .bitsandbytes import dequantize_bnb_weight
76+
77+ if is_gguf_available ():
78+ from ..quantizers .gguf .utils import dequantize_gguf_tensor
79+
80+ is_bnb_4bit_quantized = module .weight .__class__ .__name__ == "Params4bit"
81+ is_gguf_quantized = module .weight .__class__ .__name__ == "GGUFParameter"
82+
83+ if is_bnb_4bit_quantized and not is_bitsandbytes_available ():
84+ raise ValueError (
85+ "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86+ )
87+ if is_gguf_quantized and not is_gguf_available ():
88+ raise ValueError (
89+ "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90+ )
91+
92+ weight_on_cpu = False
93+ if not module .weight .is_cuda :
94+ weight_on_cpu = True
95+
96+ if is_bnb_4bit_quantized :
97+ module_weight = dequantize_bnb_weight (
98+ module .weight .cuda () if weight_on_cpu else module .weight ,
99+ state = module .weight .quant_state ,
100+ dtype = model .dtype ,
101+ ).data
102+ elif is_gguf_quantized :
103+ module_weight = dequantize_gguf_tensor (
104+ module .weight .cuda () if weight_on_cpu else module .weight ,
105+ )
106+ module_weight = module_weight .to (model .dtype )
107+ else :
108+ module_weight = module .weight .data
109+
110+ if weight_on_cpu :
111+ module_weight = module_weight .cpu ()
112+
113+ return module_weight
114+
115+
71116class StableDiffusionLoraLoaderMixin (LoraBaseMixin ):
72117 r"""
73118 Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -2267,6 +2312,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
22672312 overwritten_params = {}
22682313
22692314 is_peft_loaded = getattr (transformer , "peft_config" , None ) is not None
2315+ is_quantized = hasattr (transformer , "hf_quantizer" )
22702316 for name , module in transformer .named_modules ():
22712317 if isinstance (module , torch .nn .Linear ):
22722318 module_weight = module .weight .data
@@ -2291,9 +2337,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
22912337 if tuple (module_weight_shape ) == (out_features , in_features ):
22922338 continue
22932339
2294- # TODO (sayakpaul): We still need to consider if the module we're expanding is
2295- # quantized and handle it accordingly if that is the case.
2296- module_out_features , module_in_features = module_weight .shape
2340+ module_out_features , module_in_features = module_weight_shape
22972341 debug_message = ""
22982342 if in_features > module_in_features :
22992343 debug_message += (
@@ -2316,6 +2360,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
23162360 parent_module_name , _ , current_module_name = name .rpartition ("." )
23172361 parent_module = transformer .get_submodule (parent_module_name )
23182362
2363+ if is_quantized :
2364+ module_weight = _maybe_dequantize_weight_for_expanded_lora (transformer , module )
2365+
2366+ # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
23192367 with torch .device ("meta" ):
23202368 expanded_module = torch .nn .Linear (
23212369 in_features , out_features , bias = bias , dtype = module_weight .dtype
@@ -2327,7 +2375,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23272375 new_weight = torch .zeros_like (
23282376 expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
23292377 )
2330- slices = tuple (slice (0 , dim ) for dim in module_weight . shape )
2378+ slices = tuple (slice (0 , dim ) for dim in module_weight_shape )
23312379 new_weight [slices ] = module_weight
23322380 tmp_state_dict = {"weight" : new_weight }
23332381 if module_bias is not None :
@@ -2416,7 +2464,12 @@ def _calculate_module_shape(
24162464 base_weight_param_name : str = None ,
24172465 ) -> "torch.Size" :
24182466 def _get_weight_shape (weight : torch .Tensor ):
2419- return weight .quant_state .shape if weight .__class__ .__name__ == "Params4bit" else weight .shape
2467+ if weight .__class__ .__name__ == "Params4bit" :
2468+ return weight .quant_state .shape
2469+ elif weight .__class__ .__name__ == "GGUFParameter" :
2470+ return weight .quant_shape
2471+ else :
2472+ return weight .shape
24202473
24212474 if base_module is not None :
24222475 return _get_weight_shape (base_module .weight )
0 commit comments