Skip to content
Open
107 changes: 107 additions & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict

import pytest
import torch
import torch.nn as nn

from torchtitan.components.lora import LoRAConverter
from torchtitan.components.quantization.float8 import Float8LinearConverter
from torchtitan.config import ConfigManager
from torchtitan.distributed import ParallelDims
Expand Down Expand Up @@ -66,3 +71,105 @@ def test_build_model_converters_float8_converter():
assert isinstance(model_converters, ModelConvertersContainer)
assert len(model_converters.converters) == 1
assert isinstance(model_converters.converters[0], Float8LinearConverter)


def test_lora_before_quantization_raises():
"""LoRA must come after quantization converters."""
with pytest.raises(ValueError, match="LoRA converter must come after"):
ModelConvertersContainer.Config(
converters=[
LoRAConverter.Config(rank=8, alpha=16.0),
Float8LinearConverter.Config(emulate=True),
],
)


def test_lora_freeze_and_trainability():
"""After convert: base params frozen, LoRA adapters present and trainable."""
model = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(64, 64)),
("relu", nn.ReLU()),
("fc2", nn.Linear(64, 64)),
]
)
)
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
converter.convert(model)

# LoRA adapters should be added to all linears
assert hasattr(model.fc1, "lora_a")
assert hasattr(model.fc1, "lora_b")
assert hasattr(model.fc2, "lora_a")
assert hasattr(model.fc2, "lora_b")

# Check every parameter
lora_param_names = []
base_param_names = []
for name, param in model.named_parameters():
if "lora_a" in name or "lora_b" in name:
lora_param_names.append(name)
assert param.requires_grad, f"LoRA param '{name}' should be trainable"
else:
base_param_names.append(name)
assert not param.requires_grad, f"Base param '{name}' should be frozen"

assert len(lora_param_names) > 0, "No LoRA params found"
assert len(base_param_names) > 0, "No base params found"


def test_lora_trains_base_frozen():
"""Train for several steps: LoRA params should change, base params should not."""
torch.manual_seed(42)
model = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(64, 64)),
("relu", nn.ReLU()),
("fc2", nn.Linear(64, 64)),
]
)
)
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
converter.convert(model)

# Snapshot all params before training
base_before = {
name: param.data.clone()
for name, param in model.named_parameters()
if "lora_a" not in name and "lora_b" not in name
}
lora_before = {
name: param.data.clone()
for name, param in model.named_parameters()
if "lora_a" in name or "lora_b" in name
}

# Only LoRA params go to optimizer
optimizer = torch.optim.SGD(
[p for p in model.parameters() if p.requires_grad], lr=0.1
)

# Train for 5 steps
for _ in range(5):
x = torch.randn(4, 64)
loss = model(x).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Base params must not change
for name, param in model.named_parameters():
if name in base_before:
assert torch.equal(
param.data, base_before[name]
), f"Base param '{name}' changed during training"

# At least some LoRA params must change
any_lora_changed = any(
not torch.equal(param.data, lora_before[name])
for name, param in model.named_parameters()
if name in lora_before
)
assert any_lora_changed, "No LoRA param changed after 5 training steps"
107 changes: 107 additions & 0 deletions torchtitan/components/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn

from torchtitan.config import Configurable
from torchtitan.models.common.linear import Linear
from torchtitan.tools.logging import logger

# Cache for dynamically created LoRA classes
_lora_class_cache: dict[type, type] = {}


def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear:
parent_cls = type(linear)
assert issubclass(
parent_cls, nn.Linear
), f"parent_cls must be a subclass of nn.Linear, got {parent_cls}"

if parent_cls not in _lora_class_cache:

class LoRALinear(parent_cls): # type: ignore[valid-type, misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("LoRALinear should not be instantiated directly.")

@classmethod
def from_linear(
cls, linear: nn.Linear, rank: int, alpha: float
) -> "LoRALinear":
linear.__class__ = cls
linear._init_lora(rank, alpha) # type: ignore[attr-defined]
return linear # type: ignore[return-value]

def _init_lora(
self,
rank: int,
alpha: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
self._lora_scaling = alpha / rank
device = device if device is not None else self.weight.device
dtype = dtype if dtype is not None else self.weight.dtype
self.lora_a = (
Linear.Config(bias=False)
.build(in_features=self.in_features, out_features=rank)
.to(device=device, dtype=dtype)
)
self.lora_b = (
Linear.Config(bias=False)
.build(in_features=rank, out_features=self.out_features)
.to(device=device, dtype=dtype)
)

def init_weights(self, **kwargs) -> None:
super().init_weights(**kwargs) # pyrefly: ignore [not-callable]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove new_model_init_weights after refactoring Linear @fegin

nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)

def forward(self, input: torch.Tensor) -> torch.Tensor:
base_out = super().forward(input)
lora_out = self.lora_b(self.lora_a(input))
return base_out + self._lora_scaling * lora_out

LoRALinear.__name__ = f"LoRA{parent_cls.__name__}"
LoRALinear.__qualname__ = f"LoRA{parent_cls.__name__}"
_lora_class_cache[parent_cls] = LoRALinear

# pyrefly: ignore [missing-attribute]
return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha)


class LoRAConverter(Configurable):
"""Apply LoRA adapters to all Linear layers in a model."""

@dataclass(kw_only=True, slots=True)
class Config(Configurable.Config):
rank: int = 8
"""Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features)."""

alpha: float = 16.0
"""Scaling factor. Output is scaled by alpha/rank."""

def __init__(self, config: Config, **kwargs):
self.rank = config.rank
self.alpha = config.alpha
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")

def convert(self, model: nn.Module) -> None:
model.requires_grad_(False)
self._replace_linears_with_lora(model)

def _replace_linears_with_lora(self, module: nn.Module) -> None:
for _, child in list(module.named_modules()):
if isinstance(child, nn.Linear):
apply_lora(child, self.rank, self.alpha)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
pass
14 changes: 14 additions & 0 deletions torchtitan/models/llama3/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.lora import LoRAConverter
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.metrics import MetricsProcessor
from torchtitan.components.optimizer import (
Expand Down Expand Up @@ -108,6 +109,19 @@ def llama3_debugmodel_float8_emulate() -> Trainer.Config:
return config


def llama3_debugmodel_lora() -> Trainer.Config:
config = llama3_debugmodel()
config.model_converters = ModelConvertersContainer.Config(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Among all the converters (quantization, lora, etc.), order matters. For now let's create a __post_init__ in ModelConvertersContainer.Config to validate the order.

Note that it can't capture the wrong order, if we first init the main config and then modify a field.

@fegin Validation is becoming an issue. If we don't go with chz, maybe we should at least freeze the config, and require users to use dataclass.replace to trigger post init checks.

converters=[
LoRAConverter.Config(
rank=8,
alpha=16.0,
),
],
)
return config


def llama3_8b() -> Trainer.Config:
return Trainer.Config(
hf_assets_path="./assets/hf/Llama-3.1-8B",
Expand Down
26 changes: 25 additions & 1 deletion torchtitan/protocols/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ class Config(Configurable.Config):
print_after_conversion: bool = False
"""If true, model definition will be printed after converters are applied."""

def __post_init__(self):
_validate_converter_ordering(self.converters)
_validate_quantization(self.converters)

def __init__(
self,
config: Config,
*,
parallel_dims: ParallelDims,
model_compile_enabled: bool,
):
_validate_quantization(config.converters)
self.converters: list[ModelConverter] = [
cc.build(
parallel_dims=parallel_dims,
Expand All @@ -81,6 +84,27 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
mh.post_optimizer_hook(model)


def _validate_converter_ordering(converters: list[Configurable.Config]):
"""Validates that converters are in the correct order.

LoRA must come after quantization because quantization replaces nn.Linear
with specialized subclasses (e.g. Float8Linear), and LoRA dynamically
inherits from whatever linear class it wraps.
"""
from torchtitan.components.lora import LoRAConverter

seen_lora = False
for config in converters:
if isinstance(config, LoRAConverter.Config):
seen_lora = True
elif isinstance(config, QuantizationConverter.Config) and seen_lora:
raise ValueError(
"LoRA converter must come after quantization converters. "
"Quantization replaces nn.Linear with specialized subclasses, "
"and LoRA must wrap the final linear class."
)


def _validate_quantization(converters: list[Configurable.Config]):
"""Validates that all quantization converters use the same quantization type.

Expand Down
Loading