@@ -96,11 +96,29 @@ class Config(Configurable.Config):
9696 Requires base model to be loaded from HF/initial_load_path on resume.
9797 Set to False to save full model weights for debugging without pretrained base."""
9898
99+ quantize_base : str = ""
100+ """Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
101+ NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""
102+
103+ nf4_scaler_block_size : int = 128
104+ """Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
105+ The default torchao value (256) may be too large for sharded tensors."""
106+
99107 def __init__ (self , config : Config , ** kwargs ):
100108 self .rank = config .rank
101109 self .alpha = config .alpha
102110 self .save_adapter_only = config .save_adapter_only
103- logger .info (f"LoRA training active with rank={ self .rank } , alpha={ self .alpha } " )
111+ self .quantize_base = config .quantize_base
112+ self .nf4_scaler_block_size = config .nf4_scaler_block_size
113+ if self .quantize_base and self .quantize_base != "nf4" :
114+ raise ValueError (
115+ f"Unsupported quantize_base value: '{ self .quantize_base } '. "
116+ "Supported values: '' (none), 'nf4'."
117+ )
118+ logger .info (
119+ f"LoRA training active with rank={ self .rank } , alpha={ self .alpha } "
120+ + (f", quantize_base={ self .quantize_base } " if self .quantize_base else "" )
121+ )
104122
105123 def convert (self , model : nn .Module ) -> None :
106124 model .requires_grad_ (False )
@@ -134,6 +152,58 @@ def converter_key_filter(key: str) -> bool:
134152 object .__setattr__ (module , "converter_key_filter" , converter_key_filter )
135153 object .__setattr__ (module , "save_converter_keys_only" , self .save_adapter_only )
136154
155+ # Register a one-shot forward pre-hook to quantize base weights after
156+ # checkpoint load but before the first forward pass (QLoRA).
157+ if self .quantize_base == "nf4" :
158+ from torch .distributed .tensor import DTensor
159+ from torchao .dtypes .nf4tensor import to_nf4
160+
161+ lora_classes = tuple (_lora_class_cache .values ())
162+ nf4_scaler_block_size = self .nf4_scaler_block_size
163+
164+ def _to_nf4_tensor (weight : torch .Tensor ) -> torch .Tensor :
165+ """Convert weight to NF4, handling both regular tensors and DTensors."""
166+ nf4_block_size = 64 # NF4 default block size
167+ is_dtensor = isinstance (weight , DTensor )
168+ local_weight = weight .to_local () if is_dtensor else weight
169+
170+ num_scalers = local_weight .numel () // nf4_block_size
171+ assert num_scalers % nf4_scaler_block_size == 0 , (
172+ f"NF4 quantization failed: num_scalers ({ num_scalers } ) is not "
173+ f"divisible by nf4_scaler_block_size ({ nf4_scaler_block_size } ). "
174+ f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
175+ f"(e.g., 64, 32, or 1)."
176+ )
177+
178+ nf4_local = to_nf4 (
179+ local_weight , scaler_block_size = nf4_scaler_block_size
180+ )
181+
182+ if is_dtensor :
183+ return DTensor .from_local (
184+ nf4_local , weight .device_mesh , weight .placements
185+ )
186+ return nf4_local
187+
188+ def _quantize_hook (
189+ mod : nn .Module , args : Any , handle : torch .utils .hooks .RemovableHandle
190+ ) -> None :
191+ for sub in mod .modules ():
192+ if isinstance (sub , lora_classes ):
193+ sub .weight = nn .Parameter (
194+ _to_nf4_tensor (sub .weight .data ), requires_grad = False
195+ )
196+ logger .info ("QLoRA: quantized base weights to NF4" )
197+ handle .remove ()
198+
199+ # Use a list to allow the closure to reference the handle before it exists
200+ handle_ref : list [torch .utils .hooks .RemovableHandle ] = []
201+ handle_ref .append (
202+ module .register_forward_pre_hook (
203+ lambda mod , args : _quantize_hook (mod , args , handle_ref [0 ])
204+ )
205+ )
206+
137207 def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]) -> None :
138208 pass
139209
0 commit comments