|
10 | 10 |
|
11 | 11 | import torch |
12 | 12 | import torch.nn as nn |
| 13 | + |
13 | 14 | from torchtitan.config import Configurable |
14 | 15 | from torchtitan.models.common.linear import Linear |
| 16 | +from torchtitan.models.common.moe.moe import GroupedExperts, TokenChoiceTopKRouter |
15 | 17 | from torchtitan.tools.logging import logger |
16 | 18 |
|
17 | 19 | # Cache for dynamically created LoRA classes |
18 | 20 | _lora_class_cache: dict[type, type] = {} |
19 | 21 |
|
| 22 | +# Cache for dynamically created expert LoRA classes |
| 23 | +_expert_lora_class_cache: dict[type, type] = {} |
| 24 | + |
20 | 25 |
|
21 | 26 | def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear: |
22 | 27 | parent_cls = type(linear) |
@@ -77,8 +82,164 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: |
77 | 82 | return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha) |
78 | 83 |
|
79 | 84 |
|
| 85 | +def _compute_expert_lora_delta( |
| 86 | + lora_a: torch.Tensor, |
| 87 | + lora_b: torch.Tensor, |
| 88 | + scaling: float, |
| 89 | + target_weight: nn.Parameter, |
| 90 | +) -> torch.Tensor: |
| 91 | + """Compute the LoRA weight delta for expert weights. |
| 92 | +
|
| 93 | + Args: |
| 94 | + lora_a: (E, in, r) — projects input dim to rank. |
| 95 | + lora_b: (E, r, out) — projects rank to output dim. |
| 96 | + scaling: alpha / rank. |
| 97 | + target_weight: The base weight parameter to match DTensor placements. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + delta matching target_weight's shape and placements. |
| 101 | + Math: delta = scaling * B^T @ A^T → shape (E, out, in). |
| 102 | + """ |
| 103 | + from torch.distributed.tensor import distribute_tensor, DTensor |
| 104 | + |
| 105 | + delta = scaling * torch.bmm(lora_b.transpose(-2, -1), lora_a.transpose(-2, -1)) |
| 106 | + # When the base weight is a DTensor (TP/EP sharded), distribute the delta |
| 107 | + # to match its placements so the in-place add_/sub_ operates on matching shapes. |
| 108 | + if isinstance(target_weight, DTensor) and not isinstance(delta, DTensor): |
| 109 | + delta = distribute_tensor( |
| 110 | + delta, target_weight.device_mesh, target_weight.placements |
| 111 | + ) |
| 112 | + return delta |
| 113 | + |
| 114 | + |
| 115 | +def apply_expert_lora( |
| 116 | + experts: GroupedExperts, rank: int, alpha: float |
| 117 | +) -> GroupedExperts: |
| 118 | + """Apply LoRA adapters to a GroupedExperts module via class swapping. |
| 119 | +
|
| 120 | + LoRA parameters are registered as direct parameters on the module. EP partition |
| 121 | + functions that use ``named_parameters(recurse=False)`` with ``Shard(0)`` will |
| 122 | + correctly shard them on the expert dimension. TP/ETP partition functions only |
| 123 | + touch w1/w2/w3 by name and leave LoRA parameters unsharded. |
| 124 | +
|
| 125 | + Forward uses merge-per-forward: LoRA deltas are merged into base weights before |
| 126 | + calling the base forward, then unmerged after. This reuses the base |
| 127 | + GroupedExperts.forward without duplicating its DTensor/EP/padding logic. |
| 128 | + """ |
| 129 | + parent_cls = type(experts) |
| 130 | + assert issubclass( |
| 131 | + parent_cls, GroupedExperts |
| 132 | + ), f"parent_cls must be a subclass of GroupedExperts, got {parent_cls}" |
| 133 | + |
| 134 | + if parent_cls not in _expert_lora_class_cache: |
| 135 | + |
| 136 | + class LoRAGroupedExperts(parent_cls): # type: ignore[valid-type, misc] |
| 137 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 138 | + raise RuntimeError( |
| 139 | + "LoRAGroupedExperts should not be instantiated directly." |
| 140 | + ) |
| 141 | + |
| 142 | + @classmethod |
| 143 | + def from_experts( |
| 144 | + cls, experts: GroupedExperts, rank: int, alpha: float |
| 145 | + ) -> "LoRAGroupedExperts": |
| 146 | + experts.__class__ = cls |
| 147 | + experts._init_expert_lora(rank, alpha) # type: ignore[attr-defined] |
| 148 | + return experts # type: ignore[return-value] |
| 149 | + |
| 150 | + def _init_expert_lora(self, rank: int, alpha: float) -> None: |
| 151 | + self._lora_scaling = alpha / rank |
| 152 | + num_experts = self.num_experts |
| 153 | + # w1: (E, hidden_dim, dim) -> A1: (E, dim, r), B1: (E, r, hidden_dim) |
| 154 | + dim_w1_in = self.w1.shape[2] # dim |
| 155 | + dim_w1_out = self.w1.shape[1] # hidden_dim |
| 156 | + # w2: (E, dim, hidden_dim) -> A2: (E, hidden_dim, r), B2: (E, r, dim) |
| 157 | + dim_w2_in = self.w2.shape[2] # hidden_dim |
| 158 | + dim_w2_out = self.w2.shape[1] # dim |
| 159 | + # w3: (E, hidden_dim, dim) -> A3: (E, dim, r), B3: (E, r, hidden_dim) |
| 160 | + dim_w3_in = self.w3.shape[2] # dim |
| 161 | + dim_w3_out = self.w3.shape[1] # hidden_dim |
| 162 | + |
| 163 | + device = self.w1.device |
| 164 | + dtype = self.w1.dtype |
| 165 | + |
| 166 | + self.lora_a_w1 = nn.Parameter( |
| 167 | + torch.empty( |
| 168 | + num_experts, dim_w1_in, rank, device=device, dtype=dtype |
| 169 | + ) |
| 170 | + ) |
| 171 | + self.lora_b_w1 = nn.Parameter( |
| 172 | + torch.empty( |
| 173 | + num_experts, rank, dim_w1_out, device=device, dtype=dtype |
| 174 | + ) |
| 175 | + ) |
| 176 | + self.lora_a_w2 = nn.Parameter( |
| 177 | + torch.empty( |
| 178 | + num_experts, dim_w2_in, rank, device=device, dtype=dtype |
| 179 | + ) |
| 180 | + ) |
| 181 | + self.lora_b_w2 = nn.Parameter( |
| 182 | + torch.empty( |
| 183 | + num_experts, rank, dim_w2_out, device=device, dtype=dtype |
| 184 | + ) |
| 185 | + ) |
| 186 | + self.lora_a_w3 = nn.Parameter( |
| 187 | + torch.empty( |
| 188 | + num_experts, dim_w3_in, rank, device=device, dtype=dtype |
| 189 | + ) |
| 190 | + ) |
| 191 | + self.lora_b_w3 = nn.Parameter( |
| 192 | + torch.empty( |
| 193 | + num_experts, rank, dim_w3_out, device=device, dtype=dtype |
| 194 | + ) |
| 195 | + ) |
| 196 | + |
| 197 | + def init_weights(self, init_std: float) -> None: |
| 198 | + super().init_weights(init_std) |
| 199 | + for name in ("lora_a_w1", "lora_a_w2", "lora_a_w3"): |
| 200 | + nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) |
| 201 | + for name in ("lora_b_w1", "lora_b_w2", "lora_b_w3"): |
| 202 | + nn.init.zeros_(getattr(self, name)) |
| 203 | + |
| 204 | + def forward( |
| 205 | + self, |
| 206 | + x: torch.Tensor, |
| 207 | + num_tokens_per_expert: torch.Tensor, |
| 208 | + ) -> torch.Tensor: |
| 209 | + # Merge LoRA deltas into base weights, run base forward, unmerge. |
| 210 | + # This reuses all base GroupedExperts logic (DTensor, EP, padding). |
| 211 | + deltas = {} |
| 212 | + for w_name, a_name, b_name in ( |
| 213 | + ("w1", "lora_a_w1", "lora_b_w1"), |
| 214 | + ("w2", "lora_a_w2", "lora_b_w2"), |
| 215 | + ("w3", "lora_a_w3", "lora_b_w3"), |
| 216 | + ): |
| 217 | + lora_a = getattr(self, a_name) |
| 218 | + lora_b = getattr(self, b_name) |
| 219 | + w = getattr(self, w_name) |
| 220 | + delta = _compute_expert_lora_delta( |
| 221 | + lora_a, lora_b, self._lora_scaling, w |
| 222 | + ) |
| 223 | + w.data.add_(delta) |
| 224 | + deltas[w_name] = delta |
| 225 | + |
| 226 | + try: |
| 227 | + return super().forward(x, num_tokens_per_expert) |
| 228 | + finally: |
| 229 | + # Unmerge: subtract deltas to restore original weights |
| 230 | + for w_name, delta in deltas.items(): |
| 231 | + getattr(self, w_name).data.sub_(delta) |
| 232 | + |
| 233 | + LoRAGroupedExperts.__name__ = f"LoRA{parent_cls.__name__}" |
| 234 | + LoRAGroupedExperts.__qualname__ = f"LoRA{parent_cls.__name__}" |
| 235 | + _expert_lora_class_cache[parent_cls] = LoRAGroupedExperts |
| 236 | + |
| 237 | + # pyrefly: ignore [missing-attribute] |
| 238 | + return _expert_lora_class_cache[parent_cls].from_experts(experts, rank, alpha) |
| 239 | + |
| 240 | + |
80 | 241 | class LoRAConverter(Configurable): |
81 | | - """Apply LoRA adapters to all Linear layers in a model.""" |
| 242 | + """Apply LoRA adapters to all Linear layers and GroupedExperts in a model.""" |
82 | 243 |
|
83 | 244 | @dataclass(kw_only=True, slots=True) |
84 | 245 | class Config(Configurable.Config): |
@@ -122,15 +283,24 @@ def convert(self, model: nn.Module) -> None: |
122 | 283 | self._replace_linears_with_lora(model) |
123 | 284 |
|
124 | 285 | def _replace_linears_with_lora(self, module: nn.Module) -> None: |
| 286 | + # Collect router gate linears so we can skip them — routing scores |
| 287 | + # must stay frozen to preserve expert load balancing. |
| 288 | + router_gate_ids: set[int] = set() |
| 289 | + for child in module.modules(): |
| 290 | + if isinstance(child, TokenChoiceTopKRouter): |
| 291 | + router_gate_ids.add(id(child.gate)) |
| 292 | + |
125 | 293 | for _, child in list(module.named_modules()): |
126 | | - if isinstance(child, nn.Linear): |
| 294 | + if isinstance(child, nn.Linear) and id(child) not in router_gate_ids: |
127 | 295 | apply_lora(child, self.rank, self.alpha) |
| 296 | + elif isinstance(child, GroupedExperts): |
| 297 | + apply_expert_lora(child, self.rank, self.alpha) |
128 | 298 |
|
129 | 299 | # Expose a key filter and flag on the module so ModelWrapper can |
130 | 300 | # partition the state dict without knowing about LoRA internals. |
131 | 301 | def converter_key_filter(key: str) -> bool: |
132 | 302 | """Return True if key was added by this converter (LoRA adapter weights).""" |
133 | | - return ".lora_a." in key or ".lora_b." in key |
| 303 | + return ".lora_a" in key or ".lora_b" in key |
134 | 304 |
|
135 | 305 | object.__setattr__(module, "converter_key_filter", converter_key_filter) |
136 | 306 | object.__setattr__(module, "save_converter_keys_only", self.save_adapter_only) |
|
0 commit comments