|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import logging |
| 16 | +import os |
16 | 17 | import re |
| 18 | +from collections import defaultdict |
17 | 19 | from collections.abc import Generator |
18 | 20 | from typing import Iterable, List, Mapping, Optional, Tuple, Union |
19 | 21 |
|
|
29 | 31 | "match_named_parameters", |
30 | 32 | "match_targets", |
31 | 33 | "match_modules_set", |
| 34 | + "get_lowest_common_ancestor_name", |
32 | 35 | "is_match", |
33 | 36 | "is_narrow_match", |
34 | 37 | ] |
@@ -157,68 +160,194 @@ def match_targets( |
157 | 160 | return matched_targets |
158 | 161 |
|
159 | 162 |
|
| 163 | +def get_lowest_common_ancestor_name(names: list[str | None]) -> str: |
| 164 | + """ |
| 165 | + Given a list of names, returns the lowest-scope common name ignoring Nones. |
| 166 | +
|
| 167 | + Implementation is a small alteration of os.path.commonprefix |
| 168 | + https://docs.python.org/3/library/os.path.html#os.path.commonprefix |
| 169 | +
|
| 170 | + ([s1, s2]->prefix->result) |
| 171 | + # case 0: multiple modules: [abc.a., abc.b.] -> .abc. -> abc |
| 172 | + # case 1: single module: [abc.] -> .abc. -> abc |
| 173 | + # case 2: substring modules: [abc., ab.] -> .ab -> "" |
| 174 | + # case 3: parent & child: [ab., ab.a.] -> .ab. -> ab |
| 175 | + """ |
| 176 | + names = [name for name in names if name is not None] |
| 177 | + if len(names) == 0: |
| 178 | + return "" |
| 179 | + |
| 180 | + # 1) find longest shared prefix |
| 181 | + s1 = "." + min(names) + "." |
| 182 | + s2 = "." + max(names) + "." |
| 183 | + common_prefix = os.path.commonprefix([s1, s2]) |
| 184 | + # 2) throw away right most dot and name fragment, throw away leftmost char |
| 185 | + # ".keep.thro" -> "keep", "." -> "" |
| 186 | + return common_prefix[1 : common_prefix.rfind(".")] |
| 187 | + |
| 188 | + |
160 | 189 | def match_modules_set( |
161 | 190 | model: torch.nn.Module, |
162 | 191 | targets: Optional[Iterable[str]], |
163 | 192 | ignore: Optional[Iterable[str]] = None, |
164 | | -) -> Generator[Iterable[torch.nn.Module]]: |
| 193 | + error_on_module_rematch: bool = True, |
| 194 | +) -> Generator[List[List[torch.nn.Module]]]: |
165 | 195 | """ |
166 | | - Yields modules grouped with the same order and size as `targets`. |
167 | | - Values are returned in order of `model.named_modules()` |
| 196 | + Yields modules grouped by parent context. |
| 197 | +
|
| 198 | + We group by parent context so that we can return ALL matches of a |
| 199 | + specific target that can be paired with another target. This is most |
| 200 | + relevant in the case of MoE modules with multiple modules for each |
| 201 | + expert i.e. post_attention_layernorm <-> mlp.expert.N.gate_proj, |
| 202 | + mlp.expert.N.up_proj for all N. The parent context will differ from |
| 203 | + one layer to another while being the same for one expert to another. |
168 | 204 |
|
169 | | - E.g. the following targets would yield module belonging to the following layers: |
| 205 | + Each returned group is a list (of lists) with the same size |
| 206 | + and order as `targets` while all matches for each target and |
| 207 | + the overall order of the groups are ordered in the same way |
| 208 | + as `model.named_modules` |
| 209 | +
|
| 210 | +
|
| 211 | + E.g. the following targets would yield modules belonging to the following layers: |
170 | 212 | ```python3 |
171 | 213 | match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == ( |
172 | | - ( |
173 | | - `model.layers.0.self_attn.q_proj`, |
174 | | - `model.layers.0.self_attn.k_proj`, |
175 | | - `model.layers.0.self_attn.v_proj`, |
176 | | - ), |
177 | | - ( |
178 | | - `model.layers.1.self_attn.q_proj`, |
179 | | - `model.layers.1.self_attn.k_proj`, |
180 | | - `model.layers.1.self_attn.v_proj`, |
181 | | - ), |
| 214 | + [ |
| 215 | + [`layers.0.self_attn.q_proj`], |
| 216 | + [`layers.0.self_attn.k_proj`], |
| 217 | + [`layers.0.self_attn.v_proj`], |
| 218 | + ], |
| 219 | + [ |
| 220 | + [`layers.1.self_attn.q_proj`], |
| 221 | + [`layers.1.self_attn.k_proj`], |
| 222 | + [`layers.1.self_attn.v_proj`], |
| 223 | + ], |
182 | 224 | ... |
183 | | - ( |
184 | | - `model.layers.32.self_attn.q_proj`, |
185 | | - `model.layers.32.self_attn.k_proj`, |
186 | | - `model.layers.32.self_attn.v_proj`, |
187 | | - ), |
188 | 225 | ) |
189 | 226 | ``` |
190 | 227 |
|
191 | 228 | This can be used to match layers to their corresponding downstream counterparts. |
192 | 229 | For example, matching layer norms to their subsequent linear layers |
193 | 230 | ```python3 |
194 | 231 | for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)): |
195 | | - fuse_norm_linears(norm, [q, k, v]) |
| 232 | + fuse_norm_linears(*norm, [*q, *k, *v]) |
| 233 | + ``` |
| 234 | +
|
| 235 | + Alternatively for MoE you would get multiple matches |
| 236 | + per target per group, E.g. |
| 237 | +
|
| 238 | + ```python3 |
| 239 | +
|
| 240 | + targets = [ |
| 241 | + "post_attention_layernorm", |
| 242 | + "up_proj", |
| 243 | + "down_proj" |
| 244 | + ] |
| 245 | + match_modules_set(model, targets) == ( |
| 246 | + [ |
| 247 | + [layers.0.post_attention_layernorm], |
| 248 | + [ |
| 249 | + `layers.0.mlp.experts.0.up_proj`, |
| 250 | + `layers.0.mlp.experts.1.up_proj`, |
| 251 | + ... |
| 252 | + ], |
| 253 | + [ |
| 254 | + `layers.0.mlp.experts.0.down_proj`, |
| 255 | + `layers.0.mlp.experts.1.down_proj`, |
| 256 | + ... |
| 257 | +
|
| 258 | + ] |
| 259 | + ], # <- first yield |
| 260 | + [ |
| 261 | + [layers.1.post_attention_layernorm], |
| 262 | + [ |
| 263 | + `layers.1.mlp.experts.0.up_proj`, |
| 264 | + `layers.1.mlp.experts.1.up_proj`, |
| 265 | + ... |
| 266 | + ], |
| 267 | + [ |
| 268 | + `layers.1.mlp.experts.0.down_proj`, |
| 269 | + `layers.1.mlp.experts.1.down_proj`, |
| 270 | + ... |
| 271 | + ] |
| 272 | + ], |
| 273 | + ... |
| 274 | + ) |
| 275 | + ``` |
196 | 276 |
|
197 | 277 | :param model: model containing modules to match against |
198 | 278 | :param targets: target strings, potentially containing "re:" prefixes |
199 | 279 | :param ignore: targets to ignore, potentially containing "re:" prefixes |
| 280 | + :param error_on_module_rematch: if True, errors when a module gets |
| 281 | + matched to multiple targets, if False, no error. (Defaults to True) |
200 | 282 | """ |
201 | 283 | targets = targets or [] |
202 | 284 | ignore = ignore or [] |
203 | 285 |
|
204 | | - matches = dict.fromkeys(targets, None) |
| 286 | + # as we iterate through modules and try to match them with targets, |
| 287 | + # the algorithm can be in 2 possible states: |
| 288 | + # 0) unmatched_targets > 0, i.e. some of the targets haven't been matched. |
| 289 | + # Keep matching until all targets have at least one match |
| 290 | + # 1) unmatched_targets == 0 i.e. we have at least one match for each target. |
| 291 | + # At this point we are unsure if we have a full set or if we need to add |
| 292 | + # more matches. |
| 293 | + # There are 3 things that can happen once were in state 1: |
| 294 | + # A) found a new match with same parent_context, |
| 295 | + # (add it to matches and keep going) |
| 296 | + # B) found a new match with different parent_context, i.e. we found a match |
| 297 | + # that requires a deeper parent context, this indicates that this match |
| 298 | + # should be part of a new set. |
| 299 | + # (yield current set [not including newest match] and go back to state 0) |
| 300 | + # C) ran out of modules, we will always yield the final remaining set when |
| 301 | + # we we've iterated through all the modules in the model. |
| 302 | + # (yield final set then exit.) |
| 303 | + # Note: its possible to iterate through all the modules in the model while |
| 304 | + # not having a full matched set if the user specified a bad matching, in |
| 305 | + # that case something has gone wrong and we error |
| 306 | + matches = defaultdict(list) |
| 307 | + parent_context = None |
| 308 | + unmatched_targets = set(targets) |
| 309 | + |
205 | 310 | for name, module in model.named_modules(): |
206 | | - # match until we get a full set |
| 311 | + matched_targets_for_cur_module = set() |
207 | 312 | for target in targets: |
208 | 313 | if is_match(name, module, target, ignore): |
209 | | - if matches[target] is not None: |
210 | | - raise ValueError(f"Matched a {target} twice before completing set") |
211 | | - matches[target] = module |
212 | | - |
213 | | - # once we have a full set, yield and reset |
214 | | - if targets and all((matches[target] is not None for target in targets)): |
215 | | - yield [matches[target] for target in targets] # ensure correct ordering |
216 | | - matches = dict.fromkeys(targets, None) |
217 | | - |
218 | | - # check that none are left over |
219 | | - unmatched_keys = [match for match, value in matches.items() if value is not None] |
220 | | - if len(unmatched_keys): |
221 | | - raise ValueError(f"Unable to match targets into set: {unmatched_keys}") |
| 314 | + new_parent_context = get_lowest_common_ancestor_name( |
| 315 | + [name, parent_context] |
| 316 | + ) |
| 317 | + |
| 318 | + # code for (B) |
| 319 | + if not unmatched_targets and new_parent_context != parent_context: |
| 320 | + yield [matches[target] for target in targets] |
| 321 | + matches = defaultdict(list) |
| 322 | + new_parent_context = name |
| 323 | + unmatched_targets = set(targets) |
| 324 | + |
| 325 | + matches[target].append(module) |
| 326 | + parent_context = new_parent_context |
| 327 | + unmatched_targets -= {target} |
| 328 | + matched_targets_for_cur_module |= {target} |
| 329 | + |
| 330 | + if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch: |
| 331 | + raise ValueError( |
| 332 | + f"module: {name} was matched with multiple targets: " |
| 333 | + f"{matched_targets_for_cur_module} which is unexpected " |
| 334 | + "disable this check by setting `error_on_module_rematch = False`" |
| 335 | + ) |
| 336 | + |
| 337 | + # never found anything |
| 338 | + if unmatched_targets == set(targets): |
| 339 | + return |
| 340 | + |
| 341 | + # code for (C) |
| 342 | + if not unmatched_targets: # have a full matching |
| 343 | + yield [matches[target] for target in targets] |
| 344 | + return |
| 345 | + |
| 346 | + raise ValueError( |
| 347 | + f"Found a final incomplete set with matches found for keys: " |
| 348 | + f"{set(targets) - unmatched_targets} " |
| 349 | + f"but no matches found for keys: {unmatched_targets}" |
| 350 | + ) |
222 | 351 |
|
223 | 352 |
|
224 | 353 | def is_match( |
|
0 commit comments