Skip to content

Commit 37b61a5

Browse files
committed
add lora adapters
ghstack-source-id: 03b26d3 Pull Request resolved: #2484
1 parent de32920 commit 37b61a5

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

torchtitan/components/lora.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
from dataclasses import dataclass
9+
from typing import Any
10+
11+
import torch
12+
import torch.nn as nn
13+
14+
from torchtitan.config import Configurable
15+
from torchtitan.tools.logging import logger
16+
17+
# Cache for dynamically created LoRA classes
18+
_lora_class_cache: dict[type, type] = {}
19+
20+
21+
def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear:
22+
parent_cls = type(linear)
23+
assert issubclass(
24+
parent_cls, nn.Linear
25+
), f"parent_cls must be a subclass of nn.Linear, got {parent_cls}"
26+
27+
if parent_cls not in _lora_class_cache:
28+
29+
class LoRALinear(parent_cls): # type: ignore[valid-type, misc]
30+
def __init__(self, *args: Any, **kwargs: Any) -> None:
31+
raise RuntimeError("LoRALinear should not be instantiated directly.")
32+
33+
@classmethod
34+
def from_linear(
35+
cls, linear: nn.Linear, rank: int, alpha: float
36+
) -> "LoRALinear":
37+
linear.__class__ = cls
38+
linear._init_lora(rank, alpha) # type: ignore[attr-defined]
39+
return linear # type: ignore[return-value]
40+
41+
def _init_lora(
42+
self,
43+
rank: int,
44+
alpha: float,
45+
device: torch.device | None = None,
46+
dtype: torch.dtype | None = None,
47+
) -> None:
48+
self._lora_scaling = alpha / rank
49+
device = device if device is not None else self.weight.device
50+
dtype = dtype if dtype is not None else self.weight.dtype
51+
self.lora_a = nn.Linear(
52+
self.in_features,
53+
rank,
54+
bias=False,
55+
device=device,
56+
dtype=dtype,
57+
)
58+
self.lora_b = nn.Linear(
59+
rank,
60+
self.out_features,
61+
bias=False,
62+
device=device,
63+
dtype=dtype,
64+
)
65+
self._init_weight()
66+
67+
def _init_weight(self) -> None:
68+
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
69+
nn.init.zeros_(self.lora_b.weight)
70+
71+
def forward(self, input: torch.Tensor) -> torch.Tensor:
72+
base_out = super().forward(input)
73+
lora_out = self.lora_b(self.lora_a(input))
74+
return base_out + self._lora_scaling * lora_out
75+
76+
LoRALinear.__name__ = f"LoRA{parent_cls.__name__}"
77+
LoRALinear.__qualname__ = f"LoRA{parent_cls.__name__}"
78+
_lora_class_cache[parent_cls] = LoRALinear
79+
80+
return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha)
81+
82+
class LoRAConverter(Configurable):
83+
"""Apply LoRA adapters to all Linear layers in a model."""
84+
85+
@dataclass(kw_only=True, slots=True)
86+
class Config(Configurable.Config):
87+
rank: int = 8
88+
"""Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features)."""
89+
90+
alpha: float = 16.0
91+
"""Scaling factor. Output is scaled by alpha/rank."""
92+
93+
def __init__(self, config: Config, **kwargs):
94+
self.rank = config.rank
95+
self.alpha = config.alpha
96+
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
97+
98+
def convert(self, model: nn.Module) -> None:
99+
model.requires_grad_(False)
100+
self._replace_linears_with_lora(model)
101+
102+
def _replace_linears_with_lora(self, module: nn.Module) -> None:
103+
for _, child in list(module.named_modules()):
104+
if isinstance(child, nn.Linear):
105+
apply_lora(child, self.rank, self.alpha)
106+
107+
# Patch init_weights to also reinitialize LoRA adapters
108+
original_init_weights = getattr(module, "init_weights", None)
109+
110+
def new_model_init_weights(*args: Any, **kwargs: Any) -> None:
111+
if original_init_weights is not None and callable(original_init_weights):
112+
original_init_weights(*args, **kwargs)
113+
for sub_module in module.modules():
114+
if type(sub_module) in _lora_class_cache.values():
115+
_init_weight = getattr(sub_module, "_init_weight", None)
116+
assert _init_weight is not None and callable(_init_weight)
117+
_init_weight()
118+
119+
object.__setattr__(module, "init_weights", new_model_init_weights)
120+
121+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
122+
pass
123+
124+
125+
def find_lora_config(converters: list) -> "LoRAConverter.Config | None":
126+
return next(
127+
(c for c in converters if isinstance(c, LoRAConverter.Config)),
128+
None,
129+
)

torchtitan/models/llama3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
OptimizersContainer,
1212
OptimizersInBackwardContainer,
1313
)
14+
from torchtitan.components.lora import LoRAConverter
1415
from torchtitan.components.quantization.float8 import Float8LinearConverter
1516
from torchtitan.components.validate import Validator
1617
from torchtitan.config import (
@@ -108,6 +109,19 @@ def llama3_debugmodel_float8_emulate() -> Trainer.Config:
108109
return config
109110

110111

112+
def llama3_debugmodel_lora() -> Trainer.Config:
113+
config = llama3_debugmodel()
114+
config.model_converters = ModelConvertersContainer.Config(
115+
converters=[
116+
LoRAConverter.Config(
117+
rank=8,
118+
alpha=16.0,
119+
),
120+
],
121+
)
122+
return config
123+
124+
111125
def llama3_8b() -> Trainer.Config:
112126
return Trainer.Config(
113127
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)