Skip to content

Commit 695a256

Browse files
committed
support qlora
ghstack-source-id: b6a8ff6 Pull Request resolved: #2487
1 parent 0f97a33 commit 695a256

File tree

3 files changed

+147
-1
lines changed

3 files changed

+147
-1
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,44 @@ def test_lora_trains_base_frozen():
173173
if name in lora_before
174174
)
175175
assert any_lora_changed, "No LoRA param changed after 5 training steps"
176+
177+
178+
def test_qlora_base_weights_quantized_adapters_full_precision():
179+
"""After first forward: base weights are NF4, LoRA adapters remain full precision."""
180+
torchao = pytest.importorskip("torchao")
181+
from torchao.dtypes.nf4tensor import NF4Tensor
182+
183+
model = nn.Sequential(
184+
OrderedDict(
185+
[
186+
("fc1", nn.Linear(64, 64)),
187+
("relu", nn.ReLU()),
188+
("fc2", nn.Linear(64, 64)),
189+
]
190+
)
191+
)
192+
converter = LoRAConverter(
193+
LoRAConverter.Config(
194+
rank=4, alpha=8.0, quantize_base="nf4", nf4_scaler_block_size=1
195+
)
196+
)
197+
converter.convert(model)
198+
199+
# Before first forward: base weights are regular tensors
200+
assert not isinstance(model.fc1.weight.data, NF4Tensor)
201+
202+
# Trigger first forward to fire the quantization hook
203+
model(torch.randn(2, 64))
204+
205+
# After first forward: base weights should be NF4, adapters stay float32
206+
for name in ("fc1", "fc2"):
207+
layer = getattr(model, name)
208+
assert isinstance(
209+
layer.weight.data, NF4Tensor
210+
), f"{name}.weight should be NF4 after first forward"
211+
assert (
212+
layer.lora_a.weight.dtype == torch.float32
213+
), f"{name}.lora_a.weight should be float32"
214+
assert (
215+
layer.lora_b.weight.dtype == torch.float32
216+
), f"{name}.lora_b.weight should be float32"

torchtitan/components/lora.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,29 @@ class Config(Configurable.Config):
9393
Requires base model to be loaded from HF/initial_load_path on resume.
9494
Set to False to save full model weights for debugging without pretrained base."""
9595

96+
quantize_base: str = ""
97+
"""Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
98+
NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""
99+
100+
nf4_scaler_block_size: int = 128
101+
"""Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
102+
The default torchao value (256) may be too large for sharded tensors."""
103+
96104
def __init__(self, config: Config, **kwargs):
97105
self.rank = config.rank
98106
self.alpha = config.alpha
99107
self.save_adapter_only = config.save_adapter_only
100-
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
108+
self.quantize_base = config.quantize_base
109+
self.nf4_scaler_block_size = config.nf4_scaler_block_size
110+
if self.quantize_base and self.quantize_base != "nf4":
111+
raise ValueError(
112+
f"Unsupported quantize_base value: '{self.quantize_base}'. "
113+
"Supported values: '' (none), 'nf4'."
114+
)
115+
logger.info(
116+
f"LoRA training active with rank={self.rank}, alpha={self.alpha}"
117+
+ (f", quantize_base={self.quantize_base}" if self.quantize_base else "")
118+
)
101119

102120
def convert(self, model: nn.Module) -> None:
103121
model.requires_grad_(False)
@@ -108,5 +126,78 @@ def _replace_linears_with_lora(self, module: nn.Module) -> None:
108126
if isinstance(child, nn.Linear):
109127
apply_lora(child, self.rank, self.alpha)
110128

129+
# Expose a key filter and flag on the module so ModelWrapper can
130+
# partition the state dict without knowing about LoRA internals.
131+
def converter_key_filter(key: str) -> bool:
132+
"""Return True if key was added by this converter (LoRA adapter weights)."""
133+
return ".lora_a." in key or ".lora_b." in key
134+
135+
object.__setattr__(module, "converter_key_filter", converter_key_filter)
136+
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)
137+
138+
# Register a one-shot forward pre-hook to quantize base weights after
139+
# checkpoint load but before the first forward pass (QLoRA).
140+
# TODO: Prototype — move to torchao as a proper QuantizationConverter.
141+
# to_nf4 on local tensors loses DTensor grad info, fine here since
142+
# base weights are frozen and only LoRA adapters receive gradients.
143+
if self.quantize_base == "nf4":
144+
from torch.distributed.tensor import DTensor
145+
146+
try:
147+
from torchao.dtypes.nf4tensor import to_nf4
148+
except ImportError as err:
149+
raise ImportError(
150+
"QLoRA requires torchao. Install with: pip install torchao"
151+
) from err
152+
153+
lora_classes = tuple(_lora_class_cache.values())
154+
nf4_scaler_block_size = self.nf4_scaler_block_size
155+
156+
def _to_nf4_tensor(weight: torch.Tensor) -> torch.Tensor:
157+
"""Convert weight to NF4, handling both regular tensors and DTensors."""
158+
nf4_block_size = 64 # NF4 default block size
159+
is_dtensor = isinstance(weight, DTensor)
160+
local_weight = weight.to_local() if is_dtensor else weight
161+
162+
num_scalers = local_weight.numel() // nf4_block_size
163+
if num_scalers % nf4_scaler_block_size != 0:
164+
raise ValueError(
165+
f"NF4 quantization failed: num_scalers ({num_scalers}) is not "
166+
f"divisible by nf4_scaler_block_size ({nf4_scaler_block_size}). "
167+
f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
168+
f"(e.g., 64, 32, or 1)."
169+
)
170+
171+
nf4_local = to_nf4(
172+
local_weight, scaler_block_size=nf4_scaler_block_size
173+
)
174+
175+
if is_dtensor:
176+
return DTensor.from_local(
177+
nf4_local, # pyrefly: ignore [bad-argument-type]
178+
weight.device_mesh,
179+
weight.placements,
180+
)
181+
return nf4_local # pyrefly: ignore [bad-return]
182+
183+
def _quantize_hook(
184+
mod: nn.Module, args: Any, handle: torch.utils.hooks.RemovableHandle
185+
) -> None:
186+
for sub in mod.modules():
187+
if isinstance(sub, lora_classes):
188+
sub.weight = nn.Parameter(
189+
_to_nf4_tensor(sub.weight.data), requires_grad=False
190+
)
191+
logger.info("QLoRA: quantized base weights to NF4")
192+
handle.remove()
193+
194+
# Use a list to allow the closure to reference the handle before it exists
195+
handle_ref: list[torch.utils.hooks.RemovableHandle] = []
196+
handle_ref.append(
197+
module.register_forward_pre_hook(
198+
lambda mod, args: _quantize_hook(mod, args, handle_ref[0])
199+
)
200+
)
201+
111202
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
112203
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def llama3_debugmodel_lora() -> Trainer.Config:
130130
return config
131131

132132

133+
def llama3_debugmodel_qlora() -> Trainer.Config:
134+
config = llama3_debugmodel_lora()
135+
config.model_converters = ModelConvertersContainer.Config(
136+
converters=[
137+
LoRAConverter.Config(
138+
rank=8,
139+
alpha=16.0,
140+
quantize_base="nf4",
141+
),
142+
],
143+
)
144+
return config
145+
146+
133147
def llama3_8b() -> Trainer.Config:
134148
return Trainer.Config(
135149
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)