|
14 | 14 |
|
15 | 15 | from torchtitan.config import Configurable |
16 | 16 | from torchtitan.models.common.linear import Linear |
| 17 | +from torchtitan.models.common.moe.moe import GroupedExperts, TokenChoiceTopKRouter |
17 | 18 | from torchtitan.tools.logging import logger |
18 | 19 |
|
19 | 20 | # Cache for dynamically created LoRA classes |
20 | 21 | _lora_class_cache: dict[type, type] = {} |
21 | 22 |
|
| 23 | +# Cache for dynamically created expert LoRA classes |
| 24 | +_expert_lora_class_cache: dict[type, type] = {} |
| 25 | + |
22 | 26 |
|
23 | 27 | def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear: |
24 | 28 | parent_cls = type(linear) |
@@ -79,8 +83,164 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: |
79 | 83 | return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha) |
80 | 84 |
|
81 | 85 |
|
| 86 | +def _compute_expert_lora_delta( |
| 87 | + lora_a: torch.Tensor, |
| 88 | + lora_b: torch.Tensor, |
| 89 | + scaling: float, |
| 90 | + target_weight: nn.Parameter, |
| 91 | +) -> torch.Tensor: |
| 92 | + """Compute the LoRA weight delta for expert weights. |
| 93 | +
|
| 94 | + Args: |
| 95 | + lora_a: (E, in, r) — projects input dim to rank. |
| 96 | + lora_b: (E, r, out) — projects rank to output dim. |
| 97 | + scaling: alpha / rank. |
| 98 | + target_weight: The base weight parameter to match DTensor placements. |
| 99 | +
|
| 100 | + Returns: |
| 101 | + delta matching target_weight's shape and placements. |
| 102 | + Math: delta = scaling * B^T @ A^T → shape (E, out, in). |
| 103 | + """ |
| 104 | + from torch.distributed.tensor import distribute_tensor, DTensor |
| 105 | + |
| 106 | + delta = scaling * torch.bmm(lora_b.transpose(-2, -1), lora_a.transpose(-2, -1)) |
| 107 | + # When the base weight is a DTensor (TP/EP sharded), distribute the delta |
| 108 | + # to match its placements so the in-place add_/sub_ operates on matching shapes. |
| 109 | + if isinstance(target_weight, DTensor) and not isinstance(delta, DTensor): |
| 110 | + delta = distribute_tensor( |
| 111 | + delta, target_weight.device_mesh, target_weight.placements |
| 112 | + ) |
| 113 | + return delta |
| 114 | + |
| 115 | + |
| 116 | +def apply_expert_lora( |
| 117 | + experts: GroupedExperts, rank: int, alpha: float |
| 118 | +) -> GroupedExperts: |
| 119 | + """Apply LoRA adapters to a GroupedExperts module via class swapping. |
| 120 | +
|
| 121 | + LoRA parameters are registered as direct parameters on the module. EP partition |
| 122 | + functions that use ``named_parameters(recurse=False)`` with ``Shard(0)`` will |
| 123 | + correctly shard them on the expert dimension. TP/ETP partition functions only |
| 124 | + touch w1/w2/w3 by name and leave LoRA parameters unsharded. |
| 125 | +
|
| 126 | + Forward uses merge-per-forward: LoRA deltas are merged into base weights before |
| 127 | + calling the base forward, then unmerged after. This reuses the base |
| 128 | + GroupedExperts.forward without duplicating its DTensor/EP/padding logic. |
| 129 | + """ |
| 130 | + parent_cls = type(experts) |
| 131 | + assert issubclass( |
| 132 | + parent_cls, GroupedExperts |
| 133 | + ), f"parent_cls must be a subclass of GroupedExperts, got {parent_cls}" |
| 134 | + |
| 135 | + if parent_cls not in _expert_lora_class_cache: |
| 136 | + |
| 137 | + class LoRAGroupedExperts(parent_cls): # type: ignore[valid-type, misc] |
| 138 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 139 | + raise RuntimeError( |
| 140 | + "LoRAGroupedExperts should not be instantiated directly." |
| 141 | + ) |
| 142 | + |
| 143 | + @classmethod |
| 144 | + def from_experts( |
| 145 | + cls, experts: GroupedExperts, rank: int, alpha: float |
| 146 | + ) -> "LoRAGroupedExperts": |
| 147 | + experts.__class__ = cls |
| 148 | + experts._init_expert_lora(rank, alpha) # type: ignore[attr-defined] |
| 149 | + return experts # type: ignore[return-value] |
| 150 | + |
| 151 | + def _init_expert_lora(self, rank: int, alpha: float) -> None: |
| 152 | + self._lora_scaling = alpha / rank |
| 153 | + num_experts = self.num_experts |
| 154 | + # w1: (E, hidden_dim, dim) -> A1: (E, dim, r), B1: (E, r, hidden_dim) |
| 155 | + dim_w1_in = self.w1.shape[2] # dim |
| 156 | + dim_w1_out = self.w1.shape[1] # hidden_dim |
| 157 | + # w2: (E, dim, hidden_dim) -> A2: (E, hidden_dim, r), B2: (E, r, dim) |
| 158 | + dim_w2_in = self.w2.shape[2] # hidden_dim |
| 159 | + dim_w2_out = self.w2.shape[1] # dim |
| 160 | + # w3: (E, hidden_dim, dim) -> A3: (E, dim, r), B3: (E, r, hidden_dim) |
| 161 | + dim_w3_in = self.w3.shape[2] # dim |
| 162 | + dim_w3_out = self.w3.shape[1] # hidden_dim |
| 163 | + |
| 164 | + device = self.w1.device |
| 165 | + dtype = self.w1.dtype |
| 166 | + |
| 167 | + self.lora_a_w1 = nn.Parameter( |
| 168 | + torch.empty( |
| 169 | + num_experts, dim_w1_in, rank, device=device, dtype=dtype |
| 170 | + ) |
| 171 | + ) |
| 172 | + self.lora_b_w1 = nn.Parameter( |
| 173 | + torch.empty( |
| 174 | + num_experts, rank, dim_w1_out, device=device, dtype=dtype |
| 175 | + ) |
| 176 | + ) |
| 177 | + self.lora_a_w2 = nn.Parameter( |
| 178 | + torch.empty( |
| 179 | + num_experts, dim_w2_in, rank, device=device, dtype=dtype |
| 180 | + ) |
| 181 | + ) |
| 182 | + self.lora_b_w2 = nn.Parameter( |
| 183 | + torch.empty( |
| 184 | + num_experts, rank, dim_w2_out, device=device, dtype=dtype |
| 185 | + ) |
| 186 | + ) |
| 187 | + self.lora_a_w3 = nn.Parameter( |
| 188 | + torch.empty( |
| 189 | + num_experts, dim_w3_in, rank, device=device, dtype=dtype |
| 190 | + ) |
| 191 | + ) |
| 192 | + self.lora_b_w3 = nn.Parameter( |
| 193 | + torch.empty( |
| 194 | + num_experts, rank, dim_w3_out, device=device, dtype=dtype |
| 195 | + ) |
| 196 | + ) |
| 197 | + |
| 198 | + def init_weights(self, init_std: float) -> None: |
| 199 | + super().init_weights(init_std) |
| 200 | + for name in ("lora_a_w1", "lora_a_w2", "lora_a_w3"): |
| 201 | + nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) |
| 202 | + for name in ("lora_b_w1", "lora_b_w2", "lora_b_w3"): |
| 203 | + nn.init.zeros_(getattr(self, name)) |
| 204 | + |
| 205 | + def forward( |
| 206 | + self, |
| 207 | + x: torch.Tensor, |
| 208 | + num_tokens_per_expert: torch.Tensor, |
| 209 | + ) -> torch.Tensor: |
| 210 | + # Merge LoRA deltas into base weights, run base forward, unmerge. |
| 211 | + # This reuses all base GroupedExperts logic (DTensor, EP, padding). |
| 212 | + deltas = {} |
| 213 | + for w_name, a_name, b_name in ( |
| 214 | + ("w1", "lora_a_w1", "lora_b_w1"), |
| 215 | + ("w2", "lora_a_w2", "lora_b_w2"), |
| 216 | + ("w3", "lora_a_w3", "lora_b_w3"), |
| 217 | + ): |
| 218 | + lora_a = getattr(self, a_name) |
| 219 | + lora_b = getattr(self, b_name) |
| 220 | + w = getattr(self, w_name) |
| 221 | + delta = _compute_expert_lora_delta( |
| 222 | + lora_a, lora_b, self._lora_scaling, w |
| 223 | + ) |
| 224 | + w.data.add_(delta) |
| 225 | + deltas[w_name] = delta |
| 226 | + |
| 227 | + try: |
| 228 | + return super().forward(x, num_tokens_per_expert) |
| 229 | + finally: |
| 230 | + # Unmerge: subtract deltas to restore original weights |
| 231 | + for w_name, delta in deltas.items(): |
| 232 | + getattr(self, w_name).data.sub_(delta) |
| 233 | + |
| 234 | + LoRAGroupedExperts.__name__ = f"LoRA{parent_cls.__name__}" |
| 235 | + LoRAGroupedExperts.__qualname__ = f"LoRA{parent_cls.__name__}" |
| 236 | + _expert_lora_class_cache[parent_cls] = LoRAGroupedExperts |
| 237 | + |
| 238 | + # pyrefly: ignore [missing-attribute] |
| 239 | + return _expert_lora_class_cache[parent_cls].from_experts(experts, rank, alpha) |
| 240 | + |
| 241 | + |
82 | 242 | class LoRAConverter(Configurable): |
83 | | - """Apply LoRA adapters to all Linear layers in a model.""" |
| 243 | + """Apply LoRA adapters to all Linear layers and GroupedExperts in a model.""" |
84 | 244 |
|
85 | 245 | @dataclass(kw_only=True, slots=True) |
86 | 246 | class Config(Configurable.Config): |
@@ -125,9 +285,18 @@ def convert(self, model: nn.Module) -> None: |
125 | 285 | } |
126 | 286 |
|
127 | 287 | def _replace_linears_with_lora(self, module: nn.Module) -> None: |
| 288 | + # Collect router gate linears so we can skip them — routing scores |
| 289 | + # must stay frozen to preserve expert load balancing. |
| 290 | + router_gate_ids: set[int] = set() |
| 291 | + for child in module.modules(): |
| 292 | + if isinstance(child, TokenChoiceTopKRouter): |
| 293 | + router_gate_ids.add(id(child.gate)) |
| 294 | + |
128 | 295 | for _, child in list(module.named_modules()): |
129 | | - if isinstance(child, nn.Linear): |
| 296 | + if isinstance(child, nn.Linear) and id(child) not in router_gate_ids: |
130 | 297 | apply_lora(child, self.rank, self.alpha) |
| 298 | + elif isinstance(child, GroupedExperts): |
| 299 | + apply_expert_lora(child, self.rank, self.alpha) |
131 | 300 |
|
132 | 301 | def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None: |
133 | 302 | pass |
|
0 commit comments