-
Notifications
You must be signed in to change notification settings - Fork 755
[3/N] Support qlora by nf4tensor #2487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mori360/3/base
Are you sure you want to change the base?
Changes from all commits
f5e39ab
565a005
9b82e66
ad1834c
aa570e4
7a2b16b
2b8644f
11e32ae
039dd19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,11 +93,29 @@ class Config(Configurable.Config): | |
| Requires base model to be loaded from HF/initial_load_path on resume. | ||
| Set to False to save full model weights for debugging without pretrained base.""" | ||
|
|
||
| quantize_base: str = "" | ||
| """Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA). | ||
| NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision.""" | ||
|
|
||
| nf4_scaler_block_size: int = 128 | ||
| """Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs. | ||
| The default torchao value (256) may be too large for sharded tensors.""" | ||
|
|
||
| def __init__(self, config: Config, **kwargs): | ||
| self.rank = config.rank | ||
| self.alpha = config.alpha | ||
| self.save_adapter_only = config.save_adapter_only | ||
| logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}") | ||
| self.quantize_base = config.quantize_base | ||
| self.nf4_scaler_block_size = config.nf4_scaler_block_size | ||
| if self.quantize_base and self.quantize_base != "nf4": | ||
| raise ValueError( | ||
| f"Unsupported quantize_base value: '{self.quantize_base}'. " | ||
| "Supported values: '' (none), 'nf4'." | ||
| ) | ||
| logger.info( | ||
| f"LoRA training active with rank={self.rank}, alpha={self.alpha}" | ||
| + (f", quantize_base={self.quantize_base}" if self.quantize_base else "") | ||
| ) | ||
|
|
||
| def convert(self, model: nn.Module) -> None: | ||
| model.requires_grad_(False) | ||
|
|
@@ -108,5 +126,78 @@ def _replace_linears_with_lora(self, module: nn.Module) -> None: | |
| if isinstance(child, nn.Linear): | ||
| apply_lora(child, self.rank, self.alpha) | ||
|
|
||
| # Expose a key filter and flag on the module so ModelWrapper can | ||
| # partition the state dict without knowing about LoRA internals. | ||
| def converter_key_filter(key: str) -> bool: | ||
| """Return True if key was added by this converter (LoRA adapter weights).""" | ||
| return ".lora_a." in key or ".lora_b." in key | ||
|
|
||
| object.__setattr__(module, "converter_key_filter", converter_key_filter) | ||
| object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only) | ||
|
|
||
| # Register a one-shot forward pre-hook to quantize base weights after | ||
| # checkpoint load but before the first forward pass (QLoRA). | ||
| # TODO: Prototype — move to torchao as a proper QuantizationConverter. | ||
| # to_nf4 on local tensors loses DTensor grad info, fine here since | ||
| # base weights are frozen and only LoRA adapters receive gradients. | ||
| if self.quantize_base == "nf4": | ||
| from torch.distributed.tensor import DTensor | ||
|
|
||
| try: | ||
| from torchao.dtypes.nf4tensor import to_nf4 | ||
| except ImportError as err: | ||
| raise ImportError( | ||
| "QLoRA requires torchao. Install with: pip install torchao" | ||
| ) from err | ||
|
|
||
| lora_classes = tuple(_lora_class_cache.values()) | ||
| nf4_scaler_block_size = self.nf4_scaler_block_size | ||
|
|
||
| def _to_nf4_tensor(weight: torch.Tensor) -> torch.Tensor: | ||
| """Convert weight to NF4, handling both regular tensors and DTensors.""" | ||
| nf4_block_size = 64 # NF4 default block size | ||
| is_dtensor = isinstance(weight, DTensor) | ||
| local_weight = weight.to_local() if is_dtensor else weight | ||
|
|
||
| num_scalers = local_weight.numel() // nf4_block_size | ||
| if num_scalers % nf4_scaler_block_size != 0: | ||
| raise ValueError( | ||
| f"NF4 quantization failed: num_scalers ({num_scalers}) is not " | ||
| f"divisible by nf4_scaler_block_size ({nf4_scaler_block_size}). " | ||
| f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config " | ||
| f"(e.g., 64, 32, or 1)." | ||
| ) | ||
|
|
||
| nf4_local = to_nf4( | ||
| local_weight, scaler_block_size=nf4_scaler_block_size | ||
| ) | ||
|
|
||
| if is_dtensor: | ||
| return DTensor.from_local( | ||
| nf4_local, # pyrefly: ignore [bad-argument-type] | ||
| weight.device_mesh, | ||
| weight.placements, | ||
| ) | ||
| return nf4_local # pyrefly: ignore [bad-return] | ||
|
|
||
| def _quantize_hook( | ||
| mod: nn.Module, args: Any, handle: torch.utils.hooks.RemovableHandle | ||
| ) -> None: | ||
| for sub in mod.modules(): | ||
| if isinstance(sub, lora_classes): | ||
| sub.weight = nn.Parameter( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have to do this conversion every single time the forward is called?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the result of a nf4 @ nf4 matmul
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. activation: bf16, weight: nf4 |
||
| _to_nf4_tensor(sub.weight.data), requires_grad=False | ||
| ) | ||
| logger.info("QLoRA: quantized base weights to NF4") | ||
| handle.remove() | ||
|
|
||
| # Use a list to allow the closure to reference the handle before it exists | ||
| handle_ref: list[torch.utils.hooks.RemovableHandle] = [] | ||
| handle_ref.append( | ||
| module.register_forward_pre_hook( | ||
| lambda mod, args: _quantize_hook(mod, args, handle_ref[0]) | ||
| ) | ||
| ) | ||
|
|
||
| def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None: | ||
| pass | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtensor.to_local() should never be used without
grad_placements, o/w we could lose necessary reductions on gradients.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's a prototype to proof that we could run with NF4. We only convert origin weights to nf4, which are freezed(requires_grad=False), so that losing gradients here should be fine. If we decide to support qlora later, We will later corporate with torchao to support better on dtensor