Skip to content

Commit 32fc92b

Browse files
committed
support qlora
ghstack-source-id: db613d3 Pull Request resolved: #2487
1 parent 8a322c0 commit 32fc92b

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

torchtitan/components/lora.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,29 @@ class Config(Configurable.Config):
9696
Requires base model to be loaded from HF/initial_load_path on resume.
9797
Set to False to save full model weights for debugging without pretrained base."""
9898

99+
quantize_base: str = ""
100+
"""Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
101+
NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""
102+
103+
nf4_scaler_block_size: int = 128
104+
"""Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
105+
The default torchao value (256) may be too large for sharded tensors."""
106+
99107
def __init__(self, config: Config, **kwargs):
100108
self.rank = config.rank
101109
self.alpha = config.alpha
102110
self.save_adapter_only = config.save_adapter_only
103-
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
111+
self.quantize_base = config.quantize_base
112+
self.nf4_scaler_block_size = config.nf4_scaler_block_size
113+
if self.quantize_base and self.quantize_base != "nf4":
114+
raise ValueError(
115+
f"Unsupported quantize_base value: '{self.quantize_base}'. "
116+
"Supported values: '' (none), 'nf4'."
117+
)
118+
logger.info(
119+
f"LoRA training active with rank={self.rank}, alpha={self.alpha}"
120+
+ (f", quantize_base={self.quantize_base}" if self.quantize_base else "")
121+
)
104122

105123
def convert(self, model: nn.Module) -> None:
106124
model.requires_grad_(False)
@@ -134,6 +152,58 @@ def converter_key_filter(key: str) -> bool:
134152
object.__setattr__(module, "converter_key_filter", converter_key_filter)
135153
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)
136154

155+
# Register a one-shot forward pre-hook to quantize base weights after
156+
# checkpoint load but before the first forward pass (QLoRA).
157+
if self.quantize_base == "nf4":
158+
from torch.distributed.tensor import DTensor
159+
from torchao.dtypes.nf4tensor import to_nf4
160+
161+
lora_classes = tuple(_lora_class_cache.values())
162+
nf4_scaler_block_size = self.nf4_scaler_block_size
163+
164+
def _to_nf4_tensor(weight: torch.Tensor) -> torch.Tensor:
165+
"""Convert weight to NF4, handling both regular tensors and DTensors."""
166+
nf4_block_size = 64 # NF4 default block size
167+
is_dtensor = isinstance(weight, DTensor)
168+
local_weight = weight.to_local() if is_dtensor else weight
169+
170+
num_scalers = local_weight.numel() // nf4_block_size
171+
assert num_scalers % nf4_scaler_block_size == 0, (
172+
f"NF4 quantization failed: num_scalers ({num_scalers}) is not "
173+
f"divisible by nf4_scaler_block_size ({nf4_scaler_block_size}). "
174+
f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
175+
f"(e.g., 64, 32, or 1)."
176+
)
177+
178+
nf4_local = to_nf4(
179+
local_weight, scaler_block_size=nf4_scaler_block_size
180+
)
181+
182+
if is_dtensor:
183+
return DTensor.from_local(
184+
nf4_local, weight.device_mesh, weight.placements
185+
)
186+
return nf4_local
187+
188+
def _quantize_hook(
189+
mod: nn.Module, args: Any, handle: torch.utils.hooks.RemovableHandle
190+
) -> None:
191+
for sub in mod.modules():
192+
if isinstance(sub, lora_classes):
193+
sub.weight = nn.Parameter(
194+
_to_nf4_tensor(sub.weight.data), requires_grad=False
195+
)
196+
logger.info("QLoRA: quantized base weights to NF4")
197+
handle.remove()
198+
199+
# Use a list to allow the closure to reference the handle before it exists
200+
handle_ref: list[torch.utils.hooks.RemovableHandle] = []
201+
handle_ref.append(
202+
module.register_forward_pre_hook(
203+
lambda mod, args: _quantize_hook(mod, args, handle_ref[0])
204+
)
205+
)
206+
137207
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
138208
pass
139209

torchtitan/models/llama3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def llama3_debugmodel_lora() -> Trainer.Config:
130130
return config
131131

132132

133+
def llama3_debugmodel_qlora() -> Trainer.Config:
134+
config = llama3_debugmodel_lora()
135+
config.model_converters = ModelConvertersContainer.Config(
136+
converters=[
137+
LoRAConverter.Config(
138+
rank=8,
139+
alpha=16.0,
140+
quantize_base="nf4",
141+
),
142+
],
143+
)
144+
return config
145+
146+
133147
def llama3_8b() -> Trainer.Config:
134148
return Trainer.Config(
135149
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)