Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,44 @@ def test_lora_trains_base_frozen():
if name in lora_before
)
assert any_lora_changed, "No LoRA param changed after 5 training steps"


def test_qlora_base_weights_quantized_adapters_full_precision():
"""After first forward: base weights are NF4, LoRA adapters remain full precision."""
torchao = pytest.importorskip("torchao")
from torchao.dtypes.nf4tensor import NF4Tensor

model = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(64, 64)),
("relu", nn.ReLU()),
("fc2", nn.Linear(64, 64)),
]
)
)
converter = LoRAConverter(
LoRAConverter.Config(
rank=4, alpha=8.0, quantize_base="nf4", nf4_scaler_block_size=1
)
)
converter.convert(model)

# Before first forward: base weights are regular tensors
assert not isinstance(model.fc1.weight.data, NF4Tensor)

# Trigger first forward to fire the quantization hook
model(torch.randn(2, 64))

# After first forward: base weights should be NF4, adapters stay float32
for name in ("fc1", "fc2"):
layer = getattr(model, name)
assert isinstance(
layer.weight.data, NF4Tensor
), f"{name}.weight should be NF4 after first forward"
assert (
layer.lora_a.weight.dtype == torch.float32
), f"{name}.lora_a.weight should be float32"
assert (
layer.lora_b.weight.dtype == torch.float32
), f"{name}.lora_b.weight should be float32"
93 changes: 92 additions & 1 deletion torchtitan/components/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,29 @@ class Config(Configurable.Config):
Requires base model to be loaded from HF/initial_load_path on resume.
Set to False to save full model weights for debugging without pretrained base."""

quantize_base: str = ""
"""Quantize base (non-LoRA) weights. "" = no quantization, "nf4" = NF4 (QLoRA).
NF4 quantization reduces base weight memory ~4x while keeping LoRA adapters in full precision."""

nf4_scaler_block_size: int = 128
"""Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
The default torchao value (256) may be too large for sharded tensors."""

def __init__(self, config: Config, **kwargs):
self.rank = config.rank
self.alpha = config.alpha
self.save_adapter_only = config.save_adapter_only
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
self.quantize_base = config.quantize_base
self.nf4_scaler_block_size = config.nf4_scaler_block_size
if self.quantize_base and self.quantize_base != "nf4":
raise ValueError(
f"Unsupported quantize_base value: '{self.quantize_base}'. "
"Supported values: '' (none), 'nf4'."
)
logger.info(
f"LoRA training active with rank={self.rank}, alpha={self.alpha}"
+ (f", quantize_base={self.quantize_base}" if self.quantize_base else "")
)

def convert(self, model: nn.Module) -> None:
model.requires_grad_(False)
Expand All @@ -108,5 +126,78 @@ def _replace_linears_with_lora(self, module: nn.Module) -> None:
if isinstance(child, nn.Linear):
apply_lora(child, self.rank, self.alpha)

# Expose a key filter and flag on the module so ModelWrapper can
# partition the state dict without knowing about LoRA internals.
def converter_key_filter(key: str) -> bool:
"""Return True if key was added by this converter (LoRA adapter weights)."""
return ".lora_a." in key or ".lora_b." in key

object.__setattr__(module, "converter_key_filter", converter_key_filter)
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)

# Register a one-shot forward pre-hook to quantize base weights after
# checkpoint load but before the first forward pass (QLoRA).
# TODO: Prototype — move to torchao as a proper QuantizationConverter.
# to_nf4 on local tensors loses DTensor grad info, fine here since
# base weights are frozen and only LoRA adapters receive gradients.
if self.quantize_base == "nf4":
from torch.distributed.tensor import DTensor

try:
from torchao.dtypes.nf4tensor import to_nf4
except ImportError as err:
raise ImportError(
"QLoRA requires torchao. Install with: pip install torchao"
) from err

lora_classes = tuple(_lora_class_cache.values())
nf4_scaler_block_size = self.nf4_scaler_block_size

def _to_nf4_tensor(weight: torch.Tensor) -> torch.Tensor:
"""Convert weight to NF4, handling both regular tensors and DTensors."""
nf4_block_size = 64 # NF4 default block size
is_dtensor = isinstance(weight, DTensor)
local_weight = weight.to_local() if is_dtensor else weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtensor.to_local() should never be used without grad_placements, o/w we could lose necessary reductions on gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a prototype to proof that we could run with NF4. We only convert origin weights to nf4, which are freezed(requires_grad=False), so that losing gradients here should be fine. If we decide to support qlora later, We will later corporate with torchao to support better on dtensor


num_scalers = local_weight.numel() // nf4_block_size
if num_scalers % nf4_scaler_block_size != 0:
raise ValueError(
f"NF4 quantization failed: num_scalers ({num_scalers}) is not "
f"divisible by nf4_scaler_block_size ({nf4_scaler_block_size}). "
f"Try a smaller nf4_scaler_block_size in LoRAConverter.Config "
f"(e.g., 64, 32, or 1)."
)

nf4_local = to_nf4(
local_weight, scaler_block_size=nf4_scaler_block_size
)

if is_dtensor:
return DTensor.from_local(
nf4_local, # pyrefly: ignore [bad-argument-type]
weight.device_mesh,
weight.placements,
)
return nf4_local # pyrefly: ignore [bad-return]

def _quantize_hook(
mod: nn.Module, args: Any, handle: torch.utils.hooks.RemovableHandle
) -> None:
for sub in mod.modules():
if isinstance(sub, lora_classes):
sub.weight = nn.Parameter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to do this conversion every single time the forward is called?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the result of a nf4 @ nf4 matmul

  • if it's bf16, then you have to convert the activations to a LoRA linear
  • if it's nf4, then other modules like rmsnorm and attention also needs to be able to work with nf4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation: bf16, weight: nf4
during forward, we would first dequantizes nf4 to bf16 first, so it's computation at bf16
nf4 @ bf16 -> bf16
we only have the quantization on origin weights, and the outputs are all in bf16, so that later computation are not affected, and we only save the memory for the origin weights (bf16 -> nf4)

_to_nf4_tensor(sub.weight.data), requires_grad=False
)
logger.info("QLoRA: quantized base weights to NF4")
handle.remove()

# Use a list to allow the closure to reference the handle before it exists
handle_ref: list[torch.utils.hooks.RemovableHandle] = []
handle_ref.append(
module.register_forward_pre_hook(
lambda mod, args: _quantize_hook(mod, args, handle_ref[0])
)
)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
pass
14 changes: 14 additions & 0 deletions torchtitan/models/llama3/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ def llama3_debugmodel_lora() -> Trainer.Config:
return config


def llama3_debugmodel_qlora() -> Trainer.Config:
config = llama3_debugmodel_lora()
config.model_converters = ModelConvertersContainer.Config(
converters=[
LoRAConverter.Config(
rank=8,
alpha=16.0,
quantize_base="nf4",
),
],
)
return config


def llama3_8b() -> Trainer.Config:
return Trainer.Config(
hf_assets_path="./assets/hf/Llama-3.1-8B",
Expand Down
Loading