Skip to content

Commit bbaeed5

Browse files
committed
add lora adapters
ghstack-source-id: 61dae2b Pull Request resolved: #2484
1 parent 0691f51 commit bbaeed5

File tree

4 files changed

+253
-1
lines changed

4 files changed

+253
-1
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from collections import OrderedDict
7+
68
import pytest
9+
import torch
10+
import torch.nn as nn
711

12+
from torchtitan.components.lora import LoRAConverter
813
from torchtitan.components.quantization.float8 import Float8LinearConverter
914
from torchtitan.config import ConfigManager
1015
from torchtitan.distributed import ParallelDims
@@ -66,3 +71,105 @@ def test_build_model_converters_float8_converter():
6671
assert isinstance(model_converters, ModelConvertersContainer)
6772
assert len(model_converters.converters) == 1
6873
assert isinstance(model_converters.converters[0], Float8LinearConverter)
74+
75+
76+
def test_lora_before_quantization_raises():
77+
"""LoRA must come after quantization converters."""
78+
with pytest.raises(ValueError, match="LoRA converter must come after"):
79+
ModelConvertersContainer.Config(
80+
converters=[
81+
LoRAConverter.Config(rank=8, alpha=16.0),
82+
Float8LinearConverter.Config(emulate=True),
83+
],
84+
)
85+
86+
87+
def test_lora_freeze_and_trainability():
88+
"""After convert: base params frozen, LoRA adapters present and trainable."""
89+
model = nn.Sequential(
90+
OrderedDict(
91+
[
92+
("fc1", nn.Linear(64, 64)),
93+
("relu", nn.ReLU()),
94+
("fc2", nn.Linear(64, 64)),
95+
]
96+
)
97+
)
98+
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
99+
converter.convert(model)
100+
101+
# LoRA adapters should be added to all linears
102+
assert hasattr(model.fc1, "lora_a")
103+
assert hasattr(model.fc1, "lora_b")
104+
assert hasattr(model.fc2, "lora_a")
105+
assert hasattr(model.fc2, "lora_b")
106+
107+
# Check every parameter
108+
lora_param_names = []
109+
base_param_names = []
110+
for name, param in model.named_parameters():
111+
if "lora_a" in name or "lora_b" in name:
112+
lora_param_names.append(name)
113+
assert param.requires_grad, f"LoRA param '{name}' should be trainable"
114+
else:
115+
base_param_names.append(name)
116+
assert not param.requires_grad, f"Base param '{name}' should be frozen"
117+
118+
assert len(lora_param_names) > 0, "No LoRA params found"
119+
assert len(base_param_names) > 0, "No base params found"
120+
121+
122+
def test_lora_trains_base_frozen():
123+
"""Train for several steps: LoRA params should change, base params should not."""
124+
torch.manual_seed(42)
125+
model = nn.Sequential(
126+
OrderedDict(
127+
[
128+
("fc1", nn.Linear(64, 64)),
129+
("relu", nn.ReLU()),
130+
("fc2", nn.Linear(64, 64)),
131+
]
132+
)
133+
)
134+
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
135+
converter.convert(model)
136+
137+
# Snapshot all params before training
138+
base_before = {
139+
name: param.data.clone()
140+
for name, param in model.named_parameters()
141+
if "lora_a" not in name and "lora_b" not in name
142+
}
143+
lora_before = {
144+
name: param.data.clone()
145+
for name, param in model.named_parameters()
146+
if "lora_a" in name or "lora_b" in name
147+
}
148+
149+
# Only LoRA params go to optimizer
150+
optimizer = torch.optim.SGD(
151+
[p for p in model.parameters() if p.requires_grad], lr=0.1
152+
)
153+
154+
# Train for 5 steps
155+
for _ in range(5):
156+
x = torch.randn(4, 64)
157+
loss = model(x).sum()
158+
loss.backward()
159+
optimizer.step()
160+
optimizer.zero_grad()
161+
162+
# Base params must not change
163+
for name, param in model.named_parameters():
164+
if name in base_before:
165+
assert torch.equal(
166+
param.data, base_before[name]
167+
), f"Base param '{name}' changed during training"
168+
169+
# At least some LoRA params must change
170+
any_lora_changed = any(
171+
not torch.equal(param.data, lora_before[name])
172+
for name, param in model.named_parameters()
173+
if name in lora_before
174+
)
175+
assert any_lora_changed, "No LoRA param changed after 5 training steps"

torchtitan/components/lora.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
from dataclasses import dataclass
9+
from typing import Any
10+
11+
import torch
12+
import torch.nn as nn
13+
14+
from torchtitan.config import Configurable
15+
from torchtitan.models.common.linear import Linear
16+
from torchtitan.tools.logging import logger
17+
18+
# Cache for dynamically created LoRA classes
19+
_lora_class_cache: dict[type, type] = {}
20+
21+
22+
def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear:
23+
parent_cls = type(linear)
24+
assert issubclass(
25+
parent_cls, nn.Linear
26+
), f"parent_cls must be a subclass of nn.Linear, got {parent_cls}"
27+
28+
if parent_cls not in _lora_class_cache:
29+
30+
class LoRALinear(parent_cls): # type: ignore[valid-type, misc]
31+
def __init__(self, *args: Any, **kwargs: Any) -> None:
32+
raise RuntimeError("LoRALinear should not be instantiated directly.")
33+
34+
@classmethod
35+
def from_linear(
36+
cls, linear: nn.Linear, rank: int, alpha: float
37+
) -> "LoRALinear":
38+
linear.__class__ = cls
39+
linear._init_lora(rank, alpha) # type: ignore[attr-defined]
40+
return linear # type: ignore[return-value]
41+
42+
def _init_lora(
43+
self,
44+
rank: int,
45+
alpha: float,
46+
device: torch.device | None = None,
47+
dtype: torch.dtype | None = None,
48+
) -> None:
49+
self._lora_scaling = alpha / rank
50+
device = device if device is not None else self.weight.device
51+
dtype = dtype if dtype is not None else self.weight.dtype
52+
self.lora_a = (
53+
Linear.Config(bias=False)
54+
.build(in_features=self.in_features, out_features=rank)
55+
.to(device=device, dtype=dtype)
56+
)
57+
self.lora_b = (
58+
Linear.Config(bias=False)
59+
.build(in_features=rank, out_features=self.out_features)
60+
.to(device=device, dtype=dtype)
61+
)
62+
63+
def init_weights(self, **kwargs) -> None:
64+
super().init_weights(**kwargs) # pyrefly: ignore [not-callable]
65+
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
66+
nn.init.zeros_(self.lora_b.weight)
67+
68+
def forward(self, input: torch.Tensor) -> torch.Tensor:
69+
base_out = super().forward(input)
70+
lora_out = self.lora_b(self.lora_a(input))
71+
return base_out + self._lora_scaling * lora_out
72+
73+
LoRALinear.__name__ = f"LoRA{parent_cls.__name__}"
74+
LoRALinear.__qualname__ = f"LoRA{parent_cls.__name__}"
75+
_lora_class_cache[parent_cls] = LoRALinear
76+
77+
# pyrefly: ignore [missing-attribute]
78+
return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha)
79+
80+
81+
class LoRAConverter(Configurable):
82+
"""Apply LoRA adapters to all Linear layers in a model."""
83+
84+
@dataclass(kw_only=True, slots=True)
85+
class Config(Configurable.Config):
86+
rank: int = 8
87+
"""Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features)."""
88+
89+
alpha: float = 16.0
90+
"""Scaling factor. Output is scaled by alpha/rank."""
91+
92+
def __init__(self, config: Config, **kwargs):
93+
self.rank = config.rank
94+
self.alpha = config.alpha
95+
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
96+
97+
def convert(self, model: nn.Module) -> None:
98+
model.requires_grad_(False)
99+
self._replace_linears_with_lora(model)
100+
101+
def _replace_linears_with_lora(self, module: nn.Module) -> None:
102+
for _, child in list(module.named_modules()):
103+
if isinstance(child, nn.Linear):
104+
apply_lora(child, self.rank, self.alpha)
105+
106+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
107+
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from torchtitan.components.checkpoint import CheckpointManager
8+
from torchtitan.components.lora import LoRAConverter
89
from torchtitan.components.lr_scheduler import LRSchedulersContainer
910
from torchtitan.components.metrics import MetricsProcessor
1011
from torchtitan.components.optimizer import (
@@ -108,6 +109,19 @@ def llama3_debugmodel_float8_emulate() -> Trainer.Config:
108109
return config
109110

110111

112+
def llama3_debugmodel_lora() -> Trainer.Config:
113+
config = llama3_debugmodel()
114+
config.model_converters = ModelConvertersContainer.Config(
115+
converters=[
116+
LoRAConverter.Config(
117+
rank=8,
118+
alpha=16.0,
119+
),
120+
],
121+
)
122+
return config
123+
124+
111125
def llama3_8b() -> Trainer.Config:
112126
return Trainer.Config(
113127
hf_assets_path="./assets/hf/Llama-3.1-8B",

torchtitan/protocols/model_converter.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,17 @@ class Config(Configurable.Config):
5353
print_after_conversion: bool = False
5454
"""If true, model definition will be printed after converters are applied."""
5555

56+
def __post_init__(self):
57+
_validate_converter_ordering(self.converters)
58+
_validate_quantization(self.converters)
59+
5660
def __init__(
5761
self,
5862
config: Config,
5963
*,
6064
parallel_dims: ParallelDims,
6165
model_compile_enabled: bool,
6266
):
63-
_validate_quantization(config.converters)
6467
self.converters: list[ModelConverter] = [
6568
cc.build(
6669
parallel_dims=parallel_dims,
@@ -81,6 +84,27 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
8184
mh.post_optimizer_hook(model)
8285

8386

87+
def _validate_converter_ordering(converters: list[Configurable.Config]):
88+
"""Validates that converters are in the correct order.
89+
90+
LoRA must come after quantization because quantization replaces nn.Linear
91+
with specialized subclasses (e.g. Float8Linear), and LoRA dynamically
92+
inherits from whatever linear class it wraps.
93+
"""
94+
from torchtitan.components.lora import LoRAConverter
95+
96+
seen_lora = False
97+
for config in converters:
98+
if isinstance(config, LoRAConverter.Config):
99+
seen_lora = True
100+
elif isinstance(config, QuantizationConverter.Config) and seen_lora:
101+
raise ValueError(
102+
"LoRA converter must come after quantization converters. "
103+
"Quantization replaces nn.Linear with specialized subclasses, "
104+
"and LoRA must wrap the final linear class."
105+
)
106+
107+
84108
def _validate_quantization(converters: list[Configurable.Config]):
85109
"""Validates that all quantization converters use the same quantization type.
86110

0 commit comments

Comments
 (0)