@@ -93,11 +93,29 @@ class Config(Configurable.Config):
9393 Requires base model to be loaded from HF/initial_load_path on resume.
9494 Set to False to save full model weights for debugging without pretrained base."""
9595
96+ quantize_base : str = ""
97+ """Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
98+ NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""
99+
100+ nf4_scaler_block_size : int = 128
101+ """Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
102+ The default torchao value (256) may be too large for sharded tensors."""
103+
96104 def __init__ (self , config : Config , ** kwargs ):
97105 self .rank = config .rank
98106 self .alpha = config .alpha
99107 self .save_adapter_only = config .save_adapter_only
100- logger .info (f"LoRA training active with rank={ self .rank } , alpha={ self .alpha } " )
108+ self .quantize_base = config .quantize_base
109+ self .nf4_scaler_block_size = config .nf4_scaler_block_size
110+ if self .quantize_base and self .quantize_base != "nf4" :
111+ raise ValueError (
112+ f"Unsupported quantize_base value: '{ self .quantize_base } '. "
113+ "Supported values: '' (none), 'nf4'."
114+ )
115+ logger .info (
116+ f"LoRA training active with rank={ self .rank } , alpha={ self .alpha } "
117+ + (f", quantize_base={ self .quantize_base } " if self .quantize_base else "" )
118+ )
101119
102120 def convert (self , model : nn .Module ) -> None :
103121 model .requires_grad_ (False )
@@ -108,5 +126,78 @@ def _replace_linears_with_lora(self, module: nn.Module) -> None:
108126 if isinstance (child , nn .Linear ):
109127 apply_lora (child , self .rank , self .alpha )
110128
129+ # Expose a key filter and flag on the module so ModelWrapper can
130+ # partition the state dict without knowing about LoRA internals.
131+ def converter_key_filter (key : str ) -> bool :
132+ """Return True if key was added by this converter (LoRA adapter weights)."""
133+ return ".lora_a." in key or ".lora_b." in key
134+
135+ object .__setattr__ (module , "converter_key_filter" , converter_key_filter )
136+ object .__setattr__ (module , "save_converter_keys_only" , self .save_adapter_only )
137+
138+ # Register a one-shot forward pre-hook to quantize base weights after
139+ # checkpoint load but before the first forward pass (QLoRA).
140+ # TODO: Prototype — move to torchao as a proper QuantizationConverter.
141+ # to_nf4 on local tensors loses DTensor grad info, fine here since
142+ # base weights are frozen and only LoRA adapters receive gradients.
143+ if self .quantize_base == "nf4" :
144+ from torch .distributed .tensor import DTensor
145+
146+ try :
147+ from torchao .dtypes .nf4tensor import to_nf4
148+ except ImportError as err :
149+ raise ImportError (
150+ "QLoRA requires torchao. Install with: pip install torchao"
151+ ) from err
152+
153+ lora_classes = tuple (_lora_class_cache .values ())
154+ nf4_scaler_block_size = self .nf4_scaler_block_size
155+
156+ def _to_nf4_tensor (weight : torch .Tensor ) -> torch .Tensor :
157+ """Convert weight to NF4, handling both regular tensors and DTensors."""
158+ nf4_block_size = 64 # NF4 default block size
159+ is_dtensor = isinstance (weight , DTensor )
160+ local_weight = weight .to_local () if is_dtensor else weight
161+
162+ num_scalers = local_weight .numel () // nf4_block_size
163+ if num_scalers % nf4_scaler_block_size != 0 :
164+ raise ValueError (
165+ f"NF4 quantization failed: num_scalers ({ num_scalers } ) is not "
166+ f"divisible by nf4_scaler_block_size ({ nf4_scaler_block_size } ). "
167+ f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
168+ f"(e.g., 64, 32, or 1)."
169+ )
170+
171+ nf4_local = to_nf4 (
172+ local_weight , scaler_block_size = nf4_scaler_block_size
173+ )
174+
175+ if is_dtensor :
176+ return DTensor .from_local (
177+ nf4_local , # pyrefly: ignore [bad-argument-type]
178+ weight .device_mesh ,
179+ weight .placements ,
180+ )
181+ return nf4_local # pyrefly: ignore [bad-return]
182+
183+ def _quantize_hook (
184+ mod : nn .Module , args : Any , handle : torch .utils .hooks .RemovableHandle
185+ ) -> None :
186+ for sub in mod .modules ():
187+ if isinstance (sub , lora_classes ):
188+ sub .weight = nn .Parameter (
189+ _to_nf4_tensor (sub .weight .data ), requires_grad = False
190+ )
191+ logger .info ("QLoRA: quantized base weights to NF4" )
192+ handle .remove ()
193+
194+ # Use a list to allow the closure to reference the handle before it exists
195+ handle_ref : list [torch .utils .hooks .RemovableHandle ] = []
196+ handle_ref .append (
197+ module .register_forward_pre_hook (
198+ lambda mod , args : _quantize_hook (mod , args , handle_ref [0 ])
199+ )
200+ )
201+
111202 def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]) -> None :
112203 pass
0 commit comments