Skip to content

Commit 268f960

Browse files
committed
support lora moe
ghstack-source-id: 87f0bd5 Pull Request resolved: #2569
1 parent 7ca7141 commit 268f960

File tree

3 files changed

+266
-3
lines changed

3 files changed

+266
-3
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,85 @@ def test_qlora_base_weights_quantized_adapters_full_precision():
217217
), f"{name}.lora_b.weight should be float32"
218218

219219

220+
def test_lora_moe_freeze_and_trainability():
221+
"""LoRA on MoE model: router frozen, expert LoRA adapters trainable, base weights frozen."""
222+
from torchtitan.models.common.moe.moe import GroupedExperts, TokenChoiceTopKRouter
223+
224+
# Build a minimal MoE-like model: a router + grouped experts + a dense linear
225+
num_experts = 4
226+
dim = 64
227+
hidden_dim = 128
228+
229+
class SimpleMoEModel(nn.Module):
230+
def __init__(self):
231+
super().__init__()
232+
self.router = TokenChoiceTopKRouter(
233+
dim=dim,
234+
num_experts=num_experts,
235+
num_expert_groups=None,
236+
num_limited_groups=None,
237+
top_k=2,
238+
score_func="softmax",
239+
route_norm=False,
240+
route_scale=1.0,
241+
gate_bias=False,
242+
)
243+
self.experts = GroupedExperts(
244+
dim=dim,
245+
hidden_dim=hidden_dim,
246+
num_experts=num_experts,
247+
use_grouped_mm=False,
248+
)
249+
self.output = nn.Linear(dim, dim)
250+
251+
def forward(self, x):
252+
# Just test that LoRA params exist — no need for full MoE forward
253+
return self.output(x)
254+
255+
model = SimpleMoEModel()
256+
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
257+
converter.convert(model)
258+
259+
# Router gate should be frozen (LoRA skips router gates)
260+
assert not hasattr(model.router.gate, "lora_a"), "Router gate should not have LoRA"
261+
for param in model.router.parameters():
262+
assert not param.requires_grad, "Router params should be frozen"
263+
264+
# Dense linear should have LoRA adapters
265+
assert hasattr(model.output, "lora_a")
266+
assert hasattr(model.output, "lora_b")
267+
268+
# GroupedExperts should have expert LoRA adapters
269+
assert hasattr(model.experts, "lora_a_w1")
270+
assert hasattr(model.experts, "lora_b_w1")
271+
assert hasattr(model.experts, "lora_a_w2")
272+
assert hasattr(model.experts, "lora_b_w2")
273+
assert hasattr(model.experts, "lora_a_w3")
274+
assert hasattr(model.experts, "lora_b_w3")
275+
276+
# Check trainability: LoRA params trainable, base params frozen
277+
lora_param_names = []
278+
base_param_names = []
279+
for name, param in model.named_parameters():
280+
if "lora_a" in name or "lora_b" in name:
281+
lora_param_names.append(name)
282+
assert param.requires_grad, f"LoRA param '{name}' should be trainable"
283+
else:
284+
base_param_names.append(name)
285+
assert not param.requires_grad, f"Base param '{name}' should be frozen"
286+
287+
assert len(lora_param_names) > 0, "No LoRA params found"
288+
assert len(base_param_names) > 0, "No base params found"
289+
290+
# Verify expert LoRA shapes: (num_experts, *, rank) or (num_experts, rank, *)
291+
assert model.experts.lora_a_w1.shape == (num_experts, dim, 4)
292+
assert model.experts.lora_b_w1.shape == (num_experts, 4, hidden_dim)
293+
assert model.experts.lora_a_w2.shape == (num_experts, hidden_dim, 4)
294+
assert model.experts.lora_b_w2.shape == (num_experts, 4, dim)
295+
assert model.experts.lora_a_w3.shape == (num_experts, dim, 4)
296+
assert model.experts.lora_b_w3.shape == (num_experts, 4, hidden_dim)
297+
298+
220299
def test_qat_preserves_weight_dtype():
221300
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
222301
pytest.importorskip("torchao")

torchtitan/components/lora.py

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010

1111
import torch
1212
import torch.nn as nn
13+
1314
from torchtitan.config import Configurable
1415
from torchtitan.models.common.linear import Linear
16+
from torchtitan.models.common.moe.moe import GroupedExperts, TokenChoiceTopKRouter
1517
from torchtitan.tools.logging import logger
1618

1719
# Cache for dynamically created LoRA classes
1820
_lora_class_cache: dict[type, type] = {}
1921

22+
# Cache for dynamically created expert LoRA classes
23+
_expert_lora_class_cache: dict[type, type] = {}
24+
2025

2126
def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear:
2227
parent_cls = type(linear)
@@ -77,8 +82,164 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
7782
return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha)
7883

7984

85+
def _compute_expert_lora_delta(
86+
lora_a: torch.Tensor,
87+
lora_b: torch.Tensor,
88+
scaling: float,
89+
target_weight: nn.Parameter,
90+
) -> torch.Tensor:
91+
"""Compute the LoRA weight delta for expert weights.
92+
93+
Args:
94+
lora_a: (E, in, r) — projects input dim to rank.
95+
lora_b: (E, r, out) — projects rank to output dim.
96+
scaling: alpha / rank.
97+
target_weight: The base weight parameter to match DTensor placements.
98+
99+
Returns:
100+
delta matching target_weight's shape and placements.
101+
Math: delta = scaling * B^T @ A^T → shape (E, out, in).
102+
"""
103+
from torch.distributed.tensor import distribute_tensor, DTensor
104+
105+
delta = scaling * torch.bmm(lora_b.transpose(-2, -1), lora_a.transpose(-2, -1))
106+
# When the base weight is a DTensor (TP/EP sharded), distribute the delta
107+
# to match its placements so the in-place add_/sub_ operates on matching shapes.
108+
if isinstance(target_weight, DTensor) and not isinstance(delta, DTensor):
109+
delta = distribute_tensor(
110+
delta, target_weight.device_mesh, target_weight.placements
111+
)
112+
return delta
113+
114+
115+
def apply_expert_lora(
116+
experts: GroupedExperts, rank: int, alpha: float
117+
) -> GroupedExperts:
118+
"""Apply LoRA adapters to a GroupedExperts module via class swapping.
119+
120+
LoRA parameters are registered as direct parameters on the module. EP partition
121+
functions that use ``named_parameters(recurse=False)`` with ``Shard(0)`` will
122+
correctly shard them on the expert dimension. TP/ETP partition functions only
123+
touch w1/w2/w3 by name and leave LoRA parameters unsharded.
124+
125+
Forward uses merge-per-forward: LoRA deltas are merged into base weights before
126+
calling the base forward, then unmerged after. This reuses the base
127+
GroupedExperts.forward without duplicating its DTensor/EP/padding logic.
128+
"""
129+
parent_cls = type(experts)
130+
assert issubclass(
131+
parent_cls, GroupedExperts
132+
), f"parent_cls must be a subclass of GroupedExperts, got {parent_cls}"
133+
134+
if parent_cls not in _expert_lora_class_cache:
135+
136+
class LoRAGroupedExperts(parent_cls): # type: ignore[valid-type, misc]
137+
def __init__(self, *args: Any, **kwargs: Any) -> None:
138+
raise RuntimeError(
139+
"LoRAGroupedExperts should not be instantiated directly."
140+
)
141+
142+
@classmethod
143+
def from_experts(
144+
cls, experts: GroupedExperts, rank: int, alpha: float
145+
) -> "LoRAGroupedExperts":
146+
experts.__class__ = cls
147+
experts._init_expert_lora(rank, alpha) # type: ignore[attr-defined]
148+
return experts # type: ignore[return-value]
149+
150+
def _init_expert_lora(self, rank: int, alpha: float) -> None:
151+
self._lora_scaling = alpha / rank
152+
num_experts = self.num_experts
153+
# w1: (E, hidden_dim, dim) -> A1: (E, dim, r), B1: (E, r, hidden_dim)
154+
dim_w1_in = self.w1.shape[2] # dim
155+
dim_w1_out = self.w1.shape[1] # hidden_dim
156+
# w2: (E, dim, hidden_dim) -> A2: (E, hidden_dim, r), B2: (E, r, dim)
157+
dim_w2_in = self.w2.shape[2] # hidden_dim
158+
dim_w2_out = self.w2.shape[1] # dim
159+
# w3: (E, hidden_dim, dim) -> A3: (E, dim, r), B3: (E, r, hidden_dim)
160+
dim_w3_in = self.w3.shape[2] # dim
161+
dim_w3_out = self.w3.shape[1] # hidden_dim
162+
163+
device = self.w1.device
164+
dtype = self.w1.dtype
165+
166+
self.lora_a_w1 = nn.Parameter(
167+
torch.empty(
168+
num_experts, dim_w1_in, rank, device=device, dtype=dtype
169+
)
170+
)
171+
self.lora_b_w1 = nn.Parameter(
172+
torch.empty(
173+
num_experts, rank, dim_w1_out, device=device, dtype=dtype
174+
)
175+
)
176+
self.lora_a_w2 = nn.Parameter(
177+
torch.empty(
178+
num_experts, dim_w2_in, rank, device=device, dtype=dtype
179+
)
180+
)
181+
self.lora_b_w2 = nn.Parameter(
182+
torch.empty(
183+
num_experts, rank, dim_w2_out, device=device, dtype=dtype
184+
)
185+
)
186+
self.lora_a_w3 = nn.Parameter(
187+
torch.empty(
188+
num_experts, dim_w3_in, rank, device=device, dtype=dtype
189+
)
190+
)
191+
self.lora_b_w3 = nn.Parameter(
192+
torch.empty(
193+
num_experts, rank, dim_w3_out, device=device, dtype=dtype
194+
)
195+
)
196+
197+
def init_weights(self, init_std: float) -> None:
198+
super().init_weights(init_std)
199+
for name in ("lora_a_w1", "lora_a_w2", "lora_a_w3"):
200+
nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5))
201+
for name in ("lora_b_w1", "lora_b_w2", "lora_b_w3"):
202+
nn.init.zeros_(getattr(self, name))
203+
204+
def forward(
205+
self,
206+
x: torch.Tensor,
207+
num_tokens_per_expert: torch.Tensor,
208+
) -> torch.Tensor:
209+
# Merge LoRA deltas into base weights, run base forward, unmerge.
210+
# This reuses all base GroupedExperts logic (DTensor, EP, padding).
211+
deltas = {}
212+
for w_name, a_name, b_name in (
213+
("w1", "lora_a_w1", "lora_b_w1"),
214+
("w2", "lora_a_w2", "lora_b_w2"),
215+
("w3", "lora_a_w3", "lora_b_w3"),
216+
):
217+
lora_a = getattr(self, a_name)
218+
lora_b = getattr(self, b_name)
219+
w = getattr(self, w_name)
220+
delta = _compute_expert_lora_delta(
221+
lora_a, lora_b, self._lora_scaling, w
222+
)
223+
w.data.add_(delta)
224+
deltas[w_name] = delta
225+
226+
try:
227+
return super().forward(x, num_tokens_per_expert)
228+
finally:
229+
# Unmerge: subtract deltas to restore original weights
230+
for w_name, delta in deltas.items():
231+
getattr(self, w_name).data.sub_(delta)
232+
233+
LoRAGroupedExperts.__name__ = f"LoRA{parent_cls.__name__}"
234+
LoRAGroupedExperts.__qualname__ = f"LoRA{parent_cls.__name__}"
235+
_expert_lora_class_cache[parent_cls] = LoRAGroupedExperts
236+
237+
# pyrefly: ignore [missing-attribute]
238+
return _expert_lora_class_cache[parent_cls].from_experts(experts, rank, alpha)
239+
240+
80241
class LoRAConverter(Configurable):
81-
"""Apply LoRA adapters to all Linear layers in a model."""
242+
"""Apply LoRA adapters to all Linear layers and GroupedExperts in a model."""
82243

83244
@dataclass(kw_only=True, slots=True)
84245
class Config(Configurable.Config):
@@ -122,15 +283,24 @@ def convert(self, model: nn.Module) -> None:
122283
self._replace_linears_with_lora(model)
123284

124285
def _replace_linears_with_lora(self, module: nn.Module) -> None:
286+
# Collect router gate linears so we can skip them — routing scores
287+
# must stay frozen to preserve expert load balancing.
288+
router_gate_ids: set[int] = set()
289+
for child in module.modules():
290+
if isinstance(child, TokenChoiceTopKRouter):
291+
router_gate_ids.add(id(child.gate))
292+
125293
for _, child in list(module.named_modules()):
126-
if isinstance(child, nn.Linear):
294+
if isinstance(child, nn.Linear) and id(child) not in router_gate_ids:
127295
apply_lora(child, self.rank, self.alpha)
296+
elif isinstance(child, GroupedExperts):
297+
apply_expert_lora(child, self.rank, self.alpha)
128298

129299
# Expose a key filter and flag on the module so ModelWrapper can
130300
# partition the state dict without knowing about LoRA internals.
131301
def converter_key_filter(key: str) -> bool:
132302
"""Return True if key was added by this converter (LoRA adapter weights)."""
133-
return ".lora_a." in key or ".lora_b." in key
303+
return ".lora_a" in key or ".lora_b" in key
134304

135305
object.__setattr__(module, "converter_key_filter", converter_key_filter)
136306
object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only)

torchtitan/models/deepseek_v3/config_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from torchtitan.components.checkpoint import CheckpointManager
8+
from torchtitan.components.lora import LoRAConverter
89
from torchtitan.components.lr_scheduler import LRSchedulersContainer
910
from torchtitan.components.metrics import MetricsProcessor
1011
from torchtitan.components.optimizer import OptimizersContainer
@@ -58,6 +59,19 @@ def deepseek_v3_debugmodel() -> Trainer.Config:
5859
)
5960

6061

62+
def deepseek_v3_debugmodel_lora() -> Trainer.Config:
63+
config = deepseek_v3_debugmodel()
64+
config.model_converters = ModelConvertersContainer.Config(
65+
converters=[
66+
LoRAConverter.Config(
67+
rank=8,
68+
alpha=16.0,
69+
),
70+
],
71+
)
72+
return config
73+
74+
6175
def deepseek_v3_debugmodel_flex_attn() -> Trainer.Config:
6276
config = deepseek_v3_debugmodel()
6377
config.model_spec = model_registry("debugmodel_flex_attn")

0 commit comments

Comments
 (0)