[3/N] Support qlora by nf4tensor#2487
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
QLoRA (LoRA + NF4-quantized base weights): Add quantize_base="nf4" option to LoRAConverter.Config that registers a one-shot forward_pre_hook to quantize base weights to NF4 after checkpoint load but before the first forward pass. Uses torchao.dtypes.nf4tensor.to_nf4 with DTensor support for distributed training. Configurable nf4_scaler_block_size (default 128) to handle sharded tensor sizes. Reduces base weight memory ~4x while keeping LoRA adapters in full precision. [ghstack-poisoned]
QLoRA (LoRA + NF4-quantized base weights): Add quantize_base="nf4" option to LoRAConverter.Config that registers a one-shot forward_pre_hook to quantize base weights to NF4 after checkpoint load but before the first forward pass. Uses torchao.dtypes.nf4tensor.to_nf4 with DTensor support for distributed training. Configurable nf4_scaler_block_size (default 128) to handle sharded tensor sizes. Reduces base weight memory ~4x while keeping LoRA adapters in full precision. [ghstack-poisoned]
QLoRA (LoRA + NF4-quantized base weights): Add quantize_base="nf4" option to LoRAConverter.Config that registers a one-shot forward_pre_hook to quantize base weights to NF4 after checkpoint load but before the first forward pass. Uses torchao.dtypes.nf4tensor.to_nf4 with DTensor support for distributed training. Configurable nf4_scaler_block_size (default 128) to handle sharded tensor sizes. Reduces base weight memory ~4x while keeping LoRA adapters in full precision. [ghstack-poisoned]
| """Convert weight to NF4, handling both regular tensors and DTensors.""" | ||
| nf4_block_size = 64 # NF4 default block size | ||
| is_dtensor = isinstance(weight, DTensor) | ||
| local_weight = weight.to_local() if is_dtensor else weight |
There was a problem hiding this comment.
dtensor.to_local() should never be used without grad_placements, o/w we could lose necessary reductions on gradients.
There was a problem hiding this comment.
Yeah, it's a prototype to proof that we could run with NF4. We only convert origin weights to nf4, which are freezed(requires_grad=False), so that losing gradients here should be fine. If we decide to support qlora later, We will later corporate with torchao to support better on dtensor
| ) -> None: | ||
| for sub in mod.modules(): | ||
| if isinstance(sub, lora_classes): | ||
| sub.weight = nn.Parameter( |
There was a problem hiding this comment.
Do we have to do this conversion every single time the forward is called?
There was a problem hiding this comment.
what's the result of a nf4 @ nf4 matmul
- if it's bf16, then you have to convert the activations to a LoRA linear
- if it's nf4, then other modules like rmsnorm and attention also needs to be able to work with nf4
There was a problem hiding this comment.
activation: bf16, weight: nf4
during forward, we would first dequantizes nf4 to bf16 first, so it's computation at bf16
nf4 @ bf16 -> bf16
we only have the quantization on origin weights, and the outputs are all in bf16, so that later computation are not affected, and we only save the memory for the origin weights (bf16 -> nf4)
### Summary - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4 - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in full precision (float32) - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config - Handle DTensor for distributed training compatibility - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor dimensions - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings) ### Test Plan - test_qlora_base_weights_quantized_adapters_full_precision For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters. [ghstack-poisoned]
### Summary - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4 - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in full precision (float32) - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config - Handle DTensor for distributed training compatibility - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor dimensions - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings) ### Test Plan - test_qlora_base_weights_quantized_adapters_full_precision For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters. [ghstack-poisoned]
### Summary - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4 - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in full precision (float32) - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config - Handle DTensor for distributed training compatibility - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor dimensions - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings) ### Test Plan - test_qlora_base_weights_quantized_adapters_full_precision For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters. [ghstack-poisoned]
### Summary - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4 - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in full precision (float32) - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config - Handle DTensor for distributed training compatibility - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor dimensions - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings) ### Test Plan - test_qlora_base_weights_quantized_adapters_full_precision For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters. [ghstack-poisoned]
Summary
full precision (float32)
dimensions
Test Plan
For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters.
Stack from ghstack (oldest at bottom):