Skip to content

Commit 559ebe1

Browse files
committed
support lora moe
ghstack-source-id: 726558b Pull Request resolved: #2569
1 parent fa02102 commit 559ebe1

File tree

3 files changed

+264
-2
lines changed

3 files changed

+264
-2
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,85 @@ def test_lora_key_remap_roundtrip():
205205
assert torch.equal(rt_sd[k], tt_sd[k])
206206

207207

208+
def test_lora_moe_freeze_and_trainability():
209+
"""LoRA on MoE model: router frozen, expert LoRA adapters trainable, base weights frozen."""
210+
from torchtitan.models.common.moe.moe import GroupedExperts, TokenChoiceTopKRouter
211+
212+
# Build a minimal MoE-like model: a router + grouped experts + a dense linear
213+
num_experts = 4
214+
dim = 64
215+
hidden_dim = 128
216+
217+
class SimpleMoEModel(nn.Module):
218+
def __init__(self):
219+
super().__init__()
220+
self.router = TokenChoiceTopKRouter(
221+
dim=dim,
222+
num_experts=num_experts,
223+
num_expert_groups=None,
224+
num_limited_groups=None,
225+
top_k=2,
226+
score_func="softmax",
227+
route_norm=False,
228+
route_scale=1.0,
229+
gate_bias=False,
230+
)
231+
self.experts = GroupedExperts(
232+
dim=dim,
233+
hidden_dim=hidden_dim,
234+
num_experts=num_experts,
235+
use_grouped_mm=False,
236+
)
237+
self.output = nn.Linear(dim, dim)
238+
239+
def forward(self, x):
240+
# Just test that LoRA params exist — no need for full MoE forward
241+
return self.output(x)
242+
243+
model = SimpleMoEModel()
244+
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
245+
converter.convert(model)
246+
247+
# Router gate should be frozen (LoRA skips router gates)
248+
assert not hasattr(model.router.gate, "lora_a"), "Router gate should not have LoRA"
249+
for param in model.router.parameters():
250+
assert not param.requires_grad, "Router params should be frozen"
251+
252+
# Dense linear should have LoRA adapters
253+
assert hasattr(model.output, "lora_a")
254+
assert hasattr(model.output, "lora_b")
255+
256+
# GroupedExperts should have expert LoRA adapters
257+
assert hasattr(model.experts, "lora_a_w1")
258+
assert hasattr(model.experts, "lora_b_w1")
259+
assert hasattr(model.experts, "lora_a_w2")
260+
assert hasattr(model.experts, "lora_b_w2")
261+
assert hasattr(model.experts, "lora_a_w3")
262+
assert hasattr(model.experts, "lora_b_w3")
263+
264+
# Check trainability: LoRA params trainable, base params frozen
265+
lora_param_names = []
266+
base_param_names = []
267+
for name, param in model.named_parameters():
268+
if "lora_a" in name or "lora_b" in name:
269+
lora_param_names.append(name)
270+
assert param.requires_grad, f"LoRA param '{name}' should be trainable"
271+
else:
272+
base_param_names.append(name)
273+
assert not param.requires_grad, f"Base param '{name}' should be frozen"
274+
275+
assert len(lora_param_names) > 0, "No LoRA params found"
276+
assert len(base_param_names) > 0, "No base params found"
277+
278+
# Verify expert LoRA shapes: (num_experts, *, rank) or (num_experts, rank, *)
279+
assert model.experts.lora_a_w1.shape == (num_experts, dim, 4)
280+
assert model.experts.lora_b_w1.shape == (num_experts, 4, hidden_dim)
281+
assert model.experts.lora_a_w2.shape == (num_experts, hidden_dim, 4)
282+
assert model.experts.lora_b_w2.shape == (num_experts, 4, dim)
283+
assert model.experts.lora_a_w3.shape == (num_experts, dim, 4)
284+
assert model.experts.lora_b_w3.shape == (num_experts, 4, hidden_dim)
285+
286+
208287
def test_qat_preserves_weight_dtype():
209288
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
210289
pytest.importorskip("torchao")

torchtitan/components/lora.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414

1515
from torchtitan.config import Configurable
1616
from torchtitan.models.common.linear import Linear
17+
from torchtitan.models.common.moe.moe import GroupedExperts, TokenChoiceTopKRouter
1718
from torchtitan.tools.logging import logger
1819

1920
# Cache for dynamically created LoRA classes
2021
_lora_class_cache: dict[type, type] = {}
2122

23+
# Cache for dynamically created expert LoRA classes
24+
_expert_lora_class_cache: dict[type, type] = {}
25+
2226

2327
def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear:
2428
parent_cls = type(linear)
@@ -79,8 +83,164 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
7983
return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha)
8084

8185

86+
def _compute_expert_lora_delta(
87+
lora_a: torch.Tensor,
88+
lora_b: torch.Tensor,
89+
scaling: float,
90+
target_weight: nn.Parameter,
91+
) -> torch.Tensor:
92+
"""Compute the LoRA weight delta for expert weights.
93+
94+
Args:
95+
lora_a: (E, in, r) — projects input dim to rank.
96+
lora_b: (E, r, out) — projects rank to output dim.
97+
scaling: alpha / rank.
98+
target_weight: The base weight parameter to match DTensor placements.
99+
100+
Returns:
101+
delta matching target_weight's shape and placements.
102+
Math: delta = scaling * B^T @ A^T → shape (E, out, in).
103+
"""
104+
from torch.distributed.tensor import distribute_tensor, DTensor
105+
106+
delta = scaling * torch.bmm(lora_b.transpose(-2, -1), lora_a.transpose(-2, -1))
107+
# When the base weight is a DTensor (TP/EP sharded), distribute the delta
108+
# to match its placements so the in-place add_/sub_ operates on matching shapes.
109+
if isinstance(target_weight, DTensor) and not isinstance(delta, DTensor):
110+
delta = distribute_tensor(
111+
delta, target_weight.device_mesh, target_weight.placements
112+
)
113+
return delta
114+
115+
116+
def apply_expert_lora(
117+
experts: GroupedExperts, rank: int, alpha: float
118+
) -> GroupedExperts:
119+
"""Apply LoRA adapters to a GroupedExperts module via class swapping.
120+
121+
LoRA parameters are registered as direct parameters on the module. EP partition
122+
functions that use ``named_parameters(recurse=False)`` with ``Shard(0)`` will
123+
correctly shard them on the expert dimension. TP/ETP partition functions only
124+
touch w1/w2/w3 by name and leave LoRA parameters unsharded.
125+
126+
Forward uses merge-per-forward: LoRA deltas are merged into base weights before
127+
calling the base forward, then unmerged after. This reuses the base
128+
GroupedExperts.forward without duplicating its DTensor/EP/padding logic.
129+
"""
130+
parent_cls = type(experts)
131+
assert issubclass(
132+
parent_cls, GroupedExperts
133+
), f"parent_cls must be a subclass of GroupedExperts, got {parent_cls}"
134+
135+
if parent_cls not in _expert_lora_class_cache:
136+
137+
class LoRAGroupedExperts(parent_cls): # type: ignore[valid-type, misc]
138+
def __init__(self, *args: Any, **kwargs: Any) -> None:
139+
raise RuntimeError(
140+
"LoRAGroupedExperts should not be instantiated directly."
141+
)
142+
143+
@classmethod
144+
def from_experts(
145+
cls, experts: GroupedExperts, rank: int, alpha: float
146+
) -> "LoRAGroupedExperts":
147+
experts.__class__ = cls
148+
experts._init_expert_lora(rank, alpha) # type: ignore[attr-defined]
149+
return experts # type: ignore[return-value]
150+
151+
def _init_expert_lora(self, rank: int, alpha: float) -> None:
152+
self._lora_scaling = alpha / rank
153+
num_experts = self.num_experts
154+
# w1: (E, hidden_dim, dim) -> A1: (E, dim, r), B1: (E, r, hidden_dim)
155+
dim_w1_in = self.w1.shape[2] # dim
156+
dim_w1_out = self.w1.shape[1] # hidden_dim
157+
# w2: (E, dim, hidden_dim) -> A2: (E, hidden_dim, r), B2: (E, r, dim)
158+
dim_w2_in = self.w2.shape[2] # hidden_dim
159+
dim_w2_out = self.w2.shape[1] # dim
160+
# w3: (E, hidden_dim, dim) -> A3: (E, dim, r), B3: (E, r, hidden_dim)
161+
dim_w3_in = self.w3.shape[2] # dim
162+
dim_w3_out = self.w3.shape[1] # hidden_dim
163+
164+
device = self.w1.device
165+
dtype = self.w1.dtype
166+
167+
self.lora_a_w1 = nn.Parameter(
168+
torch.empty(
169+
num_experts, dim_w1_in, rank, device=device, dtype=dtype
170+
)
171+
)
172+
self.lora_b_w1 = nn.Parameter(
173+
torch.empty(
174+
num_experts, rank, dim_w1_out, device=device, dtype=dtype
175+
)
176+
)
177+
self.lora_a_w2 = nn.Parameter(
178+
torch.empty(
179+
num_experts, dim_w2_in, rank, device=device, dtype=dtype
180+
)
181+
)
182+
self.lora_b_w2 = nn.Parameter(
183+
torch.empty(
184+
num_experts, rank, dim_w2_out, device=device, dtype=dtype
185+
)
186+
)
187+
self.lora_a_w3 = nn.Parameter(
188+
torch.empty(
189+
num_experts, dim_w3_in, rank, device=device, dtype=dtype
190+
)
191+
)
192+
self.lora_b_w3 = nn.Parameter(
193+
torch.empty(
194+
num_experts, rank, dim_w3_out, device=device, dtype=dtype
195+
)
196+
)
197+
198+
def init_weights(self, init_std: float) -> None:
199+
super().init_weights(init_std)
200+
for name in ("lora_a_w1", "lora_a_w2", "lora_a_w3"):
201+
nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5))
202+
for name in ("lora_b_w1", "lora_b_w2", "lora_b_w3"):
203+
nn.init.zeros_(getattr(self, name))
204+
205+
def forward(
206+
self,
207+
x: torch.Tensor,
208+
num_tokens_per_expert: torch.Tensor,
209+
) -> torch.Tensor:
210+
# Merge LoRA deltas into base weights, run base forward, unmerge.
211+
# This reuses all base GroupedExperts logic (DTensor, EP, padding).
212+
deltas = {}
213+
for w_name, a_name, b_name in (
214+
("w1", "lora_a_w1", "lora_b_w1"),
215+
("w2", "lora_a_w2", "lora_b_w2"),
216+
("w3", "lora_a_w3", "lora_b_w3"),
217+
):
218+
lora_a = getattr(self, a_name)
219+
lora_b = getattr(self, b_name)
220+
w = getattr(self, w_name)
221+
delta = _compute_expert_lora_delta(
222+
lora_a, lora_b, self._lora_scaling, w
223+
)
224+
w.data.add_(delta)
225+
deltas[w_name] = delta
226+
227+
try:
228+
return super().forward(x, num_tokens_per_expert)
229+
finally:
230+
# Unmerge: subtract deltas to restore original weights
231+
for w_name, delta in deltas.items():
232+
getattr(self, w_name).data.sub_(delta)
233+
234+
LoRAGroupedExperts.__name__ = f"LoRA{parent_cls.__name__}"
235+
LoRAGroupedExperts.__qualname__ = f"LoRA{parent_cls.__name__}"
236+
_expert_lora_class_cache[parent_cls] = LoRAGroupedExperts
237+
238+
# pyrefly: ignore [missing-attribute]
239+
return _expert_lora_class_cache[parent_cls].from_experts(experts, rank, alpha)
240+
241+
82242
class LoRAConverter(Configurable):
83-
"""Apply LoRA adapters to all Linear layers in a model."""
243+
"""Apply LoRA adapters to all Linear layers and GroupedExperts in a model."""
84244

85245
@dataclass(kw_only=True, slots=True)
86246
class Config(Configurable.Config):
@@ -125,9 +285,18 @@ def convert(self, model: nn.Module) -> None:
125285
}
126286

127287
def _replace_linears_with_lora(self, module: nn.Module) -> None:
288+
# Collect router gate linears so we can skip them — routing scores
289+
# must stay frozen to preserve expert load balancing.
290+
router_gate_ids: set[int] = set()
291+
for child in module.modules():
292+
if isinstance(child, TokenChoiceTopKRouter):
293+
router_gate_ids.add(id(child.gate))
294+
128295
for _, child in list(module.named_modules()):
129-
if isinstance(child, nn.Linear):
296+
if isinstance(child, nn.Linear) and id(child) not in router_gate_ids:
130297
apply_lora(child, self.rank, self.alpha)
298+
elif isinstance(child, GroupedExperts):
299+
apply_expert_lora(child, self.rank, self.alpha)
131300

132301
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
133302
pass

torchtitan/models/deepseek_v3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from torchtitan.components.checkpoint import CheckpointManager
8+
from torchtitan.components.lora import LoRAConverter
89
from torchtitan.components.lr_scheduler import LRSchedulersContainer
910
from torchtitan.components.metrics import MetricsProcessor
1011
from torchtitan.components.optimizer import OptimizersContainer
@@ -58,6 +59,19 @@ def deepseek_v3_debugmodel() -> Trainer.Config:
5859
)
5960

6061

62+
def deepseek_v3_debugmodel_lora() -> Trainer.Config:
63+
config = deepseek_v3_debugmodel()
64+
config.model_converters = ModelConvertersContainer.Config(
65+
converters=[
66+
LoRAConverter.Config(
67+
rank=8,
68+
alpha=16.0,
69+
),
70+
],
71+
)
72+
return config
73+
74+
6175
def deepseek_v3_debugmodel_flex_attn() -> Trainer.Config:
6276
config = deepseek_v3_debugmodel()
6377
config.model_spec = model_registry("debugmodel_flex_attn")

0 commit comments

Comments
 (0)