Skip to content

Commit 1ec8bb6

Browse files
HDCharleskylesayrsfynnsu
authored
fix match_modules_set to work with MoE (#524)
* fix match_modules_set to work with MoE Summary: match_modules_set isn't currently as useful as it could be because it lacks the ability to match multiple results for each set like in the case of a moe model where you have 128 experts. ``` [`layers.32.mlp.experts.0.gate_up_proj`, ..., `layers.32.mlp.experts.127.gate_up_proj`] ``` In order to make is so this can still work for matching simple cases and moe cases we use the following approach. 1) match modules until we have at least 1 match per target 2) when we have 1 match per target, our set is 'full' and we calculate the common parent context 3) continue matching and for each match, check if parent context would change given the new match 4) if we find a match that changes the parent context, this is the first element of the next set. yield the existing matched set and then reset, using the current match as the first element of the new set. To facilitate this algorithm i also added get_lowest_common_module_name which basically copies a similar function in llm-compressor though significantly simpler. Signed-off-by: HDCharles <[email protected]> * Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers <[email protected]> Signed-off-by: HDCharles <[email protected]> * Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers <[email protected]> Signed-off-by: HDCharles <[email protected]> * Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers <[email protected]> Signed-off-by: HDCharles <[email protected]> * Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers <[email protected]> Signed-off-by: HDCharles <[email protected]> * format fixes and bug fixes Summary Signed-off-by: HDCharles <[email protected]> * formatting and fixes Summary Signed-off-by: HDCharles <[email protected]> * formatting Summary Signed-off-by: HDCharles <[email protected]> * formatting the formatting of format Summary Signed-off-by: HDCharles <[email protected]> * making it look nice Summary Signed-off-by: HDCharles <[email protected]> * improve name to lowest_common_ancestor Summary Signed-off-by: HDCharles <[email protected]> * check for multiple matches, formatting, List typehint Summary Signed-off-by: HDCharles <[email protected]> * error instead of warn Summary Signed-off-by: HDCharles <[email protected]> * fix Summary Signed-off-by: HDCharles <[email protected]> --------- Signed-off-by: HDCharles <[email protected]> Signed-off-by: HDCharles <[email protected]> Co-authored-by: Kyle Sayers <[email protected]> Co-authored-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 73c2cf9 commit 1ec8bb6

File tree

2 files changed

+317
-47
lines changed

2 files changed

+317
-47
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 164 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617
import re
18+
from collections import defaultdict
1719
from collections.abc import Generator
1820
from typing import Iterable, List, Mapping, Optional, Tuple, Union
1921

@@ -29,6 +31,7 @@
2931
"match_named_parameters",
3032
"match_targets",
3133
"match_modules_set",
34+
"get_lowest_common_ancestor_name",
3235
"is_match",
3336
"is_narrow_match",
3437
]
@@ -157,68 +160,194 @@ def match_targets(
157160
return matched_targets
158161

159162

163+
def get_lowest_common_ancestor_name(names: list[str | None]) -> str:
164+
"""
165+
Given a list of names, returns the lowest-scope common name ignoring Nones.
166+
167+
Implementation is a small alteration of os.path.commonprefix
168+
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
169+
170+
([s1, s2]->prefix->result)
171+
# case 0: multiple modules: [abc.a., abc.b.] -> .abc. -> abc
172+
# case 1: single module: [abc.] -> .abc. -> abc
173+
# case 2: substring modules: [abc., ab.] -> .ab -> ""
174+
# case 3: parent & child: [ab., ab.a.] -> .ab. -> ab
175+
"""
176+
names = [name for name in names if name is not None]
177+
if len(names) == 0:
178+
return ""
179+
180+
# 1) find longest shared prefix
181+
s1 = "." + min(names) + "."
182+
s2 = "." + max(names) + "."
183+
common_prefix = os.path.commonprefix([s1, s2])
184+
# 2) throw away right most dot and name fragment, throw away leftmost char
185+
# ".keep.thro" -> "keep", "." -> ""
186+
return common_prefix[1 : common_prefix.rfind(".")]
187+
188+
160189
def match_modules_set(
161190
model: torch.nn.Module,
162191
targets: Optional[Iterable[str]],
163192
ignore: Optional[Iterable[str]] = None,
164-
) -> Generator[Iterable[torch.nn.Module]]:
193+
error_on_module_rematch: bool = True,
194+
) -> Generator[List[List[torch.nn.Module]]]:
165195
"""
166-
Yields modules grouped with the same order and size as `targets`.
167-
Values are returned in order of `model.named_modules()`
196+
Yields modules grouped by parent context.
197+
198+
We group by parent context so that we can return ALL matches of a
199+
specific target that can be paired with another target. This is most
200+
relevant in the case of MoE modules with multiple modules for each
201+
expert i.e. post_attention_layernorm <-> mlp.expert.N.gate_proj,
202+
mlp.expert.N.up_proj for all N. The parent context will differ from
203+
one layer to another while being the same for one expert to another.
168204
169-
E.g. the following targets would yield module belonging to the following layers:
205+
Each returned group is a list (of lists) with the same size
206+
and order as `targets` while all matches for each target and
207+
the overall order of the groups are ordered in the same way
208+
as `model.named_modules`
209+
210+
211+
E.g. the following targets would yield modules belonging to the following layers:
170212
```python3
171213
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
172-
(
173-
`model.layers.0.self_attn.q_proj`,
174-
`model.layers.0.self_attn.k_proj`,
175-
`model.layers.0.self_attn.v_proj`,
176-
),
177-
(
178-
`model.layers.1.self_attn.q_proj`,
179-
`model.layers.1.self_attn.k_proj`,
180-
`model.layers.1.self_attn.v_proj`,
181-
),
214+
[
215+
[`layers.0.self_attn.q_proj`],
216+
[`layers.0.self_attn.k_proj`],
217+
[`layers.0.self_attn.v_proj`],
218+
],
219+
[
220+
[`layers.1.self_attn.q_proj`],
221+
[`layers.1.self_attn.k_proj`],
222+
[`layers.1.self_attn.v_proj`],
223+
],
182224
...
183-
(
184-
`model.layers.32.self_attn.q_proj`,
185-
`model.layers.32.self_attn.k_proj`,
186-
`model.layers.32.self_attn.v_proj`,
187-
),
188225
)
189226
```
190227
191228
This can be used to match layers to their corresponding downstream counterparts.
192229
For example, matching layer norms to their subsequent linear layers
193230
```python3
194231
for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
195-
fuse_norm_linears(norm, [q, k, v])
232+
fuse_norm_linears(*norm, [*q, *k, *v])
233+
```
234+
235+
Alternatively for MoE you would get multiple matches
236+
per target per group, E.g.
237+
238+
```python3
239+
240+
targets = [
241+
"post_attention_layernorm",
242+
"up_proj",
243+
"down_proj"
244+
]
245+
match_modules_set(model, targets) == (
246+
[
247+
[layers.0.post_attention_layernorm],
248+
[
249+
`layers.0.mlp.experts.0.up_proj`,
250+
`layers.0.mlp.experts.1.up_proj`,
251+
...
252+
],
253+
[
254+
`layers.0.mlp.experts.0.down_proj`,
255+
`layers.0.mlp.experts.1.down_proj`,
256+
...
257+
258+
]
259+
], # <- first yield
260+
[
261+
[layers.1.post_attention_layernorm],
262+
[
263+
`layers.1.mlp.experts.0.up_proj`,
264+
`layers.1.mlp.experts.1.up_proj`,
265+
...
266+
],
267+
[
268+
`layers.1.mlp.experts.0.down_proj`,
269+
`layers.1.mlp.experts.1.down_proj`,
270+
...
271+
]
272+
],
273+
...
274+
)
275+
```
196276
197277
:param model: model containing modules to match against
198278
:param targets: target strings, potentially containing "re:" prefixes
199279
:param ignore: targets to ignore, potentially containing "re:" prefixes
280+
:param error_on_module_rematch: if True, errors when a module gets
281+
matched to multiple targets, if False, no error. (Defaults to True)
200282
"""
201283
targets = targets or []
202284
ignore = ignore or []
203285

204-
matches = dict.fromkeys(targets, None)
286+
# as we iterate through modules and try to match them with targets,
287+
# the algorithm can be in 2 possible states:
288+
# 0) unmatched_targets > 0, i.e. some of the targets haven't been matched.
289+
# Keep matching until all targets have at least one match
290+
# 1) unmatched_targets == 0 i.e. we have at least one match for each target.
291+
# At this point we are unsure if we have a full set or if we need to add
292+
# more matches.
293+
# There are 3 things that can happen once were in state 1:
294+
# A) found a new match with same parent_context,
295+
# (add it to matches and keep going)
296+
# B) found a new match with different parent_context, i.e. we found a match
297+
# that requires a deeper parent context, this indicates that this match
298+
# should be part of a new set.
299+
# (yield current set [not including newest match] and go back to state 0)
300+
# C) ran out of modules, we will always yield the final remaining set when
301+
# we we've iterated through all the modules in the model.
302+
# (yield final set then exit.)
303+
# Note: its possible to iterate through all the modules in the model while
304+
# not having a full matched set if the user specified a bad matching, in
305+
# that case something has gone wrong and we error
306+
matches = defaultdict(list)
307+
parent_context = None
308+
unmatched_targets = set(targets)
309+
205310
for name, module in model.named_modules():
206-
# match until we get a full set
311+
matched_targets_for_cur_module = set()
207312
for target in targets:
208313
if is_match(name, module, target, ignore):
209-
if matches[target] is not None:
210-
raise ValueError(f"Matched a {target} twice before completing set")
211-
matches[target] = module
212-
213-
# once we have a full set, yield and reset
214-
if targets and all((matches[target] is not None for target in targets)):
215-
yield [matches[target] for target in targets] # ensure correct ordering
216-
matches = dict.fromkeys(targets, None)
217-
218-
# check that none are left over
219-
unmatched_keys = [match for match, value in matches.items() if value is not None]
220-
if len(unmatched_keys):
221-
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
314+
new_parent_context = get_lowest_common_ancestor_name(
315+
[name, parent_context]
316+
)
317+
318+
# code for (B)
319+
if not unmatched_targets and new_parent_context != parent_context:
320+
yield [matches[target] for target in targets]
321+
matches = defaultdict(list)
322+
new_parent_context = name
323+
unmatched_targets = set(targets)
324+
325+
matches[target].append(module)
326+
parent_context = new_parent_context
327+
unmatched_targets -= {target}
328+
matched_targets_for_cur_module |= {target}
329+
330+
if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch:
331+
raise ValueError(
332+
f"module: {name} was matched with multiple targets: "
333+
f"{matched_targets_for_cur_module} which is unexpected "
334+
"disable this check by setting `error_on_module_rematch = False`"
335+
)
336+
337+
# never found anything
338+
if unmatched_targets == set(targets):
339+
return
340+
341+
# code for (C)
342+
if not unmatched_targets: # have a full matching
343+
yield [matches[target] for target in targets]
344+
return
345+
346+
raise ValueError(
347+
f"Found a final incomplete set with matches found for keys: "
348+
f"{set(targets) - unmatched_targets} "
349+
f"but no matches found for keys: {unmatched_targets}"
350+
)
222351

223352

224353
def is_match(

0 commit comments

Comments
 (0)