Skip to content

Commit 6b775ca

Browse files
committed
support qlora
ghstack-source-id: 1353f02 Pull Request resolved: #2487
1 parent 46eebfb commit 6b775ca

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,36 @@ def test_lora_trains_base_frozen():
165165
if name in lora_before
166166
)
167167
assert any_lora_changed, "No LoRA param changed after 5 training steps"
168+
169+
170+
def test_qlora_base_weights_quantized_adapters_full_precision():
171+
"""After first forward: base weights are NF4, LoRA adapters remain full precision."""
172+
torchao = pytest.importorskip("torchao")
173+
from torchao.dtypes.nf4tensor import NF4Tensor
174+
175+
model = SimpleModel()
176+
converter = LoRAConverter(
177+
LoRAConverter.Config(
178+
rank=4, alpha=8.0, quantize_base="nf4", nf4_scaler_block_size=1
179+
)
180+
)
181+
converter.convert(model)
182+
183+
# Before first forward: base weights are regular tensors
184+
assert not isinstance(model.fc1.weight.data, NF4Tensor)
185+
186+
# Trigger first forward to fire the quantization hook
187+
model(torch.randn(2, 64))
188+
189+
# After first forward: base weights should be NF4, adapters stay float32
190+
for name in ("fc1", "fc2"):
191+
layer = getattr(model, name)
192+
assert isinstance(
193+
layer.weight.data, NF4Tensor
194+
), f"{name}.weight should be NF4 after first forward"
195+
assert (
196+
layer.lora_a.weight.dtype == torch.float32
197+
), f"{name}.lora_a.weight should be float32"
198+
assert (
199+
layer.lora_b.weight.dtype == torch.float32
200+
), f"{name}.lora_b.weight should be float32"

torchtitan/components/lora.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,29 @@ class Config(Configurable.Config):
9393
Requires base model to be loaded from HF/initial_load_path on resume.
9494
Set to False to save full model weights for debugging without pretrained base."""
9595

96+
quantize_base: str = ""
97+
"""Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
98+
NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""
99+
100+
nf4_scaler_block_size: int = 128
101+
"""Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
102+
The default torchao value (256) may be too large for sharded tensors."""
103+
96104
def __init__(self, config: Config, **kwargs):
97105
self.rank = config.rank
98106
self.alpha = config.alpha
99107
self.save_adapter_only = config.save_adapter_only
100-
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
108+
self.quantize_base = config.quantize_base
109+
self.nf4_scaler_block_size = config.nf4_scaler_block_size
110+
if self.quantize_base and self.quantize_base != "nf4":
111+
raise ValueError(
112+
f"Unsupported quantize_base value: '{self.quantize_base}'. "
113+
"Supported values: '' (none), 'nf4'."
114+
)
115+
logger.info(
116+
f"LoRA training active with rank={self.rank}, alpha={self.alpha}"
117+
+ (f", quantize_base={self.quantize_base}" if self.quantize_base else "")
118+
)
101119

102120
def convert(self, model: nn.Module) -> None:
103121
model.requires_grad_(False)
@@ -108,5 +126,76 @@ def _replace_linears_with_lora(self, module: nn.Module) -> None:
108126
if isinstance(child, nn.Linear):
109127
apply_lora(child, self.rank, self.alpha)
110128

129+
# Expose a key filter and flag on the module so ModelWrapper can
130+
# partition the state dict without knowing about LoRA internals.
131+
def converter_key_filter(key: str) -> bool:
132+
"""Return True if key was added by this converter (LoRA adapter weights)."""
133+
return ".lora_a." in key or ".lora_b." in key
134+
135+
object.__setattr__(module, "converter_key_filter", converter_key_filter)
136+
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)
137+
138+
# Register a one-shot forward pre-hook to quantize base weights after
139+
# checkpoint load but before the first forward pass (QLoRA).
140+
# TODO: Prototype — move to torchao as a proper QuantizationConverter.
141+
# to_nf4 on local tensors loses DTensor grad info, fine here since
142+
# base weights are frozen and only LoRA adapters receive gradients.
143+
if self.quantize_base == "nf4":
144+
from torch.distributed.tensor import DTensor
145+
146+
try:
147+
from torchao.dtypes.nf4tensor import to_nf4
148+
except ImportError as err:
149+
raise ImportError(
150+
"QLoRA requires torchao. Install with: pip install torchao"
151+
) from err
152+
153+
lora_classes = tuple(_lora_class_cache.values())
154+
nf4_scaler_block_size = self.nf4_scaler_block_size
155+
156+
def _to_nf4_tensor(weight: torch.Tensor) -> torch.Tensor:
157+
"""Convert weight to NF4, handling both regular tensors and DTensors."""
158+
nf4_block_size = 64 # NF4 default block size
159+
is_dtensor = isinstance(weight, DTensor)
160+
local_weight = weight.to_local() if is_dtensor else weight
161+
162+
num_scalers = local_weight.numel() // nf4_block_size
163+
if num_scalers % nf4_scaler_block_size != 0:
164+
raise ValueError(
165+
f"NF4 quantization failed: num_scalers ({num_scalers}) is not "
166+
f"divisible by nf4_scaler_block_size ({nf4_scaler_block_size}). "
167+
f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
168+
f"(e.g., 64, 32, or 1)."
169+
)
170+
171+
nf4_local = to_nf4(
172+
local_weight, scaler_block_size=nf4_scaler_block_size
173+
)
174+
175+
if is_dtensor:
176+
return DTensor.from_local(
177+
nf4_local, weight.device_mesh, weight.placements
178+
)
179+
return nf4_local
180+
181+
def _quantize_hook(
182+
mod: nn.Module, args: Any, handle: torch.utils.hooks.RemovableHandle
183+
) -> None:
184+
for sub in mod.modules():
185+
if isinstance(sub, lora_classes):
186+
sub.weight = nn.Parameter(
187+
_to_nf4_tensor(sub.weight.data), requires_grad=False
188+
)
189+
logger.info("QLoRA: quantized base weights to NF4")
190+
handle.remove()
191+
192+
# Use a list to allow the closure to reference the handle before it exists
193+
handle_ref: list[torch.utils.hooks.RemovableHandle] = []
194+
handle_ref.append(
195+
module.register_forward_pre_hook(
196+
lambda mod, args: _quantize_hook(mod, args, handle_ref[0])
197+
)
198+
)
199+
111200
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
112201
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def llama3_debugmodel_lora() -> Trainer.Config:
130130
return config
131131

132132

133+
def llama3_debugmodel_qlora() -> Trainer.Config:
134+
config = llama3_debugmodel_lora()
135+
config.model_converters = ModelConvertersContainer.Config(
136+
converters=[
137+
LoRAConverter.Config(
138+
rank=8,
139+
alpha=16.0,
140+
quantize_base="nf4",
141+
),
142+
],
143+
)
144+
return config
145+
146+
133147
def llama3_8b() -> Trainer.Config:
134148
return Trainer.Config(
135149
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)