@@ -95,11 +95,29 @@ class Config(Configurable.Config):
9595 Requires base model to be loaded from HF/initial_load_path on resume.
9696 Set to False to save full model weights for debugging without pretrained base."""
9797
98+ quantize_base : str = ""
99+ """Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
100+ NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""
101+
102+ nf4_scaler_block_size : int = 128
103+ """Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
104+ The default torchao value (256) may be too large for sharded tensors."""
105+
98106 def __init__ (self , config : Config , ** kwargs ):
99107 self .rank = config .rank
100108 self .alpha = config .alpha
101109 self .save_adapter_only = config .save_adapter_only
102- logger .info (f"LoRA training active with rank={ self .rank } , alpha={ self .alpha } " )
110+ self .quantize_base = config .quantize_base
111+ self .nf4_scaler_block_size = config .nf4_scaler_block_size
112+ if self .quantize_base and self .quantize_base != "nf4" :
113+ raise ValueError (
114+ f"Unsupported quantize_base value: '{ self .quantize_base } '. "
115+ "Supported values: '' (none), 'nf4'."
116+ )
117+ logger .info (
118+ f"LoRA training active with rank={ self .rank } , alpha={ self .alpha } "
119+ + (f", quantize_base={ self .quantize_base } " if self .quantize_base else "" )
120+ )
103121
104122 def convert (self , model : nn .Module ) -> None :
105123 model .requires_grad_ (False )
@@ -133,5 +151,64 @@ def converter_key_filter(key: str) -> bool:
133151 object .__setattr__ (module , "converter_key_filter" , converter_key_filter )
134152 object .__setattr__ (module , "save_converter_keys_only" , self .save_adapter_only )
135153
154+ # Register a one-shot forward pre-hook to quantize base weights after
155+ # checkpoint load but before the first forward pass (QLoRA).
156+ if self .quantize_base == "nf4" :
157+ from torch .distributed .tensor import DTensor
158+
159+ try :
160+ from torchao .dtypes .nf4tensor import to_nf4
161+ except ImportError :
162+ raise ImportError (
163+ "QLoRA requires torchao. Install with: pip install torchao"
164+ )
165+
166+ lora_classes = tuple (_lora_class_cache .values ())
167+ nf4_scaler_block_size = self .nf4_scaler_block_size
168+
169+ def _to_nf4_tensor (weight : torch .Tensor ) -> torch .Tensor :
170+ """Convert weight to NF4, handling both regular tensors and DTensors."""
171+ nf4_block_size = 64 # NF4 default block size
172+ is_dtensor = isinstance (weight , DTensor )
173+ local_weight = weight .to_local () if is_dtensor else weight
174+
175+ num_scalers = local_weight .numel () // nf4_block_size
176+ if num_scalers % nf4_scaler_block_size != 0 :
177+ raise ValueError (
178+ f"NF4 quantization failed: num_scalers ({ num_scalers } ) is not "
179+ f"divisible by nf4_scaler_block_size ({ nf4_scaler_block_size } ). "
180+ f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
181+ f"(e.g., 64, 32, or 1)."
182+ )
183+
184+ nf4_local = to_nf4 (
185+ local_weight , scaler_block_size = nf4_scaler_block_size
186+ )
187+
188+ if is_dtensor :
189+ return DTensor .from_local (
190+ nf4_local , weight .device_mesh , weight .placements
191+ )
192+ return nf4_local
193+
194+ def _quantize_hook (
195+ mod : nn .Module , args : Any , handle : torch .utils .hooks .RemovableHandle
196+ ) -> None :
197+ for sub in mod .modules ():
198+ if isinstance (sub , lora_classes ):
199+ sub .weight = nn .Parameter (
200+ _to_nf4_tensor (sub .weight .data ), requires_grad = False
201+ )
202+ logger .info ("QLoRA: quantized base weights to NF4" )
203+ handle .remove ()
204+
205+ # Use a list to allow the closure to reference the handle before it exists
206+ handle_ref : list [torch .utils .hooks .RemovableHandle ] = []
207+ handle_ref .append (
208+ module .register_forward_pre_hook (
209+ lambda mod , args : _quantize_hook (mod , args , handle_ref [0 ])
210+ )
211+ )
212+
136213 def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]) -> None :
137214 pass
0 commit comments