Skip to content

Commit ffe6dba

Browse files
committed
support qlora
ghstack-source-id: 42d7480 Pull Request resolved: #2487
1 parent 0060a8e commit ffe6dba

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,36 @@ def test_lora_trains_base_frozen():
154154
if name in lora_before
155155
)
156156
assert any_lora_changed, "No LoRA param changed after 5 training steps"
157+
158+
159+
def test_qlora_base_weights_quantized_adapters_full_precision():
160+
"""After first forward: base weights are NF4, LoRA adapters remain full precision."""
161+
torchao = pytest.importorskip("torchao")
162+
from torchao.dtypes.nf4tensor import NF4Tensor
163+
164+
model = SimpleModel()
165+
converter = LoRAConverter(
166+
LoRAConverter.Config(
167+
rank=4, alpha=8.0, quantize_base="nf4", nf4_scaler_block_size=1
168+
)
169+
)
170+
converter.convert(model)
171+
172+
# Before first forward: base weights are regular tensors
173+
assert not isinstance(model.fc1.weight.data, NF4Tensor)
174+
175+
# Trigger first forward to fire the quantization hook
176+
model(torch.randn(2, 64))
177+
178+
# After first forward: base weights should be NF4, adapters stay float32
179+
for name in ("fc1", "fc2"):
180+
layer = getattr(model, name)
181+
assert isinstance(
182+
layer.weight.data, NF4Tensor
183+
), f"{name}.weight should be NF4 after first forward"
184+
assert (
185+
layer.lora_a.weight.dtype == torch.float32
186+
), f"{name}.lora_a.weight should be float32"
187+
assert (
188+
layer.lora_b.weight.dtype == torch.float32
189+
), f"{name}.lora_b.weight should be float32"

torchtitan/components/lora.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,29 @@ class Config(Configurable.Config):
9595
Requires base model to be loaded from HF/initial_load_path on resume.
9696
Set to False to save full model weights for debugging without pretrained base."""
9797

98+
quantize_base: str = ""
99+
"""Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
100+
NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""
101+
102+
nf4_scaler_block_size: int = 128
103+
"""Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
104+
The default torchao value (256) may be too large for sharded tensors."""
105+
98106
def __init__(self, config: Config, **kwargs):
99107
self.rank = config.rank
100108
self.alpha = config.alpha
101109
self.save_adapter_only = config.save_adapter_only
102-
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
110+
self.quantize_base = config.quantize_base
111+
self.nf4_scaler_block_size = config.nf4_scaler_block_size
112+
if self.quantize_base and self.quantize_base != "nf4":
113+
raise ValueError(
114+
f"Unsupported quantize_base value: '{self.quantize_base}'. "
115+
"Supported values: '' (none), 'nf4'."
116+
)
117+
logger.info(
118+
f"LoRA training active with rank={self.rank}, alpha={self.alpha}"
119+
+ (f", quantize_base={self.quantize_base}" if self.quantize_base else "")
120+
)
103121

104122
def convert(self, model: nn.Module) -> None:
105123
model.requires_grad_(False)
@@ -133,5 +151,64 @@ def converter_key_filter(key: str) -> bool:
133151
object.__setattr__(module, "converter_key_filter", converter_key_filter)
134152
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)
135153

154+
# Register a one-shot forward pre-hook to quantize base weights after
155+
# checkpoint load but before the first forward pass (QLoRA).
156+
if self.quantize_base == "nf4":
157+
from torch.distributed.tensor import DTensor
158+
159+
try:
160+
from torchao.dtypes.nf4tensor import to_nf4
161+
except ImportError:
162+
raise ImportError(
163+
"QLoRA requires torchao. Install with: pip install torchao"
164+
)
165+
166+
lora_classes = tuple(_lora_class_cache.values())
167+
nf4_scaler_block_size = self.nf4_scaler_block_size
168+
169+
def _to_nf4_tensor(weight: torch.Tensor) -> torch.Tensor:
170+
"""Convert weight to NF4, handling both regular tensors and DTensors."""
171+
nf4_block_size = 64 # NF4 default block size
172+
is_dtensor = isinstance(weight, DTensor)
173+
local_weight = weight.to_local() if is_dtensor else weight
174+
175+
num_scalers = local_weight.numel() // nf4_block_size
176+
if num_scalers % nf4_scaler_block_size != 0:
177+
raise ValueError(
178+
f"NF4 quantization failed: num_scalers ({num_scalers}) is not "
179+
f"divisible by nf4_scaler_block_size ({nf4_scaler_block_size}). "
180+
f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
181+
f"(e.g., 64, 32, or 1)."
182+
)
183+
184+
nf4_local = to_nf4(
185+
local_weight, scaler_block_size=nf4_scaler_block_size
186+
)
187+
188+
if is_dtensor:
189+
return DTensor.from_local(
190+
nf4_local, weight.device_mesh, weight.placements
191+
)
192+
return nf4_local
193+
194+
def _quantize_hook(
195+
mod: nn.Module, args: Any, handle: torch.utils.hooks.RemovableHandle
196+
) -> None:
197+
for sub in mod.modules():
198+
if isinstance(sub, lora_classes):
199+
sub.weight = nn.Parameter(
200+
_to_nf4_tensor(sub.weight.data), requires_grad=False
201+
)
202+
logger.info("QLoRA: quantized base weights to NF4")
203+
handle.remove()
204+
205+
# Use a list to allow the closure to reference the handle before it exists
206+
handle_ref: list[torch.utils.hooks.RemovableHandle] = []
207+
handle_ref.append(
208+
module.register_forward_pre_hook(
209+
lambda mod, args: _quantize_hook(mod, args, handle_ref[0])
210+
)
211+
)
212+
136213
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
137214
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def llama3_debugmodel_lora() -> Trainer.Config:
130130
return config
131131

132132

133+
def llama3_debugmodel_qlora() -> Trainer.Config:
134+
config = llama3_debugmodel_lora()
135+
config.model_converters = ModelConvertersContainer.Config(
136+
converters=[
137+
LoRAConverter.Config(
138+
rank=8,
139+
alpha=16.0,
140+
quantize_base="nf4",
141+
),
142+
],
143+
)
144+
return config
145+
146+
133147
def llama3_8b() -> Trainer.Config:
134148
return Trainer.Config(
135149
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)