Skip to content

[3/N] Support qlora by nf4tensor#2487

Draft
mori360 wants to merge 9 commits intogh/mori360/3/basefrom
gh/mori360/3/head
Draft

[3/N] Support qlora by nf4tensor#2487
mori360 wants to merge 9 commits intogh/mori360/3/basefrom
gh/mori360/3/head

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented Mar 5, 2026

Summary

  • Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4
  • Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in
    full precision (float32)
  • Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config
  • Handle DTensor for distributed training compatibility
  • Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor
    dimensions
  • Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings)

Test Plan

  • test_qlora_base_weights_quantized_adapters_full_precision
    For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters.

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 5, 2026
mori360 added a commit that referenced this pull request Mar 5, 2026
ghstack-source-id: 146c7a3
Pull Request resolved: #2487
[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 5, 2026
ghstack-source-id: db613d3
Pull Request resolved: #2487
@mori360 mori360 changed the title support qlora [3/N] Support qlora Mar 5, 2026
@mori360 mori360 changed the title [3/N] Support qlora [3/N] Support qlora by nf4tensor Mar 5, 2026
@mori360 mori360 marked this pull request as draft March 5, 2026 00:51
mori360 added 3 commits March 5, 2026 11:41
QLoRA (LoRA + NF4-quantized base weights): Add quantize_base="nf4" option to LoRAConverter.Config that registers a one-shot forward_pre_hook to quantize base weights to NF4 after checkpoint load but before the first forward pass. Uses torchao.dtypes.nf4tensor.to_nf4 with DTensor support for distributed training. Configurable nf4_scaler_block_size (default 128) to handle sharded tensor sizes. Reduces base weight memory ~4x while keeping LoRA adapters in full precision.




[ghstack-poisoned]
QLoRA (LoRA + NF4-quantized base weights): Add quantize_base="nf4" option to LoRAConverter.Config that registers a one-shot forward_pre_hook to quantize base weights to NF4 after checkpoint load but before the first forward pass. Uses torchao.dtypes.nf4tensor.to_nf4 with DTensor support for distributed training. Configurable nf4_scaler_block_size (default 128) to handle sharded tensor sizes. Reduces base weight memory ~4x while keeping LoRA adapters in full precision.




[ghstack-poisoned]
QLoRA (LoRA + NF4-quantized base weights): Add quantize_base="nf4" option to LoRAConverter.Config that registers a one-shot forward_pre_hook to quantize base weights to NF4 after checkpoint load but before the first forward pass. Uses torchao.dtypes.nf4tensor.to_nf4 with DTensor support for distributed training. Configurable nf4_scaler_block_size (default 128) to handle sharded tensor sizes. Reduces base weight memory ~4x while keeping LoRA adapters in full precision.




[ghstack-poisoned]
"""Convert weight to NF4, handling both regular tensors and DTensors."""
nf4_block_size = 64 # NF4 default block size
is_dtensor = isinstance(weight, DTensor)
local_weight = weight.to_local() if is_dtensor else weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtensor.to_local() should never be used without grad_placements, o/w we could lose necessary reductions on gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a prototype to proof that we could run with NF4. We only convert origin weights to nf4, which are freezed(requires_grad=False), so that losing gradients here should be fine. If we decide to support qlora later, We will later corporate with torchao to support better on dtensor

) -> None:
for sub in mod.modules():
if isinstance(sub, lora_classes):
sub.weight = nn.Parameter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to do this conversion every single time the forward is called?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the result of a nf4 @ nf4 matmul

  • if it's bf16, then you have to convert the activations to a LoRA linear
  • if it's nf4, then other modules like rmsnorm and attention also needs to be able to work with nf4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation: bf16, weight: nf4
during forward, we would first dequantizes nf4 to bf16 first, so it's computation at bf16
nf4 @ bf16 -> bf16
we only have the quantization on origin weights, and the outputs are all in bf16, so that later computation are not affected, and we only save the memory for the origin weights (bf16 -> nf4)

@tianyu-l tianyu-l requested a review from danielvegamyhre March 8, 2026 22:41
mori360 added 4 commits March 12, 2026 12:45
### Summary

  - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4
  - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in
   full precision (float32)
  - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config
  - Handle DTensor for distributed training compatibility
  - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor
  dimensions
  - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings)

### Test Plan

- test_qlora_base_weights_quantized_adapters_full_precision
For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters.





[ghstack-poisoned]
### Summary

  - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4
  - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in
   full precision (float32)
  - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config
  - Handle DTensor for distributed training compatibility
  - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor
  dimensions
  - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings)

### Test Plan

- test_qlora_base_weights_quantized_adapters_full_precision
For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters.





[ghstack-poisoned]
### Summary

  - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4
  - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in
   full precision (float32)
  - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config
  - Handle DTensor for distributed training compatibility
  - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor
  dimensions
  - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings)

### Test Plan

- test_qlora_base_weights_quantized_adapters_full_precision
For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters.





[ghstack-poisoned]
### Summary

  - Add QLoRA (quantized LoRA) support via NF4 quantization of base weights using torchao's to_nf4
  - Base weights are quantized to NF4 via a one-shot forward pre-hook that fires before the first forward pass and self-removes, while LoRA adapters remain in
   full precision (float32)
  - Add quantize_base and nf4_scaler_block_size config fields to LoRAConverter.Config
  - Handle DTensor for distributed training compatibility
  - Proper error handling: ImportError with chained exception for missing torchao, ValueError for invalid quantize_base values and incompatible tensor
  dimensions
  - Add llama3_debugmodel_qlora debug config (derives from llama3_debugmodel_lora to inherit checkpoint settings)

### Test Plan

- test_qlora_base_weights_quantized_adapters_full_precision
For QLoRA with NF4Tensor, before first forward: asserts fc1.weight is NOT NF4Tensor (quantization hasn't fired yet). Calls model(torch.randn(2, 64)) — this triggers the one-shot register_forward_pre_hook that calls to_nf4() on all base weights. After first forward: for both fc1 and fc2, asserts weight.data IS an NF4Tensor, while lora_a.weight.dtype and lora_b.weight.dtype remain torch.float32. This validates the core QLoRA property: compressed base weights + full-precision adapters.





[ghstack-poisoned]
@mori360 mori360 mentioned this pull request Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants